pytorch这个问题怎么解决?

img


请问一下这个问题怎么解决?越快解决越好
,越快解决越好,希望有人能够帮忙解决这个问题

  • 你可以参考下这个问题的回答, 看看是否对你有帮助, 链接: https://ask.csdn.net/questions/7790249
  • 这篇博客你也可以参考下:PyTorch开始使用的一些问题
  • 除此之外, 这篇博客: PyTorch下训练数据小文件转大文件读写(附有各种存储格式对比)中的 读取数据库 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
  • class SimpleDataset(Dataset):
        def __init__(self, db_path, transform=None) -> None:
            self.db_path = db_path
            self.conn = None
            self.establish_conn()
     
            # 数据库中表名
            self.table_name = 'Synthetic_chinese_dataset'
     
            self.cursor.execute(f'select max(rowid) from {self.table_name}')
            self.nums = self.cursor.fetchall()[0][0]
            self.transform = transform
     
        def __getitem__(self, index: int):
            self.establish_conn()
     
            # 查询
            search_sql = f'select * from {self.table_name} where rowid=?'
            self.cursor.execute(search_sql, (index+1, ))
            img_path, img_bytes, label = self.cursor.fetchone()
     
            # 还原图像和标签
            img = Image.open(BytesIO(img_bytes))
            img = img.convert('RGB')
            img = scale_resize_pillow(img, (320, 32))
     
            if self.transform:
                img = self.transform(img)
            return img, label
     
        def __len__(self) -> int:
            return self.nums
     
        def establish_conn(self):
            if self.conn is None:
                self.conn = sqlite3.connect(self.db_path,
                                            check_same_thread=False,
                                            cached_statements=1024)
                self.cursor = self.conn.cursor()
            return self
     
        def close_conn(self):
            if self.conn is not None:
                self.cursor.close()
                self.conn.close()
     
                del self.conn
                self.conn = None
            return self  
     
    # --------------------------------------------------
    train_dataset = SimpleDataset(train_db_path, train_transforms)
    # ✧✧使用部分,需要手动关闭数据库连接
    train_dataset.close_conn()
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=batch_size,
                                  num_workers=n_worker,
                                  pin_memory=True,
                                  sampler=train_sampler)
    
  • 您还可以看一下 覃秉丰老师的PyTorch从入门到实战一次学会课程中的 使用卷积网络进行手写数字识别小节, 巩固相关知识点
  • 以下回答由chatgpt基于相关博客总结生成:

    针对在使用pytorch中遇到的问题,需要先了解具体错误信息和出现问题的代码段,才能提供具体的解决方案。如果问题涉及到pytorch的数据读取方面,可以考虑以下两个方面:

    1.使用PyTorch读取大规模数据时,可通过封装自定义数据集(Dataset)和数据迭代器(DataLoader)来实现高效数据读取。可以参考上面的参考资料段落0中的实现。同时,重点关注以下几点:

    a.尽量减少数据读取、处理和网络计算所用的时间,以便充分利用GPU性能,提高网络训练速度。

    b.确保使用的数据集格式与DataLoader的读取方式一致,比如读取图像数据时,要注意图像的尺寸大小、像素格式等问题,以免出现数据不匹配的情况。

    2.如果数据规模较小,可以直接将数据转成numpy格式导入到内存中,然后利用PyTorch的Tensor格式进行训练。这种方法比通过Dataset和DataLoader读取的速度更快,因为省去了对数据集进行预处理的过程。

    如果以上方法都不能解决你的问题,请提供更具体的错误信息或代码片段,以便更好地理解问题。