在模型完成训练后需要输入测试集进行测试。目前测试功能只允许单张图片的输入。
博客来源:参考的文章
代码来源:https://github.com/codecat0/CV/tree/main/Image_Classification
import os
import json
import argparse
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from models.base_model import BaseModel
def main(args):
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_transform = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
)
img_path = args.img_path
assert os.path.exists(img_path), f"file {img_path} dose not exist."
img = Image.open(img_path)
plt.imshow(img)
img = data_transform(img)
# [C, H, W] -> [1, C, H, W]
img = torch.unsqueeze(img, dim=0)
json_path = './class_indices.json'
assert os.path.exists(json_path), f"file {json_path} does not exist."
json_file = open(json_path, 'r')
class_indict = json.load(json_file)
model = BaseModel(name=args.model_name, num_classes=args.num_classes).to(device)
model.load_state_dict(torch.load(args.model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "real: {} predict: {} prob: {:.3f}".format(args.real_label, class_indict[str(predict_cla)],
predict[predict_cla].numpy())
plt.title(print_res)
plt.xticks([])
plt.yticks([])
print(print_res)
plt.savefig('./data/predict.jpg', bbox_inches='tight', dpi=600, pad_inches=0.0)
plt.show()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--img_path', type=str, default=r'C:\Users\11831\Desktop\FinalProject\Code\data\testing\akiec\ISIC_0026729.jpg')
parser.add_argument('--real_label', type=str, default='akiec')
parser.add_argument('--model_name', type=str, default='densenet')
parser.add_argument('--num_classes', type=int, default=7)
parser.add_argument('--model_weight_path', type=str, default='./weights/densenet.pth')
args = parser.parse_args()
main(args)
我自己按照这段代码尝试写过,没有成功。(代码链接在最上面)
import os
import json
import random
import matplotlib.pyplot as plt
from PIL import Image
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
def read_split_data(root: str, val_rate: float = 0.2, plot_image: bool = False):
# 保证随机结果可复现
random.seed(0)
assert os.path.exists(root), f'dataset root {root} does not exist.'
# 遍历文件夹,一个文件夹对应一个类别
flower_classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证顺序一致
flower_classes.sort()
# 给类别进行编码,生成对应的数字索引
class_indices = dict((k, v) for v, k in enumerate(flower_classes))
json_str = json.dumps(dict((val, key) for key, val in class_indices.items()), indent=4)
with open('class_indices.json', 'w') as f:
f.write(json_str)
# 训练集所有图片的路径和对应索引信息
train_images_path, train_images_label = [], []
# 验证集所有图片的路径和对应索引信息
val_images_path, val_images_label = [], []
# 每个类别的样本总数
every_class_num = []
# 支持的图片格式
images_format = [".jpg", ".JPG", ".png", ".PNG"]
# 遍历每个文件夹下的文件
for cla in flower_classes:
cla_path = os.path.join(root, cla)
# 获取每个类别文件夹下所有图片的路径
images = [os.path.join(cla_path, i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in images_format]
# 获取类别对应的索引
image_class = class_indices[cla]
# 获取此类别的样本数
every_class_num.append(len(images))
# 按比例随机采样验证集
val_path = random.sample(images, k=int(len(images) * val_rate))
for img_path in images:
if img_path in val_path:
val_images_path.append(img_path)
val_images_label.append(image_class)
else:
train_images_path.append(img_path)
train_images_label.append(image_class)
print(f"{sum(every_class_num)} images found in dataset.")
print(f"{len(train_images_path)} images for training.")
print(f"{len(val_images_path)} images for validation.")
if plot_image:
plt.bar(range(len(flower_classes)), every_class_num, align='center')
plt.xticks(range(len(flower_classes)), flower_classes)
for i, v in enumerate(every_class_num):
plt.text(x=i, y=v + 5, s=str(v), ha='center')
plt.xlabel('image class')
plt.ylabel('number of images')
plt.title('flower class distribution')
plt.show()
return train_images_path, train_images_label, val_images_path, val_images_label
class MyDataSet(Dataset):
"""自定义数据集"""
def __init__(self, images_path: list, images_label: list, transform=None):
self.images_path = images_path
self.images_label = images_label
self.transform = transform
def __len__(self):
return len(self.images_path)
def __getitem__(self, item):
img = Image.open(self.images_path[item])
if img.mode != 'RGB':
raise ValueError(f"image: {self.images_path[item]} is not RGB mode")
label = self.images_label[item]
if self.transform is not None:
img = self.transform(img)
return img, label
@staticmethod
def collate_fn(batch):
images, labels = tuple(zip(*batch))
images = torch.stack(images, dim=0)
labels = torch.as_tensor(labels)
return images, labels
def get_dataset_dataloader(data_path, batch_size):
train_images_path, train_images_label, val_images_path, val_images_label = read_split_data(root=data_path)
data_transform = {
"train": transforms.Compose([transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])]),
"val": transforms.Compose([transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])])
}
train_dataset = MyDataSet(images_path=train_images_path,
images_label=train_images_label,
transform=data_transform['train'])
val_dataset = MyDataSet(images_path=val_images_path,
images_label=val_images_label,
transform=data_transform['val'])
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
print(f"Using {nw} dataloader workers every process.")
train_dataloader = DataLoader(
dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn
)
val_dataloader = DataLoader(
dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=val_dataset.collate_fn
)
return train_dataset, val_dataset, train_dataloader, val_dataloader
希望测试代码能够遍历测试文件夹下的七个子文件夹并获取其中的所有样本,记录每个样本的类别。
对每个样本进行预测,预测类别与记录类别不相符为错误,相符为正确。
计算最后测试集的准确率=成功预测的数量/测试集样本总数
你运行以下代码试一下,不知道能否得到你想要的结果:
#coding:utf-8
import os
import json
import argparse
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from models.base_model import BaseModel
def main(args):
global error_num
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
data_transform = transforms.Compose(
[
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
]
)
img_path = args.img_path
assert os.path.exists(img_path), f"file {img_path} dose not exist."
img = Image.open(img_path)
plt.imshow(img)
img = data_transform(img)
# [C, H, W] -> [1, C, H, W]
img = torch.unsqueeze(img, dim=0)
json_path = './class_indices.json'
assert os.path.exists(json_path), f"file {json_path} does not exist."
json_file = open(json_path, 'r')
class_indict = json.load(json_file)
model = BaseModel(name=args.model_name, num_classes=args.num_classes).to(device)
model.load_state_dict(torch.load(args.model_weight_path, map_location=device))
model.eval()
with torch.no_grad():
output = torch.squeeze(model(img.to(device))).cpu()
predict = torch.softmax(output, dim=0)
predict_cla = torch.argmax(predict).numpy()
print_res = "real: {} predict: {} prob: {:.3f}".format(args.real_label, class_indict[str(predict_cla)],
predict[predict_cla].numpy())
if args.real_label !=class_indict[str(predict_cla)]:
error_num += 1
root = r'C:\Users\11831\Desktop\FinalProject\Code\data\testing'
# 遍历文件夹,一个文件夹对应一个类别
flower_classes = [cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root, cla))]
# 排序,保证顺序一致
flower_classes.sort()
# 支持的图片格式
images_format = [".jpg", ".JPG", ".png", ".PNG"]
all_images = []
# 遍历每个文件夹下的文件
for cla in flower_classes:
cla_path = os.path.join(root, cla)
# 获取每个类别文件夹下所有图片的路径与类别
images = [[os.path.join(cla_path, i), cla] for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in images_format]
all_images.extend(images)
total_num = len(all_images)
cur_num = 0
if __name__ == '__main__':
error_num = 0
for each in all_images:
cur_num += 1
path = each[0]
label =each[1]
print('正在识别第{}张图片,共{}张图片'.format(cur_num,total_num))
parser = argparse.ArgumentParser()
parser.add_argument('--img_path', type=str, default=path)
parser.add_argument('--real_label', type=str, default=label)
parser.add_argument('--model_name', type=str, default='densenet')
parser.add_argument('--num_classes', type=int, default=7)
parser.add_argument('--model_weight_path', type=str, default='./weights/densenet.pth')
args = parser.parse_args()
main(args)
print("error_num={},total_num={},correct_ratio={}".format(error_num, total_num, 1- error_num/total_num))
你是要循环 输入每一个,然后再计算全部的准确率?
利用pytorch训练好的模型测试单张图片
https://blog.csdn.net/qq_41167777/article/details/109013155