我的代码train可以正常跑,但validation就会报batchsize的错
求原因和修改方法
以下是model:
class TransformerModel(nn.Module):
def __init__(self, vocab_size: int, embed_size: int, num_heads: int, hidden_size: int,
num_layers: int):
super().__init__()
self.model_type = 'Transformer'
self.pos_encoder = PositionalEncoding(embed_size, 0.2, 128)
encoder_layers = TransformerEncoderLayer(embed_size, num_heads, dropout=0.1, batch_first=True)
self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
self.encoder = nn.Embedding(vocab_size, embed_size)
self.lstm = nn.LSTM(embed_size, hidden_size, batch_first = True)
self.embed_size = embed_size
self.decoder = nn.Linear(hidden_size, 10)
def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
src = self.encoder(src)
src = self.pos_encoder(src)
output = self.transformer_encoder(src, src_key_padding_mask = src_mask)
output = self.lstm(output)[0]
output = self.decoder(output)
return output
以下是train和validation的代码:
def validate(
model: nn.Module,
dataloader: DataLoader,
device: torch.device,
):
acc_metric = torchmetrics.Accuracy(task = 'multiclass', num_classes = 10, compute_on_step=False).to(device)
loss_metric = torchmetrics.MeanMetric(compute_on_step=False).to(device)
model.eval()
with torch.no_grad():
for batch in tqdm(dataloader):
input_ids, input_mask, tags = batch[0].to(device), batch[1].to(device), batch[2].to(device)
# output shape: (batch_size, max_length, num_classes)
logits = model(input_ids, input_mask)
# ignore padding index 0 when calculating loss
loss = F.cross_entropy(logits.reshape(-1, 10), tags.reshape(-1), ignore_index=0)
loss_metric.update(loss, input_mask.numel() - input_mask.sum())
is_active = torch.logical_not(input_mask) # non-padding elements
# only consider non-padded tokens when calculating accuracy
acc_metric.update(logits[is_active], tags[is_active])
print(f"| Validate | loss {loss_metric.compute():.4f} | acc {acc_metric.compute():.4f} |")
def train(
model: nn.Module,
dataloader: DataLoader,
optimizer: optim.Optimizer,
device: torch.device,
epoch: int,
):
acc_metric = torchmetrics.Accuracy(task = 'multiclass', num_classes = 10, compute_on_step=False).to(device)
loss_metric = torchmetrics.MeanMetric(compute_on_step=False).to(device)
model.train()
# loop through all batches in the training
for batch in tqdm(dataloader):
input_ids, input_mask, tags = batch[0].to(device), batch[1].to(device), batch[2].to(device)
optimizer.zero_grad()
# output shape: (batch_size, max_length, num_classes)
logits = model(input_ids, input_mask)
# ignore padding index 0 when calculating loss
loss = F.cross_entropy(logits.reshape(-1, 10), tags.reshape(-1), ignore_index=0)
loss.backward()
optimizer.step()
loss_metric.update(loss, input_mask.numel() - input_mask.sum())
is_active = torch.logical_not(input_mask) # non-padding elements
# only consider non-padded tokens when calculating accuracy
acc_metric.update(logits[is_active], tags[is_active])
print(f"| Epoch {epoch} | loss {loss_metric.compute():.4f} | acc {acc_metric.compute():.4f} |")
以下是训练模型的代码和报错信息:
model = TransformerModel(vocab_size = len(tokenizer),
embed_size = 256,
num_heads = 4,
hidden_size = 256,
num_layers = 2,).to(device)
optimizer = optim.Adam(model.parameters())
for epoch in range(5):
train(model, train_dataloader, optimizer, device, epoch)
validate(model, val_dataloader, device)
ValueError Traceback (most recent call last)
18-dc022b7dd0b5> in <module>
19 for epoch in range(5):
20 train(model, train_dataloader, optimizer, device, epoch)
---> 21 validate(model, val_dataloader, device)
1 frames
10-c57c46bc539f> in validate(model, dataloader, device)
15 logits = model(input_ids, input_mask)
16 # ignore padding index 0 when calculating loss
---> 17 loss = F.cross_entropy(logits.reshape(-1, 10), tags.reshape(-1), ignore_index=0)
18
19 loss_metric.update(loss, input_mask.numel() - input_mask.sum())
/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
3024 if size_average is not None or reduce is not None:
3025 reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026 return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
3027
3028
ValueError: Expected input batch_size (1312) to match target batch_size (4096).
该回答引用GPTᴼᴾᴱᴺᴬᴵ
从报错信息来看,出错是因为在 validate 中的数据 batch size 太小,因此无法将数据张量的形状转换为 (batch_size, num_classes)。
·
为了解决这个问题,你可以将 batch_size 增加到能够支持模型所需要的最小值(可以使用模型的 batch_size 属性)。你还可以尝试使用 torch.nn.utils.rnn.pad_sequence 函数将所有序列填充到相同的长度,以确保 batch 内的张量具有相同的形状。例如,您可以修改 validate 函数的代码,如下所示:
def validate(model: nn.Module, dataloader: DataLoader, device: torch.device):
acc_metric = torchmetrics.Accuracy(task='multiclass', num_classes=10, compute_on_step=False).to(device)
loss_metric = torchmetrics.MeanMetric(compute_on_step=False).to(device)
model.eval()
with torch.no_grad():
for batch in tqdm(dataloader):
# pad sequences in the batch
input_ids = nn.utils.rnn.pad_sequence(batch[0], batch_first=True)
input_mask = nn.utils.rnn.pad_sequence(batch[1], batch_first=True)
tags = nn.utils.rnn.pad_sequence(batch[2], batch_first=True)
# output shape: (batch_size, max_length, num_classes)
logits = model(input_ids.to(device), input_mask.to(device))
# ignore padding index 0 when calculating loss
loss = F.cross_entropy(logits.reshape(-1, 10), tags.reshape(-1), ignore_index=0)
loss_metric.update(loss, input_mask.numel() - input_mask.sum())
is_active = torch.logical_not(input_mask) # non-padding elements
# only consider non-padded tokens when calculating accuracy
acc_metric.update(logits[is_active], tags[is_active])
print(f"| Validate | loss {loss_metric.compute():.4f} | acc {acc_metric.compute():.4f} |")
使用 torch.nn.utils.rnn.pad_sequence 函数会将所有序列按照最长的序列进行填充。但是,如果填充到的长度太大,这可能会增加模型的计算成本,因为它需要处理更多的填充元素。因此,最好根据您的数据分布手动设置填充长度。
先把 batchsize 设置小一点,看看是否还报错。