Paddlepaddle创建数据集显存爆了

问题遇到的现象和发生背景

使用paddlepaddle创建自己的数据集时发现显存飙升,用的是百度AI Studio环境

用代码块功能插入代码,请勿粘贴截图
import os
from paddle.io import Dataset
import paddle.vision.transforms as transforms
from paddle.io import DataLoader
import paddle

import cv2

change = transforms.ToTensor()

class FlameSet(Dataset):
    def __init__(self,img_path,row_label_path,col_label_path):
        img_list = []
        label1_list = []
        label2_list = []
        img_dir = img_path
        col_label_dir = col_label_path
        row_label_dir = row_label_path
        count = 0
        for filename in os.listdir(img_dir):
            if count == 800:
                break
            if filename.split(".")[1] == "jpg":
                img = cv2.imread(img_dir+filename)
                #print(img_dir+filename)
                img = cv2.resize(img,(640,640))
                img_ = change(img)
                img_list.append(img_)
                label1 = cv2.imread(col_label_dir+filename)
                label1 = cv2.resize(label1,(640,640))
                label1_tensor = change(label1)
                label1_list.append(label1_tensor)
                label2 = cv2.imread(row_label_dir+filename)
                label2 = cv2.resize(label2,(640,640))
                label2_tensor = change(label2)
                label2_list.append(label2_tensor)
                count += 1

        self.imgs = paddle.stack(img_list,axis=0)
        self.labels1 = paddle.stack(label1_list,axis=0)
        self.labels2 = paddle.stack(label2_list, axis=0)
    def __len__(self):
        return len(self.imgs)

    # 每取一个元素,就会调用该函数
    def __getitem__(self, index):
        #print(f"取出下标为{index}的数据")

        data = self.imgs[index]
        label1 = self.labels1[index]
        label2 = self.labels2[index]
        return data,label1,label2
        # print(self.imgs.size())
        # print(self.labels.size())
def get_data():
    dataset = DataLoader(FlameSet('./datasets/data_aug/train/','./datasets/data_aug/row/','./datasets/data_aug/col/'),batch_size=1)
    return dataset

运行结果及报错内容

上面是我的代码,之前没用过paddlepaddle,根据pytorch写的转化的

我的解答思路和尝试过的方法

尝试清空无用显存也结束不了

我想要达到的结果

不知道哪里有问题