video swin transformer的全连接层在哪个代码里面

video swin transformer(mmaction)的全连接层在哪个代码里面

  • 这篇博客: 【代码解析】mmaction2: Video Swin Transformer中的 1.2 解析 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
    • SwinTransformer3D

      • patch_embed: PatchEmbed3D
        将输入三维信号切分成多个3d-patch,patch_size默认(2,4,4),对每个patch使用3d-conv进行特征提取并降采样
        • padding:对无法被patch_size整除维度进行填零padding
        • self.proj = conv3d(3, 96, kernel_size = patch_size, stride=patch_size):对输入特征进行三维卷积,即对每个patch_size大小窗口的输入进行特征提取,每个patch_size输出一个96维特征
        • norm(optional): fllatten + transpose + layer_norm(对channel维度进行norm,即对每个patch_size的96维特征进行归一化)+transpose
    • pos_drop: nn.Drop

    • self.layers : depths [2, 2, 6, 2] 多个BasicLayer进行串联

      • BasicLayer 进一步对上层输出信号切分成多个3d-window,window_size默认(8,7,7),对patch和patch之间的特征关联进行信息提取
        • get_window_size((D,H,W), window_size=(8,7,7), shift_size=(4,3,3))
        • rearrange(x, 'b c d h w -> b d h w c')
        • self.attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device) 根据输入尺度和window_size生成transformer中的mask,对非自身window的特征关联信息进行抑制
          在这里插入图片描述
      • nn.ModuleList(SwinTransformerBlock3D(for i in range(depth)])多个SwinTransformerBlock3D进行串联 (B,D,H,W,C)
        在这里插入图片描述
        • nn.LayerNorm
        • F.pad
        • torch.roll(optional)
        • x_windows = window_partition: shape (B*nW, Wd*Wh*Ww, C) window切分
        • attn_windows = self.attn(x_windows, mask=attn_mask): WindowAttention3D 对window内部进行self-attention特征提取, shape (B*nW, Wd*Wh*Ww, C)
          • nn.Linear(dim, dim * 3, bias=qkv_bias) 将输入升维三倍
          • q, k, v = qkv[0], qkv[1], qkv[2] 提取K,Q,V特征
          1. q * self.scale = head_dim ** -0.5根据head_num进行缩放,防止multi-head大小对信号量影响过大
          2. attn = q @ k.transpose(-2, -1) 内积
          • attn + relative_position_bias: relative_position_bias_table 加入位置编码(防止特征顺序对transformer模块失效,不参与学习)
          • attn.view(B_ // nW, nW, self.num_heads, N, N) + mask 加入关联特征激活/抑制mask,这里mask就是之前提取的self.attn_mask
          • self.softmax(attn) + self.attn_drop(attn) Transformer标准模块
          • x = (attn @ v) Transformer标准模块
          • self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) Transformer标准模块
          • x = shortcut + self.drop_path(x) FFN模块
    • downsample: PatchMerging 对输出特征进行重排,H和W变为1/2(不对D进行降采样),channel会变成4倍在这里插入图片描述

      • 对H和W进行间隔采样
      • norm: nn.LayerNorm
      • nn.Linear(4 * dim, 2 * dim) channel降维
    • rearrange(x, 'b d h w c -> b c d h w')

    • rearrange + norm + rearrange

    Swin-trans参数膨胀
    inflate_weights

    • patch_embed 中的conv3d选择直接膨胀初始化conv2d
    • relative_position_bias_table 两种:膨胀初始化、中心初始化