Pytorch中的nn.ModuleList改写成TensorFlow2.6时可以换成什么?

我想将一个Pytorch中的PyramidPooling类改写为TensorFlow2.6的版本使用,但其中的nn.ModuleList和nn.Sequential不知道如何修改,另外在class中的def和普通定义的函数def中,conv2d的使用有何区别?

import torch
from torch import nn

class PyramidPooling(nn.Module):
    def __init__(self, in_channels, out_channels, scales=(4, 8, 16, 32), ct_channels=1):
        super().__init__()
        self.stages = []
        self.stages = nn.ModuleList([self._make_stage(in_channels, scale, ct_channels) for scale in scales])
        self.bottleneck = nn.Conv2d(in_channels + len(scales) * ct_channels, out_channels, kernel_size=1, stride=1)
        self.relu = nn.LeakyReLU(0.2, inplace=True)

    def _make_stage(self, in_channels, scale, ct_channels):
        prior = nn.AvgPool2d(kernel_size=(scale, scale))
        conv = nn.Conv2d(in_channels, ct_channels, kernel_size=1, bias=False)
        relu = nn.LeakyReLU(0.2, inplace=True)
        return nn.Sequential(prior, conv, relu)

    def forward(self, feats):
        h, w = feats.size(2), feats.size(3)
        priors = torch.cat([F.interpolate(input=stage(feats), size=(h, w), mode='nearest') for stage in self.stages] + [feats], dim=1)
        return self.relu(self.bottleneck(priors))

代码如上所示,谢谢

tersorflow里面是model类,建立这个类就可以,conv2d你直接搜TensorFlow函数中文手册就可以,和pytorch差不多。

tensorflow在构建图的过程中,会将所有训练参数放到网络中的,不需要ModuleList()这样的模块,pytorch是动态图,所有的层都可以动态加入网络中,所以需要将变量注册到网络中,才会被优化优化。