图像分割unet数据集制作

unet数据集制作,出现下列问题

img

import os

import numpy as np
import torch
from torch.utils.data import Dataset
from utils import *
from torchvision import transforms

transform = transforms.Compose([
    transforms.ToTensor()
])
###############

class MyDataset(Dataset):
    def __init__(self, path):
        self.path = path
        self.name = os.listdir(os.path.join(path, 'train_label')) #拼接

    def __len__(self):
        return len(self.name)

    def __getitem__(self, index):  #数据集制作
        segment_name = self.name[index]  # xx.png
        segment_path = os.path.join(self.path, 'train_label', segment_name)#标签地址
        image_path = os.path.join(self.path, 'train_img', segment_name)
        segment_image = keep_image_size_open(segment_path)
        image = keep_image_size_open(image_path)
        return transform(image), transform(segment_image)


if __name__ == '__main__':
    #from torch.nn.functional import one_hot
    data = MyDataset('E:\project\ph2_dataset')
    print(data[0][0].shape)
    print(data[0][1].shape)
   #out=one_hot(data[0][1].long())
    #print(out.shape)