已有:(1)一个包含所有图像数据的文件夹 ‘。/Image/’;
(2)一个包含file name + 标签label的dataY csv文件‘dataY.csv'
目的:使用 【tensoflow】 构建一个类,可以根据kfold得到的index + dataY.csv 取Image图像,从而进行图像分类
使用 pytorch构建自己的torch.utils.data.Dataset类,修改‘__init__’ ‘__getitem__‘即可实现
不用那么麻烦,直接用sklearn的kfold就行,kfold无非就是个随机分数据集训练而已。。
from sklearn.model_selection import KFold
kfold = KFold(n_splits=num_folds, shuffle=True)
for train, test in kfold.split(inputs, targets):
#下面跟tensorflow的模型训练代码就行了