最近在学习ios开发,需要将一个pytorch模型转换为coreml模型,过程中需要用torch.jit.trace转换出torchscript。
教程中给出的torch.jit.trace都有两个参数,如下:
example_input = torch.rand(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
但是modle.py中的forward()只有一个参数,如下:
def forward(self):
# first transfer and removal
self.z_non_makeup_c = self.enc_content(self.non_makeup)
self.z_non_makeup_s = self.enc_semantic(self.non_makeup_parse)
self.z_non_makeup_a = self.enc_makeup(self.non_makeup)
self.z_makeup_c = self.enc_content(self.makeup)
self.z_makeup_s = self.enc_semantic(self.makeup_parse)
self.z_makeup_a = self.enc_makeup(self.makeup)
# warp makeup style
self.mapX, self.mapY, self.z_non_makeup_a_warp, self.z_makeup_a_warp = self.transformer(self.z_non_makeup_c,
self.z_makeup_c,
self.z_non_makeup_s,
self.z_makeup_s,
self.z_non_makeup_a,
self.z_makeup_a)
# makeup transfer and removal
self.z_transfer = self.gen(self.z_non_makeup_c, self.z_makeup_a_warp)
self.z_removal = self.gen(self.z_makeup_c, self.z_non_makeup_a_warp)
# rec
self.z_rec_non_makeup = self.gen(self.z_non_makeup_c, self.z_non_makeup_a)
self.z_rec_makeup = self.gen(self.z_makeup_c, self.z_makeup_a)
# second transfer and removal
self.z_transfer_c = self.enc_content(self.z_transfer)
# self.z_non_makeup_s = self.enc_semantic(self.non_makeup_parse)
self.z_transfer_a = self.enc_makeup(self.z_transfer)
self.z_removal_c = self.enc_content(self.z_removal)
# self.z_makeup_s = self.enc_semantic(self.makeup_parse)
self.z_removal_a = self.enc_makeup(self.z_removal)
# warp makeup style
self.mapX2, self.mapY2, self.z_transfer_a_warp, self.z_removal_a_warp = self.transformer(self.z_transfer_c,
self.z_removal_c,
self.z_non_makeup_s,
self.z_makeup_s,
self.z_transfer_a,
self.z_removal_a)
# makeup transfer and removal
self.z_cycle_non_makeup = self.gen(self.z_transfer_c, self.z_removal_a_warp)
self.z_cycle_makeup = self.gen(self.z_removal_c, self.z_transfer_a_warp)
运行之后报错:forward() takes 1 positional argument but 2 were given
我现在完全不知道该怎么修改,请求帮助。
你在哪儿调用的forward啊?
example_input = torch.rand(1, 3, 224, 224)
traced_model = torch.jit.trace(model, example_input)
这两个也没调用他呀