实现python命名实体识别的报错

我的代码train可以正常跑,但validation就会报batchsize的错
求原因和修改方法

以下是model:

class TransformerModel(nn.Module):
    def __init__(self, vocab_size: int, embed_size: int, num_heads: int, hidden_size: int,
                 num_layers: int):
        super().__init__()
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(embed_size, 0.2, 128)
        encoder_layers = TransformerEncoderLayer(embed_size, num_heads, dropout=0.1, batch_first=True)
        self.transformer_encoder = TransformerEncoder(encoder_layers, num_layers)
        self.encoder = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTM(embed_size, hidden_size, batch_first = True)
        self.embed_size = embed_size
        self.decoder = nn.Linear(hidden_size, 10)

    def forward(self, src: Tensor, src_mask: Tensor) -> Tensor:
        src = self.encoder(src) 
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_key_padding_mask = src_mask)
        output = self.lstm(output)[0]
        output = self.decoder(output)
        return output

以下是train和validation的代码:

def validate(
    model: nn.Module, 
    dataloader: DataLoader, 
    device: torch.device,
):
    acc_metric = torchmetrics.Accuracy(task = 'multiclass', num_classes = 10, compute_on_step=False).to(device)
    loss_metric = torchmetrics.MeanMetric(compute_on_step=False).to(device)
    model.eval()
    
    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids, input_mask, tags = batch[0].to(device), batch[1].to(device), batch[2].to(device)
            # output shape: (batch_size, max_length, num_classes)
            logits = model(input_ids, input_mask)
            # ignore padding index 0 when calculating loss
            loss = F.cross_entropy(logits.reshape(-1, 10), tags.reshape(-1), ignore_index=0)
                
            loss_metric.update(loss, input_mask.numel() - input_mask.sum())
            is_active = torch.logical_not(input_mask)  # non-padding elements
            # only consider non-padded tokens when calculating accuracy
            acc_metric.update(logits[is_active], tags[is_active])
    
    print(f"| Validate | loss {loss_metric.compute():.4f} | acc {acc_metric.compute():.4f} |")
def train(
    model: nn.Module, 
    dataloader: DataLoader, 
    optimizer: optim.Optimizer,
    device: torch.device,
    epoch: int,
):
    acc_metric = torchmetrics.Accuracy(task = 'multiclass', num_classes = 10, compute_on_step=False).to(device)
    loss_metric = torchmetrics.MeanMetric(compute_on_step=False).to(device)
    model.train()
    
    # loop through all batches in the training
    for batch in tqdm(dataloader):
        input_ids, input_mask, tags = batch[0].to(device), batch[1].to(device), batch[2].to(device)
        optimizer.zero_grad()
        # output shape: (batch_size, max_length, num_classes)
        logits = model(input_ids, input_mask)
        # ignore padding index 0 when calculating loss
        loss = F.cross_entropy(logits.reshape(-1, 10), tags.reshape(-1), ignore_index=0)
        
        loss.backward()
        optimizer.step()
        
        loss_metric.update(loss, input_mask.numel() - input_mask.sum())
        is_active = torch.logical_not(input_mask)  # non-padding elements
        # only consider non-padded tokens when calculating accuracy
        acc_metric.update(logits[is_active], tags[is_active])
    
    print(f"| Epoch {epoch} | loss {loss_metric.compute():.4f} | acc {acc_metric.compute():.4f} |")

以下是训练模型的代码和报错信息:

model = TransformerModel(vocab_size = len(tokenizer), 
    embed_size = 256, 
    num_heads = 4, 
    hidden_size = 256,
    num_layers = 2,).to(device)

optimizer = optim.Adam(model.parameters())

for epoch in range(5):
    train(model, train_dataloader, optimizer, device, epoch)
validate(model, val_dataloader, device)

ValueError                                Traceback (most recent call last)
18-dc022b7dd0b5> in <module>
     19 for epoch in range(5):
     20     train(model, train_dataloader, optimizer, device, epoch)
---> 21 validate(model, val_dataloader, device)
1 frames
10-c57c46bc539f> in validate(model, dataloader, device)
     15             logits = model(input_ids, input_mask)
     16             # ignore padding index 0 when calculating loss
---> 17             loss = F.cross_entropy(logits.reshape(-1, 10), tags.reshape(-1), ignore_index=0)
     18 
     19             loss_metric.update(loss, input_mask.numel() - input_mask.sum())
/usr/local/lib/python3.9/dist-packages/torch/nn/functional.py in cross_entropy(input, target, weight, size_average, ignore_index, reduce, reduction, label_smoothing)
   3024     if size_average is not None or reduce is not None:
   3025         reduction = _Reduction.legacy_get_string(size_average, reduce)
-> 3026     return torch._C._nn.cross_entropy_loss(input, target, weight, _Reduction.get_enum(reduction), ignore_index, label_smoothing)
   3027 
   3028 
ValueError: Expected input batch_size (1312) to match target batch_size (4096).

该回答引用GPTᴼᴾᴱᴺᴬᴵ
从报错信息来看,出错是因为在 validate 中的数据 batch size 太小,因此无法将数据张量的形状转换为 (batch_size, num_classes)。
·
为了解决这个问题,你可以将 batch_size 增加到能够支持模型所需要的最小值(可以使用模型的 batch_size 属性)。你还可以尝试使用 torch.nn.utils.rnn.pad_sequence 函数将所有序列填充到相同的长度,以确保 batch 内的张量具有相同的形状。例如,您可以修改 validate 函数的代码,如下所示:

def validate(model: nn.Module, dataloader: DataLoader, device: torch.device):
    acc_metric = torchmetrics.Accuracy(task='multiclass', num_classes=10, compute_on_step=False).to(device)
    loss_metric = torchmetrics.MeanMetric(compute_on_step=False).to(device)
    model.eval()
    
    with torch.no_grad():
        for batch in tqdm(dataloader):
            # pad sequences in the batch
            input_ids = nn.utils.rnn.pad_sequence(batch[0], batch_first=True)
            input_mask = nn.utils.rnn.pad_sequence(batch[1], batch_first=True)
            tags = nn.utils.rnn.pad_sequence(batch[2], batch_first=True)
            
            # output shape: (batch_size, max_length, num_classes)
            logits = model(input_ids.to(device), input_mask.to(device))
            # ignore padding index 0 when calculating loss
            loss = F.cross_entropy(logits.reshape(-1, 10), tags.reshape(-1), ignore_index=0)
                
            loss_metric.update(loss, input_mask.numel() - input_mask.sum())
            is_active = torch.logical_not(input_mask)  # non-padding elements
            # only consider non-padded tokens when calculating accuracy
            acc_metric.update(logits[is_active], tags[is_active])
    
    print(f"| Validate | loss {loss_metric.compute():.4f} | acc {acc_metric.compute():.4f} |")


使用 torch.nn.utils.rnn.pad_sequence 函数会将所有序列按照最长的序列进行填充。但是,如果填充到的长度太大,这可能会增加模型的计算成本,因为它需要处理更多的填充元素。因此,最好根据您的数据分布手动设置填充长度。

先把 batchsize 设置小一点,看看是否还报错。