torch1.0.0 torch.nn.functional.one_hot()函数.

torch 1.0.0版本

arg_input = F.one_hot(ind_arr, num_classes=num_const)  

Error:cannot find reference 'one_hot' in functional.py

 

 

 

记录一下用过的方法

参考:

index = torch.tensor([[1], [2], [0], [3]])
onehot = torch.zeros(4, 4)
onehot.scatter_(1, index, 1)
print(onehot)

# 输出
tensor([[0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.]])

torch.zeros() 中的参数改为 tensor 的行数

onehot = torch.zeros(list(index.shape)[0], list(index.shape)[0])

将原来的代码改成:

arg_input = torch.zeros(list(ind_arr.shape)[0], list(ind_arr.shape)[0]).scatter_(1, ind_arr, 1)

但是,计算机又报了新的错:

RuntimeError: CUDA out of memory. Tried to allocate 5.00 MiB (GPU 0; 2.00 GiB total capacity; 685.78 MiB already allocated; 4.41 MiB free; 227.00 KiB cached)
 

。。。

torch1.1加入one_hot()函数,根据cuda情况,可以升级一下torch