multihead_attn = nn.MultiheadAttention(embed_dim, num_heads)
attn_output = multihead_attn(query, key, value)[0]
我想知道这个query、key、value是经过线性变换前的还是已经变换后的。
比如原来的基础embding是x、y。
这个地方应该是query = y,key =x,value = x
还是 query =wy ,key =wx, value = wx ,其中w为训练参数。
该回答引用ChatGPT
根据 nn.MultiheadAttention 的文档,query, key, value 输入的形状应该是 (seq_len, batch_size, embed_dim),即每个时间步的输入向量形状是 (batch_size, embed_dim)。这里的 query, key, value 应该是经过线性变换后的向量。
在 nn.MultiheadAttention 的初始化中,有三个线性层,分别对应 query, key, value 的线性变换,可以使用 nn.Linear 模块来实现,其中 in_features 表示输入向量的维度,out_features 表示输出向量的维度。比如:
query_linear = nn.Linear(embed_dim, embed_dim)
key_linear = nn.Linear(embed_dim, embed_dim)
value_linear = nn.Linear(embed_dim, embed_dim)
query = query_linear(y)
key = key_linear(x)
value = value_linear(x)
这样,query, key, value 就是经过线性变换后的向量。