pytorch如何用dataset制作图片序列的数据集,并训练。

我的模型输入是三通道的图片序列,输出也是一段图片序列。如何使用dataset制作合适的数据集。模型是卷积之后的循环结构。

img


我的数据集格式是这样的。每一行代表一个时间段里的一系列图片。需要根据前半部分序列预测后半部分序列。我想的是用一个循环网络的结构去训练。但是不太清楚如何构建合适的dataset。分batch之后好像没法直接一起把数据喂给模型。


你可以学别人的做法,继承dataset类,然后重写__getitem__魔术方法,在迭代的时候截取数据
torch.utils.data — PyTorch 1.11.0 documentation https://pytorch.org/docs/stable/data.html

data相关可以在此页面找到

PyTorch 模型训练实用教程(一):数据
https://blog.csdn.net/qq_38156104/article/details/108029372