Pytorch的tensor问题,请问第3行代码是怎么计算出结果的,谢x。

代码是这样的,
y = torch.tensor([0,2])
y_hat = torch.tensor([[0.1,0.3,0.6], [0.3,0.2,0.5]])
y_hat[[0,1], y]
运行结果

tensor([0.1000, 0.5000])

如题,谢谢各位。

y_hat =

[[0.1,0.3,0.6],
[0.3,0.2,0.5]]

然后y_hat[[0,1], y],也就是

y_hat[[0,1], [0,2]]

意思是从y_hat里面挑选出【0,0】元素和【1,2】元素
得到
tensor([0.1000, 0.5000])

望采纳, 谢谢!

我好像明白了,[0,1]就是取第0行和第1行,y为[0,2] 就是取第0列和第2列。犯傻了我