The size of tensor a (9) must match the size of tensor b (18) at non-singleton dimension 3

The size of tensor a (9) must match the size of tensor b (18) at non-singleton dimension 3

def __init__(self, batch_size=8, num_pos=4, temperature=50):
    super(CMAlign, self).__init__()
    self.batch_size = batch_size
    self.num_pos = num_pos
    self.criterion0 = nn.TripletMarginLoss(margin=0.3, p=2.0, reduce=False)#修改与原码不一样,criterion+0
    self.temperature = temperature

def _random_pairs(self):
    batch_size = self.batch_size
    num_pos = self.num_pos

    pos = []
    for batch_index in range(batch_size):
        pos_idx = random.sample(list(range(num_pos)),
                                num_pos)  # random.sample用于截取列表的指定长度num_pos的随机数,list列出所有的整数(0-num_pos)
        pos_idx = np.array(pos_idx) + num_pos * batch_index  # np.array创建一个数组
        pos = np.concatenate((pos, pos_idx))  # np.concatenate拼接两个数组
    pos = pos.astype(int)

    neg = []
    for batch_index in range(batch_size):
        batch_list = list(range(batch_size))
        batch_list.remove(batch_index)  # remove移除列表中的元素

        batch_idx = random.sample(batch_list, num_pos)
        neg_idx = random.sample(list(range(num_pos)), num_pos)

        batch_idx, neg_idx = np.array(batch_idx), np.array(neg_idx)
        neg_idx = batch_idx * num_pos + neg_idx
        neg = np.concatenate((neg, neg_idx))
    neg = neg.astype(int)

    return {'pos': pos, 'neg': neg}

def _define_pairs(self):
    pairs_v = self._random_pairs()
    pos_v, neg_v = pairs_v['pos'], pairs_v['neg']

    pairs_t = self._random_pairs()
    pos_t, neg_t = pairs_t['pos'], pairs_t['neg']

    pos_v += self.batch_size * self.num_pos
    neg_v += self.batch_size * self.num_pos

    return {'pos': np.concatenate((pos_v, pos_t)), 'neg': np.concatenate((neg_v, neg_t))}

def feature_similarity(self, feat_q, feat_k):
    batch_size, fdim, h, w = feat_q.shape
    feat_q = feat_q.view(batch_size, fdim, -1)  # .view将特征图铺平,转化为一维向量
    feat_k = feat_k.view(batch_size, fdim, -1)

    feature_sim = torch.bmm(F.normalize(feat_q, dim=1).permute(0, 2, 1),
                            F.normalize(feat_k, dim=1))  # .bmm两个tensor矩阵乘法,.permute将tensor维度换位
    return feature_sim

def matching_probability(self, feature_sim):
    M, _ = feature_sim.max(dim=-1, keepdim=True)  # .max返回输入张量给定维度上每行的最大值,并同时返回每个最大值的位置索引。
    feature_sim = feature_sim - M  # for numerical stability
    exp = torch.exp(self.temperature * feature_sim)
    exp_sum = exp.sum(dim=-1, keepdim=True)
    return exp / exp_sum

def soft_warping(self, matching_pr, feat_k):
    batch_size, fdim, h, w = feat_k.shape
    feat_k = feat_k.view(batch_size, fdim, -1)
    feat_warp = torch.bmm(matching_pr, feat_k.permute(0, 2, 1))
    feat_warp = feat_warp.permute(0, 2, 1).view(batch_size, fdim, h, w)

    return feat_warp

def reconstruct(self, mask, feat_warp, feat_q):
    return mask * feat_warp + (1.0 - mask) * feat_q

def compute_mask(self, feat):
    batch_size, fdim, h, w = feat.shape
    norms = torch.norm(feat, p=2, dim=1).view(batch_size, h * w)  # torch.norm对feat的行求2范数

    norms -= norms.min(dim=-1, keepdim=True)[0]
    norms /= norms.max(dim=-1, keepdim=True)[0] + 1e-12
    mask = norms.view(batch_size, 1, h, w)

    return mask.detach()  # detach作用:返回一个新的tensor,从当前计算图中分离下来的,但是仍指向原变量的存放位置,不同之处只是requires_grad为false,得到的这个tensor永远不需要计算其梯度,不具有grad

def compute_comask(self, matching_pr, mask_q, mask_k):
    batch_size, mdim, h, w = mask_q.shape
    mask_q = mask_q.view(batch_size, -1, 1)
    mask_k = mask_k.view(batch_size, -1, 1)
    comask = mask_q * torch.bmm(matching_pr, mask_k)

    comask = comask.view(batch_size, -1)
    comask -= comask.min(dim=-1, keepdim=True)[0]
    comask /= comask.max(dim=-1, keepdim=True)[0] + 1e-12
    comask = comask.view(batch_size, mdim, h, w)

    return comask.detach()

def forward(self, feat_v, feat_t):
    feat = torch.cat([feat_v, feat_t], dim=0)  # torch.cat在列上拼接两个tensor
    mask = self.compute_mask(feat)
    batch_size, fdim, h, w = feat.shape

    pairs = self._define_pairs()
    pos_idx, neg_idx = pairs['pos'], pairs['neg']

    # positive
    feat_target_pos = feat[pos_idx]
    feature_sim = self.feature_similarity(feat, feat_target_pos)
    matching_pr = self.matching_probability(feature_sim)

    comask_pos = self.compute_comask(matching_pr, mask, mask[pos_idx])
    feat_warp_pos = self.soft_warping(matching_pr, feat_target_pos)
    feat_recon_pos = self.reconstruct(mask, feat_warp_pos, feat)

    # negative
    feat_target_neg = feat[neg_idx]
    feature_sim = self.feature_similarity(feat, feat_target_neg)
    matching_pr = self.matching_probability(feature_sim)

    feat_warp = self.soft_warping(matching_pr, feat_target_neg)
    feat_recon_neg = self.reconstruct(mask, feat_warp, feat)

    loss0 = torch.mean(comask_pos * self.criterion0(feat, feat_recon_pos, feat_recon_neg))#修改与原码不一样,loss+0

    return {'feat': feat_recon_pos, 'loss0': loss0} 

问题未描述清楚,代码没看出什么问题

代码没有给完整。