torch 1.0.0版本
arg_input = F.one_hot(ind_arr, num_classes=num_const)
Error:cannot find reference 'one_hot' in functional.py
记录一下用过的方法
参考:
index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)
# 输出
tensor([[0., 1., 0., 0.],
[0., 0., 1., 0.],
[1., 0., 0., 0.],
[0., 0., 0., 1.]])
torch.zeros() 中的参数改为 tensor 的行数
onehot = torch.zeros(list(index.shape)[0], list(index.shape)[0])
将原来的代码改成:
arg_input = torch.zeros(list(ind_arr.shape)[0], list(ind_arr.shape)[0]).scatter_(1, ind_arr, 1)
但是,计算机又报了新的错:
RuntimeError: CUDA out of memory. Tried to allocate 5.00 MiB (GPU 0; 2.00 GiB total capacity; 685.78 MiB already allocated; 4.41 MiB free; 227.00 KiB cached)
。。。
torch1.1加入one_hot()函数,根据cuda情况,可以升级一下torch