我一直以为我看的代码有语法错误,直到我刚刚才发现他的玄机
先上代码:
import torch.nn as nn
class A():#nn.Module
def __init__(self,a="MaXh",b=100):
super(A, self).__init__()
self.name = a
self.age = b
def forward(self,name2):
print(f"Hello, I'm{self.name} and i am {self.age} yearsold","\n","my girlfriend is",name2)
return name2
class B:
def __init__(self,c=0):
self.a = A()
def forward(self,name2):
out=self.a(name2)#看似有语法错误,不应该是out=self.a.forward(name2) ?
return out
b = B()
b.forward("ma")
运行结果:报错
但是,如果让A类继承nn.Module:
import torch.nn as nn
class A(nn.Module):#
def __init__(self,a="MaXh",b=100):
super(A, self).__init__()
self.name = a
self.age = b
def forward(self,name2):
print(f"Hello, I'm{self.name} and i am {self.age} yearsold","\n","my girlfriend is",name2)
return name2
class B:
def __init__(self,c=0):
self.a = A()
def forward(self,name2):
out=self.a(name2)#看似有语法错误,不应该是out=self.a.forward(name2) ?
return out
b = B()
b.forward("ma")
代码中看似有语法错误的那一行就能正常运行:
Hello, I'mMaXh and i am 100 yearsold
my girlfriend is ma
'ma'
Why?
A
类没有继承自nn.Module
,因此它不是一个PyTorch的模块,而是一个普通的Python类。在这种情况下,forward
方法只是一个普通的实例方法,并不具有特殊的含义。self.a(name2)
时,self.a
是一个A
类的实例,而不是一个PyTorch模块,所以它不会自动调用forward
方法。因此,你会得到一个TypeError
,因为你不能像调用函数一样调用一个非函数的对象。A
类继承自nn.Module
,它成为了一个PyTorch的模块。PyTorch模块中的forward
方法具有特殊的含义,用于定义模块的前向传播逻辑。self.a(name2)
时,由于self.a
是一个PyTorch模块,它会自动调用其forward
方法,从而正常运行代码并输出结果。nn.Module
有 8 个属性,都是OrderDict
(有序字典)。在 LeNet 的__init__()
方法中会调用父类nn.Module
的__init__()
方法,创建这 8 个属性。
def __init__(self):
"""
Initializes internal Module state, shared by both nn.Module and ScriptModule.
"""
torch._C._log_api_usage_once("python.nn_module")
self.training = True
self._parameters = OrderedDict()
self._buffers = OrderedDict()
self._backward_hooks = OrderedDict()
self._forward_hooks = OrderedDict()
self._forward_pre_hooks = OrderedDict()
self._state_dict_hooks = OrderedDict()
self._load_state_dict_pre_hooks = OrderedDict()
self._modules = OrderedDict()
其中比较重要的是parameters
和modules
属性。
在 LeNet 的__init__()
中创建了 5 个子模块,nn.Conv2d()
和nn.Linear()
都是 继承于nn.module
,也就是说一个 module 都是包含多个子 module 的。
class LeNet(nn.Module):
# 子模块创建
def __init__(self, classes):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, classes)
...
...
...
当调用net = LeNet(classes=2)
创建模型后,net
对象的 modules 属性就包含了这 5 个子网络模块。
def __setattr__(self, name, value):
def remove_from(*dicts):
for d in dicts:
if name in d:
del d[name]
params = self.__dict__.get('_parameters')
if isinstance(value, Parameter):
if params is None:
raise AttributeError(
"cannot assign parameters before Module.__init__() call")
remove_from(self.__dict__, self._buffers, self._modules)
self.register_parameter(name, value)
elif params is not None and name in params:
if value is not None:
raise TypeError("cannot assign '{}' as parameter '{}' "
"(torch.nn.Parameter or None expected)"
.format(torch.typename(value), name))
self.register_parameter(name, value)
else:
modules = self.__dict__.get('_modules')
if isinstance(value, Module):
if modules is None:
raise AttributeError(
"cannot assign module before Module.__init__() call")
remove_from(self.__dict__, self._parameters, self._buffers)
modules[name] = value
elif modules is not None and name in modules:
if value is not None:
raise TypeError("cannot assign '{}' as child module '{}' "
"(torch.nn.Module or None expected)"
.format(torch.typename(value), name))
modules[name] = value
...
...
...
在这里判断 value 的类型是Parameter
还是Module
,存储到对应的有序字典中。
这里nn.Conv2d(3, 6, 5)
的类型是Module
,因此会执行modules[name] = value
,key 是类属性的名字conv1
,value 就是nn.Conv2d(3, 6, 5)
。