关于#人工智能#的问题:这个地方应该是query = y,value = x还是 query =wy ,key =wx, value = wx ,其中w为训练参数

multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output = multihead_attn(query, key, value)[0]

output: torch.Size([12, 64, 300])

batch_size 为 64,有 12 个词,每个词的向量是 300 维

我想知道这个query、key、value是经过线性变换前的还是已经变换后的。

比如原来的基础embding是x、y。

这个地方应该是query = y,key =x,value = x
还是 query =wy ,key =wx, value = wx ,其中w为训练参数。

该回答引用ChatGPT

根据 nn.MultiheadAttention 的文档,query, key, value 输入的形状应该是 (seq_len, batch_size, embed_dim),即每个时间步的输入向量形状是 (batch_size, embed_dim)。这里的 query, key, value 应该是经过线性变换后的向量。

在 nn.MultiheadAttention 的初始化中,有三个线性层,分别对应 query, key, value 的线性变换,可以使用 nn.Linear 模块来实现,其中 in_features 表示输入向量的维度,out_features 表示输出向量的维度。比如:

query_linear = nn.Linear(embed_dim, embed_dim)
key_linear = nn.Linear(embed_dim, embed_dim)
value_linear = nn.Linear(embed_dim, embed_dim)

query = query_linear(y)
key = key_linear(x)
value = value_linear(x)

这样,query, key, value 就是经过线性变换后的向量。