相对位置编码的Pytorch实现,1d数据

我想询问一下,就是说对于Transformer的相对位置编码是怎么实现的,我是使用Pytorch的,然后处理的数据是1d的,想问问有没有实现过的,我也是尝试了一下,但是感觉实验结果不理想,所以想问问我的是否有错,或者有没有成品给我尝试一下。

import torch
import torch.nn as nn

# 获得相对位置矩阵,这时候还没有乘于可训练参数
def position_distance( Seq ):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    positional_l = torch.arange(Seq, dtype=torch.long, device=device).view(-1, 1)      # 获得长度
    positional_r = torch.arange(Seq, dtype=torch.long, device=device).view(1, -1)      # 获得长度
    distance = positional_l - positional_r      # 相减获得相互距离
    distance = distance + Seq - 1       # 让值都保持为正数
    return distance


class Multihead_Attention(nn.Module):
    def __init__(self, dim, num_heads, Seq):
        super().__init__()

        # Q, K, V 转换矩阵
        self.q = nn.Linear(dim, dim, bias=False)
        self.k = nn.Linear(dim, dim, bias=False)
        self.v = nn.Linear(dim, dim, bias=False)
        self.num_heads = num_heads

        self.position_dim = dim // num_heads     # 因为是计算每个头的相对位置,所以dim要除于head
        # self.Seq_embedding = nn.Embedding(2*Seq-1, self.position_dim)
        self.Seq_embedding = nn.Parameter(torch.zeros((2*Seq-1), self.position_dim))

    def forward(self, x):  # [batch, Seq, dim]

        # *************多头注意力机制*************

        batch_size, Seq, dim = x.shape

        # q k v -> [batch, head, seq, dim]
        q = self.q(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)
        k = self.k(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)
        v = self.k(x).reshape(batch_size, Seq, self.num_heads, -1).permute(0, 2, 1, 3)

        # 计算相对位置距离
        distance = position_distance(Seq)
        distance = self.Seq_embedding[distance]     # ->[seq, seq, dim/head]
        # distance = self.Seq_embedding(distance)  # ->[seq, seq, dim/head]
        distance = distance.transpose(1, 2)     # ->[seq, dim/head, seq]

        # 计算q和distance相乘
        q_distance = q.permute(2, 0, 1, 3).reshape(Seq, batch_size*self.num_heads, self.position_dim)   # ->[seq, batch*head, dim]
        QmD = (q_distance @ distance).reshape(Seq, batch_size, self.num_heads, Seq).permute(1, 2, 0, 3)    # ->[dim, head, seq, seq]

        # 点积得到attention score
        MultiHead_attn = ((q@k.transpose(2, 3)) + QmD) * (self.position_dim ** -0.5)      # -> [batch, head, seq, seq]
        MultiHead_attn = MultiHead_attn.softmax(dim=-1)

        # 乘上attention score并输出  -> [batch, dim, Seq]
        MultiHead_attn = (MultiHead_attn @ v).permute(0, 2, 1, 3).reshape(batch_size, Seq, dim)

        return MultiHead_attn



楼主,相对位置实现,你可以参考这个网址:
https://blog.csdn.net/cyz0202/article/details/124929307


可以看下pytorch参考手册中的 pytorch atleast_1d() (in module torch)