import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from vit_model import vit_base_patch16_224_in21k as create_model
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
# load image
img_path = "first41.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
plt.imshow(img)
# [N, C, H, W]
img = data_transform(img)
# expand batch dimension
img = torch.unsqueeze(img, dim=0)
# read class_indict
json_path = './class_indices.json'
assert os.path.exists(json_path), "file: '{}' dose not exist.".format(json_path)
with open(json_path, "r") as f:
class_indict = json.load(f)
# create model
model = create_model(num_classes=1, has_logits=False).to(device)
# load model weights
model_weight_path = "./weights/model-9.pth"
model.load_state_dict(torch.load(model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
# predict class
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "class: {} prob: {:.3}".format(class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
for i in range(len(predict)):
print("class: {:10} prob: {:.3}".format(class_indict[str(i)],
predict[i].numpy()))
plt.show()
if __name__ == '__main__':
main()
运行后出现predict[predict_cla].numpy())
IndexError: too many indices for tensor of dimension 0
用目前网上给出的方法都没有用
这个错误可能是由于你的predict[predict_cla]是一个标量,而不是一个向量或矩阵,所以你不能用.numpy()方法来转换它。你可以尝试用.item()方法来获取它的数值,或者用.cpu()方法来把它复制到主内存中2。
如果你想了解更多关于这个错误的原因和解决办法,你可以参考以下链接:
1 https://github.com/pytorch/tutorials/issues/552
3 https://discuss.pytorch.org/t/indexerror-too-many-indices-for-tensor-of-dimension-0/45326
2 https://discuss.pytorch.org/t/how-to-solve-indexerror-too-many-indices-for-tensor-of-dimension-1/40168
希望这对你有帮助。🙏
predict_cla = torch.argmax(predict).numpy()
predict[predict_cla].numpy()
predict[predict_cla] 要求 predict_cla 是一个整数
但是 predict_cla 被转换为 numpy 数组了,不是一个整数,不能作为 predict[*] 的序号使用。