为什么说self-attention只在通道维度共享权重

如题,在一篇论文里面看到上述描述,没太能理解。self-attention的权重不是来自于qkv矩阵的计算吗 这个过程应该是通道之间不交叉,空间维度共享权重才对吧

  • 这篇文章讲的很详细,请看:如何理解self attention中的QKV矩阵
  • 除此之外, 这篇博客: 通俗易懂:Attention中的Q、K、V是什么?怎么得到Q、K、V?中的 示例1 Self-Attention之打野捉上单被反杀了怎么办 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
  • 大部分人在讲Attention时都会讲到Self-Attention,毕竟这也是经典了。
    而当我们打游戏出现打野来上被他反杀这种现象时,我们的第一反应是:在NLP中,我们能否通过语言理解去做到,这个他指的是上单还是打野?
    在这里插入图片描述

    那么在Self-Attention中的做法是:
    1、根据这个句子得到打野、上、他的embedding,在下图表示为e1、e2、e3e_{1}、e_{2}、e_{3}e1e2e3
    2、将e通过不同的线性变换Q、K、V。(注意理解:所有的e到Q或K或V都采取相同的线性变换,而e到Q和K和V采取不同的线性变换。)实际上这里把Q、K、V都看做与e一样的话可能更好理解。
    3、根据Q3Q_{3}Q3分别与K1、K2K_{1}、K_{2}K1K2计算相似程度即Attention值,得到A1,3、A2,3A_{1,3}、A_{2,3}A1,3A2,3
    4、AttentionV3=(A1,3+A2,3+1)⋅V3Attention V_{3}=(A_{1,3}+A_{2,3}+1)\cdot V _{3}AttentionV3=(A1,3+A2,3+1)V3
    AttentionV3Attention V_{3}AttentionV3向量中包括了这样的信息:他更可能指代上单,而不是打野。
    总结:这里的Q、K、V是根据单词embedding的线性变换得到的。
    图有点丑抱歉!

  • 以下回答由chatgpt基于相关博客总结生成:

    我理解的是参考资料中的内容,即在Self-Attention中的权重共享仅在通道维度上共享。这意味着,三个参数矩阵Q、K、V都是通过通道维度上的相同权重矩阵进行线性变换得到的。因此,对于一个单词的Q、K、V矩阵,它们之间的权重是共享的,也就是不同空间维度的权重是不同的。这样可以大大减少参数数量,提高模型训练效率。

    以下是一个示例代码,用于展示在Self-Attention中的权重共享是如何进行的:

    import torch
    import torch.nn as nn
    
    class SelfAttention(nn.Module):
        def __init__(self, in_dim, num_heads):
            super(SelfAttention, self).__init__()
            self.in_dim = in_dim
            self.num_heads = num_heads
            self.head_dim = in_dim // num_heads
    
            # Define weight matrices
            self.q_weight = nn.Linear(in_dim, in_dim, bias=False)
            self.k_weight = nn.Linear(in_dim, in_dim, bias=False)
            self.v_weight = nn.Linear(in_dim, in_dim, bias=False)
    
            # Weight initialization
            nn.init.xavier_uniform_(self.q_weight.weight)
            nn.init.xavier_uniform_(self.k_weight.weight)
            nn.init.xavier_uniform_(self.v_weight.weight)
    
            # Define output linear layer
            self.output_layer = nn.Linear(in_dim, in_dim)
    
        def forward(self, x):
            batch_size, seq_len, in_dim = x.size()
            num_heads = self.num_heads
            head_dim = self.head_dim
    
            # Linearly transform input to Q, K, V matrices
            q = self.q_weight(x)
            k = self.k_weight(x)
            v = self.v_weight(x)
    
            # Reshape Q, K, V matrices into multiple heads
            q = q.view(batch_size, seq_len, num_heads, head_dim)
            k = k.view(batch_size, seq_len, num_heads, head_dim)
            v = v.view(batch_size, seq_len, num_heads, head_dim)
    
            # Transpose Q, K matrices to enable batch multiplication
            q = q.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, head_dim)
            k = k.permute(0, 2, 3, 1)  # (batch_size, num_heads, head_dim, seq_len)
    
            # Multiply Q, K matrices to get attention weights
            att_weights = torch.matmul(q, k)
            att_weights = att_weights / (head_dim ** 0.5)
    
            # Apply softmax activation to get normalized attention weights
            att_weights = nn.Softmax(dim=-1)(att_weights)
    
            # Transpose attention weights matrix back to its original shape
            att_weights = att_weights.permute(0, 2, 1, 3)  # (batch_size, seq_len, num_heads, head_dim)
    
            # Multiply attention weights matrix with V matrix to get attention output
            att_output = torch.matmul(att_weights, v)
    
            # Reshape attention output back into its original shape
            att_output = att_output.view(batch_size, seq_len, in_dim)
    
            # Linearly transform attention output and add residual connection
            output = self.output_layer(att_output) + x
    
            return output
    

    在这个代码中,我们定义了一个名为SelfAttention的PyTorch模块,用于执行Self-Attention操作。该模块通过三个线性变换(self.q_weight、self.k_weight和self.v_weight)来计算输入x中每个单词的Q、K、V矩阵。这些权重矩阵在通道维度上是共享的。我们还定义了一个输出线性层(self.output_layer)来将注意力输出线性变换回原始输入的维度,并添加残差连接。

    请注意,在这个代码中主要是用PyTorch的内置函数实现的。有关注意力机制的更多详细信息,请查看论文或其他教程。

通道维度共享权重,空间维度共享权重,都没有错。
输入有多个通道,例如 RGB 图像有 3个通道,每个通道的权重共享,不同通道的权重不同。
你可以说是通道维度共享维度,强调的是不同通道的权值不同。
你也可以说上空间维度共享,强调的是同一通道(的不同空间位置)的权值相同。