我用lstm的模块编了一个Encoder编码器出现了如下的错误!
https://img-mid.csdnimg.cn/release/static/image/mid/ask/65654403184613.png
class BNLSTMCell(nn.Module):
"""A BN-LSTM cell."""
def __init__(self, input_size, hidden_size, max_length, use_bias=True):
super(BNLSTMCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.max_length = max_length
self.use_bias = use_bias
self.weight_ih = nn.Parameter(
torch.Tensor(input_size, 4 * hidden_size))
self.weight_hh = nn.Parameter(
torch.Tensor(hidden_size, 4 * hidden_size))
if use_bias:
self.bias = nn.Parameter(torch.FloatTensor(4 * hidden_size))
else:
self.register_parameter('bias', None)
# BN parameters
self.bn_ih = SeparatedBatchNorm1d(
num_features=4 * hidden_size, max_length=max_length)
self.bn_hh = SeparatedBatchNorm1d(
num_features=4 * hidden_size, max_length=max_length)
self.bn_c = SeparatedBatchNorm1d(
num_features=hidden_size, max_length=max_length)
self.reset_parameters()
def reset_parameters(self):
"""
Initialize parameters following the way proposed in the paper.
"""
# The input-to-hidden weight matrix is initialized orthogonally.
init.orthogonal(self.weight_ih.data)
# The hidden-to-hidden weight matrix is initialized as an identity
# matrix.
weight_hh_data = torch.eye(self.hidden_size)
weight_hh_data = weight_hh_data.repeat(1, 4)
self.weight_hh.data.set_(weight_hh_data)
# The bias is just set to zero vectors.
init.constant(self.bias.data, val=0)
# Initialization of BN parameters.
self.bn_ih.reset_parameters()
self.bn_hh.reset_parameters()
self.bn_c.reset_parameters()
self.bn_ih.bias.data.fill_(0)
self.bn_hh.bias.data.fill_(0)
self.bn_ih.weight.data.fill_(0.1)
self.bn_hh.weight.data.fill_(0.1)
self.bn_c.weight.data.fill_(0.1)
def forward(self, input_, hx, time):
"""
Args:
input_: A (batch, input_size) tensor containing input
features.
hx: A tuple (h_0, c_0), which contains the initial hidden
and cell state, where the size of both states is
(batch, hidden_size).
time: The current timestep value, which is used to
get appropriate running statistics.
Returns:
h_1, c_1: Tensors containing the next hidden and cell state.
"""
h_0, c_0 = hx
batch_size = h_0.size(0)
bias_batch = (self.bias.unsqueeze(0)
.expand(batch_size, *self.bias.size()))
wh = torch.mm(h_0, self.weight_hh) # 矩阵相乘
wi = torch.mm(input_, self.weight_ih)
bn_wh = self.bn_hh(wh, time=time)
bn_wi = self.bn_ih(wi, time=time)
f, i, o, g = torch.split(bn_wh + bn_wi + bias_batch,
split_size=self.hidden_size, dim=1)
c_1 = torch.sigmoid(f)*c_0 + torch.sigmoid(i)*torch.tanh(g)
h_1 = torch.sigmoid(o) * torch.tanh(self.bn_c(c_1, time=time))
return h_1, c_1
显示在这一行
请各位帮忙看一下错在哪或pytourch如何找错