unet数据集制作,出现下列问题
import os
import numpy as np
import torch
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms
transform = transforms.Compose([
transforms.ToTensor()
])
###############
class MyDataset(Dataset):
def __init__(self, path):
self.path = path
self.name = os.listdir(os.path.join(path, 'train_label')) #拼接
def __len__(self):
return len(self.name)
def __getitem__(self, index): #数据集制作
segment_name = self.name[index] # xx.png
segment_path = os.path.join(self.path, 'train_label', segment_name)#标签地址
image_path = os.path.join(self.path, 'train_img', segment_name)
segment_image = keep_image_size_open(segment_path)
image = keep_image_size_open(image_path)
return transform(image), transform(segment_image)
if __name__ == '__main__':
#from torch.nn.functional import one_hot
data = MyDataset('E:\project\ph2_dataset')
print(data[0][0].shape)
print(data[0][1].shape)
#out=one_hot(data[0][1].long())
#print(out.shape)
1.源码下载
可以直接到up主对应的github上进行下载,GitHub - bubbliiiing/unet-pytorch: 这是一个unet-pytorch的源码,可以训练自己的模型
他的代码里面进行了中文说明,对小白很友好
2.训练所需要的模型下载
链接: 百度网盘 请输入提取码
提取码: 6n2c
下载完之后放至源码根目录中model_data文件夹,没有就新建一个