Transformer在预测阶段解码器的输入是什么?就只输入一个BOS开始符号吗?

最近学了transformer框架
在训练过程解码器的输入都是正确标签加上掩码
那预测阶段呢或者说实际应用阶段解码器输入应该是什么呢?

  • 请看👉 :还在魔改Transformer结构吗?微软&中山大学开源超强的视觉位置编码,涨点显著
  • 以下回答由chatgpt基于相关博客总结生成:

    根据问题描述,Transformer在训练过程中,解码器的输入是正确的标签加上掩码。在预测阶段或者实际应用阶段,解码器的输入应该是通过前一个时间步的解码器输出来生成的。具体的解决方案如下:

    Step 1: 准备Transformer模型 首先,需要准备一个Transformer模型。你可以使用现成的模型库,比如transformers库,或者根据自己的需求实现一个Transformer模型。

    Step 2: 准备解码器输入 在预测阶段,解码器的初始输入是一个特殊的标记,比如开始标记<sos>。接下来的每一个时间步,解码器的输入都是通过前一个时间步的解码器输出来生成的。所以,需要一个循环来生成解码器的输入序列。

    Step 3: 生成解码器输入序列 使用一个循环来生成解码器的输入序列,直到达到预定的序列长度或者遇到结束标记<eos>。在每个时间步,将前一个时间步的解码器输出作为当前时间步的输入,并将当前时间步的输出添加到结果序列中。

    下面是一个示例代码,用来生成解码器的输入序列:

    # 准备Transformer模型
    model = TransformerModel()
    
    # 准备解码器输入
    decoder_input = torch.tensor([<sos>])  # 初始输入为开始标记
    
    max_length = 100  # 设置最大序列长度
    
    result = []  # 保存解码器的输出序列
    
    # 生成解码器输入序列
    for i in range(max_length):
        # 使用前一个时间步的解码器输出作为当前时间步的输入
        decoder_output = model.decoder(decoder_input)
    
        # 获取当前时间步的输出
        output_token = decoder_output.argmax(dim=-1)
    
        # 将当前时间步的输出添加到结果序列中
        result.append(output_token.item())
    
        # 检查是否达到结束标记
        if output_token.item() == <eos>:
            break
    
        # 更新解码器输入为当前时间步的输出
        decoder_input = output_token
    
    # 将结果序列转换为字符串或标签的形式
    result = convert_to_string(result)
    

    在上述代码中,首先准备了一个Transformer模型model,然后设置了一个循环来生成解码器的输入序列。在每个时间步中,使用前一个时间步的解码器输出作为当前时间步的输入,然后获取当前时间步的输出,并将其添加到结果序列中。当输出为结束标记时,循环停止。最后,可以将结果序列转换为字符串或标签的形式。

    请注意,上述代码仅为示例,实际应用中可能需要根据具体情况进行调整和优化。另外,具体的实现方式也可能因不同的Transformer模型库而有所差异。以上是根据问题描述和参考资料给出的一个解决方案,希望能帮到你。如果有任何问题,请随时提问。

在预测阶段(或称为推理阶段)和实际应用阶段,解码器的输入是由模型自身生成的。具体地说,对于序列生成任务如机器翻译、文本摘要等,我们通常采用某种策略,例如贪婪解码(Greedy Decoding)、集束搜索(Beam Search)或者拓扑采样(Top-k Sampling)等,将模型的当前输出作为下一个时间步的输入。

对于第一时间步,解码器的输入通常是一个特殊的起始符号,例如""或""。在这之后的每一个时间步,解码器的输入就是前一时间步模型生成的输出。

以贪婪解码为例:在第一个时间步,我们将起始符号输入解码器,解码器会产生一个输出。然后,我们取这个输出中概率最高的单词,再将其作为下一个时间步的输入。这个过程反复进行,直到模型生成了一个结束符号,例如""或"",或者生成的序列达到了某个预设的最大长度。

需要注意的是,在预测阶段,我们不再需要使用掩码(masking)技术,因为在每个时间步,解码器只看到其前一步的输出,而不能看到未来的信息。这确保了预测阶段的解码过程符合序列生成的自回归(auto-regressive)特性。