pytorch中继承nn.Module后类的奇怪用法

我一直以为我看的代码有语法错误,直到我刚刚才发现他的玄机
先上代码:

import torch.nn as nn
class A():#nn.Module
    def __init__(self,a="MaXh",b=100):
        super(A, self).__init__()
        self.name = a
        self.age = b

    def forward(self,name2):
        print(f"Hello, I'm{self.name} and i am {self.age} yearsold","\n","my girlfriend is",name2)
        return name2


class B:
    def __init__(self,c=0):
        self.a = A()

    def forward(self,name2):
        out=self.a(name2)#看似有语法错误,不应该是out=self.a.forward(name2) ?
        return out

b = B()
b.forward("ma")


运行结果:报错
但是,如果让A类继承nn.Module:


import torch.nn as nn
class A(nn.Module):#
    def __init__(self,a="MaXh",b=100):
        super(A, self).__init__()
        self.name = a
        self.age = b

    def forward(self,name2):
        print(f"Hello, I'm{self.name} and i am {self.age} yearsold","\n","my girlfriend is",name2)
        return name2


class B:
    def __init__(self,c=0):
        self.a = A()

    def forward(self,name2):
        out=self.a(name2)#看似有语法错误,不应该是out=self.a.forward(name2) ?
        return out

b = B()
b.forward("ma")

代码中看似有语法错误的那一行就能正常运行:

Hello, I'mMaXh and i am 100 yearsold 
 my girlfriend is ma
'ma'

Why?

基于Monster 组和GPT的调写:

  • 在第一个代码示例中,A类没有继承自nn.Module,因此它不是一个PyTorch的模块,而是一个普通的Python类。在这种情况下,forward方法只是一个普通的实例方法,并不具有特殊的含义。
  • 当你调用self.a(name2)时,self.a是一个A类的实例,而不是一个PyTorch模块,所以它不会自动调用forward方法。因此,你会得到一个TypeError,因为你不能像调用函数一样调用一个非函数的对象。
  • 在第二个代码示例中,A类继承自nn.Module,它成为了一个PyTorch的模块。PyTorch模块中的forward方法具有特殊的含义,用于定义模块的前向传播逻辑。
  • 因此,在第二个示例中,当你调用self.a(name2)时,由于self.a是一个PyTorch模块,它会自动调用其forward方法,从而正常运行代码并输出结果。
  • 这篇博客: [PyTorch 学习笔记] 3.1 模型创建步骤与 nn.Module中的 nn.Module 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
  • nn.Module 有 8 个属性,都是OrderDict(有序字典)。在 LeNet 的__init__()方法中会调用父类nn.Module__init__()方法,创建这 8 个属性。

        def __init__(self):
            """
            Initializes internal Module state, shared by both nn.Module and ScriptModule.
            """
            torch._C._log_api_usage_once("python.nn_module")
    
            self.training = True
            self._parameters = OrderedDict()
            self._buffers = OrderedDict()
            self._backward_hooks = OrderedDict()
            self._forward_hooks = OrderedDict()
            self._forward_pre_hooks = OrderedDict()
            self._state_dict_hooks = OrderedDict()
            self._load_state_dict_pre_hooks = OrderedDict()
            self._modules = OrderedDict()
    
    • _parameters 属性:存储管理 nn.Parameter 类型的参数
    • _modules 属性:存储管理 nn.Module 类型的参数
    • _buffers 属性:存储管理缓冲属性,如 BN 层中的 running_mean
    • 5 个 ***_hooks 属性:存储管理钩子函数

    其中比较重要的是parametersmodules属性。

    在 LeNet 的__init__()中创建了 5 个子模块,nn.Conv2d()nn.Linear()都是 继承于nn.module,也就是说一个 module 都是包含多个子 module 的。

    class LeNet(nn.Module):
    	# 子模块创建
        def __init__(self, classes):
            super(LeNet, self).__init__()
            self.conv1 = nn.Conv2d(3, 6, 5)
            self.conv2 = nn.Conv2d(6, 16, 5)
            self.fc1 = nn.Linear(16*5*5, 120)
            self.fc2 = nn.Linear(120, 84)
            self.fc3 = nn.Linear(84, classes)
            ...
            ...
            ...
    

    当调用net = LeNet(classes=2)创建模型后,net对象的 modules 属性就包含了这 5 个子网络模块。


    下面看下每个子模块是如何添加到 LeNet 的`_modules` 属性中的。以`self.conv1 = nn.Conv2d(3, 6, 5)`为例,当我们运行到这一行时,首先 Step Into 进入 `Conv2d`的构造,然后 Step Out。右键`Evaluate Expression`查看`nn.Conv2d(3, 6, 5)`的属性。

    上面说了`Conv2d`也是一个 module,里面的`_modules`属性为空,`_parameters`属性里包含了该卷积层的可学习参数,这些参数的类型是 Parameter,继承自 Tensor。

    此时只是完成了`nn.Conv2d(3, 6, 5)` module 的创建。还没有赋值给`self.conv1 `。在`nn.Module`里有一个机制,会拦截所有的类属性赋值操作(`self.conv1`是类属性),进入到`__setattr__()`函数中。我们再次 Step Into 就可以进入`__setattr__()`。
        def __setattr__(self, name, value):
            def remove_from(*dicts):
                for d in dicts:
                    if name in d:
                        del d[name]
    
            params = self.__dict__.get('_parameters')
            if isinstance(value, Parameter):
                if params is None:
                    raise AttributeError(
                        "cannot assign parameters before Module.__init__() call")
                remove_from(self.__dict__, self._buffers, self._modules)
                self.register_parameter(name, value)
            elif params is not None and name in params:
                if value is not None:
                    raise TypeError("cannot assign '{}' as parameter '{}' "
                                    "(torch.nn.Parameter or None expected)"
                                    .format(torch.typename(value), name))
                self.register_parameter(name, value)
            else:
                modules = self.__dict__.get('_modules')
                if isinstance(value, Module):
                    if modules is None:
                        raise AttributeError(
                            "cannot assign module before Module.__init__() call")
                    remove_from(self.__dict__, self._parameters, self._buffers)
                    modules[name] = value
                elif modules is not None and name in modules:
                    if value is not None:
                        raise TypeError("cannot assign '{}' as child module '{}' "
                                        "(torch.nn.Module or None expected)"
                                        .format(torch.typename(value), name))
                    modules[name] = value
                ...
                ...
                ...
    

    在这里判断 value 的类型是Parameter还是Module,存储到对应的有序字典中。

    这里nn.Conv2d(3, 6, 5)的类型是Module,因此会执行modules[name] = value,key 是类属性的名字conv1,value 就是nn.Conv2d(3, 6, 5)