video swin transformer(mmaction)的全连接层在哪个代码里面
SwinTransformer3D
patch_embed
: PatchEmbed3D
(2,4,4)
,对每个patch使用3d-conv进行特征提取并降采样padding
:对无法被patch_size整除维度进行填零paddingself.proj = conv3d(3, 96, kernel_size = patch_size, stride=patch_size)
:对输入特征进行三维卷积,即对每个patch_size大小窗口的输入进行特征提取,每个patch_size输出一个96维特征norm
(optional): fllatten
+ transpose
+ layer_norm
(对channel维度进行norm,即对每个patch_size的96维特征进行归一化)+transposepos_drop
: nn.Drop
self.layers
: depths [2, 2, 6, 2]
多个BasicLayer进行串联
BasicLayer
进一步对上层输出信号切分成多个3d-window,window_size默认(8,7,7)
,对patch和patch之间的特征关联进行信息提取get_window_size((D,H,W), window_size=(8,7,7), shift_size=(4,3,3))
rearrange(x, 'b c d h w -> b d h w c')
self.attn_mask = compute_mask(Dp, Hp, Wp, window_size, shift_size, x.device)
根据输入尺度和window_size生成transformer中的mask,对非自身window的特征关联信息进行抑制nn.ModuleList(SwinTransformerBlock3D(for i in range(depth)])
多个SwinTransformerBlock3D
进行串联 (B,D,H,W,C)
nn.LayerNorm
F.pad
torch.roll(optional)
x_windows = window_partition
: shape (B*nW, Wd*Wh*Ww, C)
window切分attn_windows = self.attn(x_windows, mask=attn_mask)
: WindowAttention3D
对window内部进行self-attention特征提取, shape (B*nW, Wd*Wh*Ww, C)
nn.Linear(dim, dim * 3, bias=qkv_bias)
将输入升维三倍q, k, v = qkv[0], qkv[1], qkv[2]
提取K,Q,V特征q * self.scale = head_dim ** -0.5
根据head_num进行缩放,防止multi-head大小对信号量影响过大attn = q @ k.transpose(-2, -1)
内积attn + relative_position_bias: relative_position_bias_table
加入位置编码(防止特征顺序对transformer模块失效,不参与学习)attn.view(B_ // nW, nW, self.num_heads, N, N) + mask
加入关联特征激活/抑制mask,这里mask就是之前提取的self.attn_mask
self.softmax(attn) + self.attn_drop(attn)
Transformer标准模块x = (attn @ v)
Transformer标准模块self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop)
Transformer标准模块x = shortcut + self.drop_path(x)
FFN模块downsample
: PatchMerging
对输出特征进行重排,H和W变为1/2(不对D进行降采样),channel会变成4倍
norm
: nn.LayerNorm
nn.Linear(4 * dim, 2 * dim)
channel降维rearrange(x, 'b d h w c -> b c d h w')
rearrange
+ norm
+ rearrange
Swin-trans参数膨胀inflate_weights
patch_embed
中的conv3d
选择直接膨胀初始化conv2d
relative_position_bias_table
两种:膨胀初始化、中心初始化