想问一下这段代码中的print的*号是什么意思哇,print那一行代码可以解释一下嘛?没看懂【0】的是啥意思?求解答~
def my_init(m):
if type(m) == nn.Linear:
print("Init", *[(name, param.shape)
for name, param in m.named_parameters()][0])
nn.init.uniform_(m.weight, -10, 10)
m.weight.data *= m.weight.data.abs() >= 5
net = nn.Sequential(nn.Linear(4, 8), nn.ReLU(), nn.Linear(8, 3))
X = torch.rand(size=(2, 4))
net(X)
net.apply(my_init)
net[0].weight[:2]
当*
号放在序列类型(列表、元组、集合)之前,作为一个函数的实际参数时,会把这个序列的每一项当作单个参数传入该函数。(相当于去掉它们的大/中/小括号)
>>> print(1,2,(3,4,5))
1 2 (3, 4, 5)
>>> print(1,2,*(3,4,5)) #相当于print(1,2,3,4,5)
1 2 3 4 5
加星号的列表是用列表生成式表示的,列表的每一项都是(name, param.shape)
元组。在列表后面加一个[0]
,表示取该列表的第一项(name, param.shape)
元组,并用星号把该元组作为多个参数传入print()
。
望采纳!