pytorch测试集看每类准确率 class_correct[label] +=( c[i].item())报错

class_correct[label] +=( c[i].item())
IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
就是网络上最常用的测试每类数据集准确率的代码

img

找到一个外网ds遇到相似问题的评论,
correct seems to be a 0-dim tensor, which you cannot index. This error might be raised, as you are calling np.squeeze to create correct.

有没有xd会改的

应该是pytorch版本不同造成的,你降低下版本试试

报错已经提示你了索引的问题。原因就是这句:
label=labels[i]
这么写的话你的label是torch.tensor,而这个不能作为列表class_correct的索引,你需要用torch.item()转成数字,如果是浮点数还需要转成整数。

img