[[0.1 3.5]
[1.14 3.51]
[9.17 3.51]
[9.98 9.30]
[9.98 9.31]
[1.00 8.98]]
上面表示一条数据(一个tensor),一条数据中有n * 2个浮点数,n是不定的。
请问怎样利用AutoEncoder(AutoEncoder不行的话,其他网络有什么)把数据由n*2的格式转换成256*2,谢谢各位。
class autoencoder(nn.Module):
def __init__(self):
super(autoencoder,self).__init__()
self.encoder=nn.Sequential(
nn.Linear(4, 3),
nn.Tanh(),
nn.Linear(3, 2),
)
self.decoder=nn.Sequential(
nn.Linear(2, 3),
nn.Tanh(),
nn.Linear(3, 4),
nn.Sigmoid()
)
def forward(self, x):
encoder=self.encoder(x)
decoder=self.decoder(encoder)
return encoder,decoder
1.用torch.reshape
或
2.降维(pcb等)
可以看看transforms bert 是怎么弄不定长的