在复现STDN网络的过程中,需要提取最后一个denseblock的中间某6层特征图输出用作预测和回归。
我所编写网络的部分打印结果如下:
…………
(transition3): Sequential(
(transition_bn): BatchNorm2d(1280, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(transition_relu): ReLU(inplace=True)
(transition_conv): Conv2d(1280, 640, kernel_size=(1, 1), stride=(1, 1))
(transition_pool): AvgPool2d(kernel_size=2, stride=2, padding=0)
)
(denseblock4): Sequential(
(dense_0): denselayer(
(bn1): BatchNorm2d(640, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(640, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
(dense_1): denselayer(
(bn1): BatchNorm2d(672, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu1): ReLU(inplace=True)
(conv1): Conv2d(672, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu2): ReLU(inplace=True)
(conv2): Conv2d(128, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
)
…………
其中(transition3)为(denseblock4)前的过渡层,(denseblock4)内有dense_0 - 31共32个密集连接层。现在我需要提取在(denseblock4)其中的dense_4 9 14 19 24 31的输出特征图出来,请问应该怎么做?
网上找寻的方法只能提取整个(denseblock4)的输出,没法再深入到(denseblock4)其中的子层。
感谢!!
在forward中,通过子层名提取出所需要的特征图结果输出:
def forward(self, x):
…… ……
feature_list = []
for (name, module) in self.denseblock4.named_children():
name_list = ['dense_4', 'dense_9', 'dense_14', 'dense_19', 'dense_24', 'dense_31']
x = module(x)
if name in name_list:
feature_list.append(x)
return x, feature_list
参考GPT和自己的思路:你可以使用hooks来截取网络中每一层的输出,在这里,你需要在跟踪dense_4 9 14 19 24 31层的输出之前注册hook。以下是代码示例:
import torch
from torch.nn import Module
class Flatten(Module):
def forward(self, input):
return input.view(input.size(0), -1)
class DenseBlock(Module):
def __init__(self, n_layers, in_channels, growth_rate, kernel_size=3, padding=1):
super(DenseBlock, self).__init__()
self.layer_list = self._make_layer(n_layers, in_channels, growth_rate, kernel_size, padding)
self.transition = self._make_transition(in_channels + n_layers * growth_rate)
def _make_layer(self, n_layers, in_channels, growth_rate, kernel_size, padding):
layer_list = []
for i in range(n_layers):
layer = self._make_single_layer(in_channels + i * growth_rate, growth_rate, kernel_size, padding)
layer_list.append(layer)
return ModuleList(layer_list)
def _make_single_layer(self, in_channels, out_channels, kernel_size, padding):
return Sequential(
BatchNorm2d(in_channels),
ReLU(inplace=True),
Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=padding, bias=False),
)
def _make_transition(self, in_channels):
return Sequential(
BatchNorm2d(in_channels),
Conv2d(in_channels, in_channels // 2, kernel_size=1),
AveragePooling2d(kernel_size=2, stride=2)
)
def forward(self, x, save_features=None):
features = []
for layer in self.layer_list:
out = layer(x)
x = torch.cat([x, out], dim=1)
if save_features is not None:
features.append(out)
x = self.transition(x)
return (x, features) if save_features is not None else x
class DenseNet(Module):
def __init__(self, growth_rate=32, n_layers_per_block=(6, 12, 24, 16), in_channels=3, num_classes=10):
super(DenseNet, self).__init__()
self.layer1 = Sequential(
Conv2d(in_channels, 64, kernel_size=7, stride=2, padding=3, bias=False),
BatchNorm2d(64),
ReLU(inplace=True),
MaxPool2d(kernel_size=3, stride=2, padding=1)
)
self.dense_block1 = DenseBlock(n_layers_per_block[0], in_channels=64, growth_rate=growth_rate)
self.dense_block2 = DenseBlock(n_layers_per_block[1], in_channels=128, growth_rate=growth_rate)
self.dense_block3 = DenseBlock(n_layers_per_block[2], in_channels=256, growth_rate=growth_rate)
self.dense_block4 = DenseBlock(n_layers_per_block[3], in_channels=512, growth_rate=growth_rate)
self.classifier = Sequential(
AdaptiveAvgPool2d((1, 1)),
Flatten(),
Linear(512, num_classes)
)
self.layers = [self.dense_block1, self.dense_block2, self.dense_block3, self.dense_block4]
def forward(self, x, save_features=None):
x = self.layer1(x)
for layer in self.layers:
x, d_features = layer(x, save_features=save_features)
if save_features is not None:
save_features.append(d_features)
x = self.classifier(x)
return x
然后,你可以通过向forward()函数传入一个名为save_features的参数来保存截取的子层的输出:
densenet = DenseNet()
save_features = []
hook_handles = []
for name, layer in densenet.named_modules():
if isinstance(layer, DenseBlock):
for i in [4, 9, 14, 19, 24, 31]:
hook_handles.append(layer.layer_list[i].register_forward_hook(lambda self, input, output, save_features=save_features: save_features.append(output)))
output = densenet(torch.rand((1, 3, 224, 224)), save_features=save_features)
print([f.shape for f in save_features])
for handle in hook_handles:
handle.remove()