#pyhton报错怎么修改 给出修改方法
datalist = []
for i in range(1,20):
if i>= 1 and i<= 9:
i = '0'+str(i)
for j in range(1,9):
for k in range(1,61):
if k >= 1 and k <= 9:
k = '0' + str(k)
path = r'C:\Users\14852\Desktop\data' + r'\a' + str(i) + '\p' + str(j) + '\s' + str(k) + ".txt"
data1 = open(path)
data1 = data1.read()
data1 = data1.split('\n')
for num in data1:
datalist.append([num,i])
datalist = datalist[0:300]
input_batch = []
target_batch = []
for item in datalist:
try:
numm = item[0].split(',')
numm = [torch.LongTensor([float(i)]) for i in numm]
input_batch.append(torch.LongTensor(numm))
target_batch.append(int(item[1]))
except:
a = 1
target_batch1 = target_batch
input_batch1 = input_batch
input_batch = torch.LongTensor(input_batch)
target_batch = torch.LongTensor(target_batch)
dataset = data.TensorDataset(input_batch,target_batch)
loader = data.DataLoader(dataset,batch_size = 1)
class LSTM(nn.Module):
def init(self,max_len=1):
super(LSTM,self).init()
self.lstm = nn.LSTM(input_size=45,hidden_size=128)
self.fc = nn.Linear(128,19)
def forward(self,inputt):
inputt = inputt.view(len(inputt),1,-1)
x,(h_n,c_n) = self.lstm(inputt)
x = x.reshape(1,-1)
out = self.fc(x)
return out
model = LSTM()
optimizer = optim.Adam(model.parameters(),0.005)
def train(epoch):
for idx,(inputt,target) in enumerate(loader):
optimizer.zero_grad()
output = model(inputt)
print(output)
loss = F.nll_loss(output,target)
loss.backward()
optimizer.step()
print(loss.item())
for i in range(1):
train(i)
看起来是tensor的数据类型问题,你可以在你的ipython里面输入%debug
,然后检查一下哪一处的变量数据类型不对
RuntimeError : Expected object of scalar type Long but got scalar type Float for argument #3 'mat2' in call to_th_addm_out
RuntimeError 表示一般的运行时报错
报错意思是类型报错,需要标量类型为Long的对象,但在调用to_th_addm_out时获得了参数#3“mat2”的标量类型Float
错误的关键是这句:
numm = [torch.LongTensor([float(i)]) for i in numm]
改成:
numm = [torch.LongTensor([int(i)]) for i in numm]
就是类型错误了。 你在for idx,(inputt,target) in enumerate(loader): 这行下面加一句inputt=inputt.to(torch.long) 试下
标量类型的原因,期望的类型为Long,实际上却是float,类型改为期望的类型即可。
译:对于调用_th_addmm_out中的参数#3 'mat2',期望的标量类型为Long,但得到的标量类型为Float。
不知道你解决了没?
如果没有解决,我们可以聊聊。