看informer源代码时,attn.py里的ProbAttention有这么一段:
def _prob_QK(self, Q, K, sample_k, n_top): # n_top: c*ln(L_q)
# Q [B, H, L, D]
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# calculate the sampled Q_K
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
index_sample = torch.randint(L_K, (L_Q, sample_k)) # real U = U_part(factor*ln(L_k))*L_q
K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze()
其中对K_expand进行切片索引的操作,即K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]这一句,不明白为什么要对torch.arange(L_Q)进行维度扩展的操作。通过实验发现
K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]的维度是(B,H,L_Q,sample_k,E)
K_expand[:, :, :, index_sample, :]的维度是(B,H,L_Q,L_Q,sample_k,E)
K_expand[:,:,torch.arange(L_Q),index_sample,:]会报错,出现indexing tensors could not be broadcast together with shapes [L_Q], [L_Q, sample_k]的问题
推测torch的张量索引是通过广播的机制来完成的,但是具体的实现机制依旧不清楚,十分困惑,希望有懂的人能解答!
报错是下面的torch.matmul这里乘法对应的维度不匹配报错无法广播(mxn的矩阵要和nxk的矩阵才能做乘法),而不是切片这里报错
至于你说的这里是用的是列表切片法,具体切片效果可以看这篇文章。总之这里的切片就是要保证下面的矩阵乘法维度数对应上。