class MyIterableDataset(IterableDataset):
def __init__(self, file_path):
self.file_path = file_path
def __iter__(self):
with open(self.file_path, 'r', encoding="utf-8") as file_obj:
for line in file_obj:
line_data = line.strip('\n').split(',')
yield line_data
dataset = MyIterableDataset('text.csv')
training_args = TrainingArguments(
output_dir='./results', # output directory
num_train_epochs=3, # total # of training epochs
per_device_train_batch_size=16, # batch size per device during training
per_device_eval_batch_size=64, # batch size for evaluation
warmup_steps=500, # number of warmup steps for learning rate scheduler
weight_decay=0.01, # strength of weight decay
logging_dir='./logs', # directory for storing logs
max_steps=100
)
trainer = Trainer(
model=model, # the instantiated 🤗 Transformers model to be trained
args=training_args, # training arguments, defined above
train_dataset=dataset # training dataset
)
trainer.train()
用以上方法试图构建数据集,其中.txt文件里面有两列,一列是序号,另一列是文本句子。因为文本生成实在没有label。
然后报错:
查了一些问答结果都有一点抽象,torch基础不好不太理解,想请教我现在的代码应该怎么改