pytorch如何提取子层的特征图输出

在复现STDN网络的过程中,需要提取最后一个denseblock的中间某6层特征图输出用作预测和回归。
我所编写网络的部分打印结果如下:

…………
(transition3): Sequential(
    (transition_bn): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (transition_relu): ReLU(inplace=True)
    (transition_conv): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
    (transition_pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
  )
  (denseblock4): Sequential(
    (dense_0): denselayer(
      (bn1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
    (dense_1): denselayer(
      (bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu1): ReLU(inplace=True)
      (conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu2): ReLU(inplace=True)
      (conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    )
…………

其中(transition3)为(denseblock4)前的过渡层,(denseblock4)内有dense_0 - 31共32个密集连接层。现在我需要提取在(denseblock4)其中的dense_4 9 14 19 24 31的输出特征图出来,请问应该怎么做?
网上找寻的方法只能提取整个(denseblock4)的输出,没法再深入到(denseblock4)其中的子层。
感谢!!

在forward中,通过子层名提取出所需要的特征图结果输出:

    def forward(self, x):
        …… ……
        feature_list = []
        for (name, module) in self.denseblock4.named_children():
            name_list = ['dense_4', 'dense_9', 'dense_14', 'dense_19', 'dense_24', 'dense_31']
            x = module(x)
            if name in name_list:
                feature_list.append(x)
        return x, feature_list

参考GPT和自己的思路:你可以使用hooks来截取网络中每一层的输出,在这里,你需要在跟踪dense_4 9 14 19 24 31层的输出之前注册hook。以下是代码示例:

import torch
from torch.nn import Module

class Flatten(Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class DenseBlock(Module):
    def __init__(self, n_layers, in_channels, growth_rate, kernel_size=3, padding=1):
        super(DenseBlock, self).__init__()

        self.layer_list = self._make_layer(n_layers, in_channels, growth_rate, kernel_size, padding)
        self.transition = self._make_transition(in_channels + n_layers * growth_rate)

    def _make_layer(self, n_layers, in_channels, growth_rate, kernel_size, padding):
        layer_list = []
        for i in range(n_layers):
            layer = self._make_single_layer(in_channels + i * growth_rate, growth_rate, kernel_size, padding)
            layer_list.append(layer)

        return ModuleList(layer_list)

    def _make_single_layer(self, in_channels, out_channels, kernel_size, padding):
        return Sequential(
            BatchNorm2d(in_channels),
            ReLU(inplace=True),
            Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
        )

    def _make_transition(self, in_channels):
        return Sequential(
            BatchNorm2d(in_channels),
            Conv2d(in_channels, in_channels // 2, kernel_size=1),
            AveragePooling2d(kernel_size=2, stride=2)
        )

    def forward(self, x, save_features=None):
        features = []
        for layer in self.layer_list:
            out = layer(x)
            x = torch.cat([x, out], dim=1)
            if save_features is not None:
                features.append(out)

        x = self.transition(x)

        return (x, features) if save_features is not None else x

class DenseNet(Module):
    def __init__(self, growth_rate=32, n_layers_per_block=(6, 12, 24, 16), in_channels=3, num_classes=10):
        super(DenseNet, self).__init__()

        self.layer1 = Sequential(
            Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
            BatchNorm2d(64),
            ReLU(inplace=True),
            MaxPool2d(kernel_size=3, stride=2, padding=1)
        )

        self.dense_block1 = DenseBlock(n_layers_per_block[0], in_channels=64, growth_rate=growth_rate)
        self.dense_block2 = DenseBlock(n_layers_per_block[1], in_channels=128, growth_rate=growth_rate)
        self.dense_block3 = DenseBlock(n_layers_per_block[2], in_channels=256, growth_rate=growth_rate)
        self.dense_block4 = DenseBlock(n_layers_per_block[3], in_channels=512, growth_rate=growth_rate)
        self.classifier = Sequential(
            AdaptiveAvgPool2d((1, 1)),
            Flatten(),
            Linear(512, num_classes)
        )

        self.layers = [self.dense_block1, self.dense_block2, self.dense_block3, self.dense_block4]

    def forward(self, x, save_features=None):
        x = self.layer1(x)
        for layer in self.layers:
            x, d_features = layer(x, save_features=save_features)
            if save_features is not None:
                save_features.append(d_features)

        x = self.classifier(x)

        return x

然后,你可以通过向forward()函数传入一个名为save_features的参数来保存截取的子层的输出:

densenet = DenseNet()
save_features = []
hook_handles = []
for name, layer in densenet.named_modules():
    if isinstance(layer, DenseBlock):
        for i in [4, 9, 14, 19, 24, 31]:
            hook_handles.append(layer.layer_list[i].register_forward_hook(lambda self, input, output, save_features=save_features: save_features.append(output)))

output = densenet(torch.rand((1, 3, 224, 224)), save_features=save_features)

print([f.shape for f in save_features])
for handle in hook_handles:
    handle.remove()