class SimpleDataset(Dataset):
def __init__(self, db_path, transform=None) -> None:
self.db_path = db_path
self.conn = None
self.establish_conn()
# 数据库中表名
self.table_name = 'Synthetic_chinese_dataset'
self.cursor.execute(f'select max(rowid) from {self.table_name}')
self.nums = self.cursor.fetchall()[0][0]
self.transform = transform
def __getitem__(self, index: int):
self.establish_conn()
# 查询
search_sql = f'select * from {self.table_name} where rowid=?'
self.cursor.execute(search_sql, (index+1, ))
img_path, img_bytes, label = self.cursor.fetchone()
# 还原图像和标签
img = Image.open(BytesIO(img_bytes))
img = img.convert('RGB')
img = scale_resize_pillow(img, (320, 32))
if self.transform:
img = self.transform(img)
return img, label
def __len__(self) -> int:
return self.nums
def establish_conn(self):
if self.conn is None:
self.conn = sqlite3.connect(self.db_path,
check_same_thread=False,
cached_statements=1024)
self.cursor = self.conn.cursor()
return self
def close_conn(self):
if self.conn is not None:
self.cursor.close()
self.conn.close()
del self.conn
self.conn = None
return self
# --------------------------------------------------
train_dataset = SimpleDataset(train_db_path, train_transforms)
# ✧✧使用部分,需要手动关闭数据库连接
train_dataset.close_conn()
train_dataloader = DataLoader(train_dataset,
batch_size=batch_size,
num_workers=n_worker,
pin_memory=True,
sampler=train_sampler)
针对在使用pytorch中遇到的问题,需要先了解具体错误信息和出现问题的代码段,才能提供具体的解决方案。如果问题涉及到pytorch的数据读取方面,可以考虑以下两个方面:
1.使用PyTorch读取大规模数据时,可通过封装自定义数据集(Dataset)和数据迭代器(DataLoader)来实现高效数据读取。可以参考上面的参考资料段落0中的实现。同时,重点关注以下几点:
a.尽量减少数据读取、处理和网络计算所用的时间,以便充分利用GPU性能,提高网络训练速度。
b.确保使用的数据集格式与DataLoader的读取方式一致,比如读取图像数据时,要注意图像的尺寸大小、像素格式等问题,以免出现数据不匹配的情况。
2.如果数据规模较小,可以直接将数据转成numpy格式导入到内存中,然后利用PyTorch的Tensor格式进行训练。这种方法比通过Dataset和DataLoader读取的速度更快,因为省去了对数据集进行预处理的过程。
如果以上方法都不能解决你的问题,请提供更具体的错误信息或代码片段,以便更好地理解问题。