def _call_impl(self, *input, **kwargs):
for hook in itertools.chain(
_global_forward_pre_hooks.values(),
self._forward_pre_hooks.values()):
result = hook(self, input)
if result is not None:
if not isinstance(result, tuple):
result = (result,)
input = result
if torch._C._get_tracing_state(): # 我对这句话不理解,if中为什么没有使用输入的形参?
result = self._slow_forward(*input, **kwargs)
else:
result = self.forward(*input, **kwargs) # 错误提示跳转到了这里
if torch._C._get_tracing_state():就是用来判断是否使用JIT来跟踪模型。像pytorch构建一个计算图就会用到一个中央的context去管理变量,而JIT跟踪模型也类似,比如用以下方式标记了这个module需要用JIT跟踪:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.jit import ScriptModule, script_method, trace
class MyScriptModule(ScriptModule):
def __init__(self):
super(MyScriptModule, self).__init__()
# trace produces a ScriptModule's conv1 and conv2
self.conv1 = trace(nn.Conv2d(1, 20, 5), torch.rand(1, 1, 16, 16))
self.conv2 = trace(nn.Conv2d(20, 20, 5), torch.rand(1, 20, 16, 16))
@script_method
def forward(self, input):
x = F.relu(self.conv1(input))
x = F.relu(self.conv2(input))
return x
而torch跟踪代码后同样也会在c++后端中存入这个MyScriptModule是否需要跟踪的信息,因此只是需要调用_C的_get_tracing_state()就可以判断是否需要跟踪这个module,而不需要使用forward中传入的形参。
_get_tracing_state()
是一个获取跟踪状态的方法
这个方法是 torch._C这个对象的。
随机抽样 Random sampling
torch.manual_seed
torch.manual_seed(seed)
设定生成随机数的种子,并返回一个 torch._C.Generator 对象.
参数: seed (int or long) – 种子.