y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1,0.3,0.6], [0.3,0.2,0.5]])
y_hat[[0,1], y]
tensor([0.1000, 0.5000])
如题,谢谢各位。
y_hat =
[[0.1,0.3,0.6],
[0.3,0.2,0.5]]
然后y_hat[[0,1], y],也就是
y_hat[[0,1], [0,2]]
意思是从y_hat里面挑选出【0,0】元素和【1,2】元素
得到
tensor([0.1000, 0.5000])
望采纳, 谢谢!
我好像明白了,[0,1]就是取第0行和第1行,y为[0,2] 就是取第0列和第2列。犯傻了我