pytorch实现lstm编码器出错

我用lstm的模块编了一个Encoder编码器出现了如下的错误!
https://img-mid.csdnimg.cn/release/static/image/mid/ask/65654403184613.png

class BNLSTMCell(nn.Module):

"""A BN-LSTM cell."""

def __init__(self, input_size, hidden_size, max_length, use_bias=True):

    super(BNLSTMCell, self).__init__()
    self.input_size = input_size
    self.hidden_size = hidden_size
    self.max_length = max_length
    self.use_bias = use_bias
    self.weight_ih = nn.Parameter(
        torch.Tensor(input_size, 4 * hidden_size))
    self.weight_hh = nn.Parameter(
        torch.Tensor(hidden_size, 4 * hidden_size))
    if use_bias:
        self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
    else:
        self.register_parameter('bias', None)
    # BN parameters
    self.bn_ih = SeparatedBatchNorm1d(
        num_features=4 * hidden_size, max_length=max_length)
    self.bn_hh = SeparatedBatchNorm1d(
        num_features=4 * hidden_size, max_length=max_length)
    self.bn_c = SeparatedBatchNorm1d(
        num_features=hidden_size, max_length=max_length)
    self.reset_parameters()

def reset_parameters(self):
    """
    Initialize parameters following the way proposed in the paper.
    """

    # The input-to-hidden weight matrix is initialized orthogonally.
    init.orthogonal(self.weight_ih.data)
    # The hidden-to-hidden weight matrix is initialized as an identity
    # matrix.
    weight_hh_data = torch.eye(self.hidden_size)
    weight_hh_data = weight_hh_data.repeat(1, 4)
    self.weight_hh.data.set_(weight_hh_data)
    # The bias is just set to zero vectors.
    init.constant(self.bias.data, val=0)
    # Initialization of BN parameters.
    self.bn_ih.reset_parameters()
    self.bn_hh.reset_parameters()
    self.bn_c.reset_parameters()
    self.bn_ih.bias.data.fill_(0)
    self.bn_hh.bias.data.fill_(0)
    self.bn_ih.weight.data.fill_(0.1)
    self.bn_hh.weight.data.fill_(0.1)
    self.bn_c.weight.data.fill_(0.1)

def forward(self, input_, hx, time):
    """
    Args:
        input_: A (batch, input_size) tensor containing input
            features.
        hx: A tuple (h_0, c_0), which contains the initial hidden
            and cell state, where the size of both states is
            (batch, hidden_size).
        time: The current timestep value, which is used to
            get appropriate running statistics.

    Returns:
        h_1, c_1: Tensors containing the next hidden and cell state.
    """

    h_0, c_0 = hx
    batch_size = h_0.size(0)
    bias_batch = (self.bias.unsqueeze(0)
                  .expand(batch_size, *self.bias.size()))
    wh = torch.mm(h_0, self.weight_hh)  # 矩阵相乘
    wi = torch.mm(input_, self.weight_ih)
    bn_wh = self.bn_hh(wh, time=time)
    bn_wi = self.bn_ih(wi, time=time)
    f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch,
                             split_size=self.hidden_size, dim=1)
    c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
    h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time))
    return h_1, c_1

显示在这一行

img

请各位帮忙看一下错在哪或pytourch如何找错