[8] Assertion failed: axis >= 0 && axis < nbDims

pytorch 双向LSTM转到TensorRT(6.0.1.5)遇到[8] Assertion failed: axis >= 0 && axis < nbDims:

问题复现:

import torch
import torch.nn as nn


class BidirectionLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(BidirectionLSTM, self).__init__()
        self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True)
        self.linear = nn.Linear(hidden_size * 2, output_size)
        
    def forward(self, input):
        recurrent, _ = self.rnn(input)
        output = self.linear(recurrent)
        return output


class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.SequenceModeling = nn.Sequential(
            BidirectionLSTM(512, 256, 256),
            BidirectionLSTM(256, 256, 256),
        )
        
    def forward(self, input):
        output = self.SequenceModeling(input)
        return output


if __name__ == '__main__':
    model = Model()
    model.eval()

    dummy_input = torch.rand((1, 64, 512))
    dummy_output = model(dummy_input)

    torch_out = torch.onnx.export(model, dummy_input, "test.onnx", export_params=True, verbose=True,
                                  input_names=["input"], output_names=["output"])

模型再netron中可视化如下:

不知如何解决?