我在将pytroch模型.pth转成onnx文件时出现问题,模型中包含Attention模块,希望有时间的帮忙看一下,感谢
Traceback (most recent call last):
File "/Users/zhukaili/PycharmProjects/gujiocr/trs.py", line 1, in
from deep_ocr import transfer
File "/Users/zhukaili/PycharmProjects/gujiocr/deep_ocr/transfer.py", line 118, in
'output' : {0 : 'batch_size'}})
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/onnx/init.py", line 276, in export
custom_opsets, enable_onnx_checker, use_external_data_format)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/onnx/utils.py", line 94, in export
use_external_data_format=use_external_data_format)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/onnx/utils.py", line 701, in _export
dynamic_axes=dynamic_axes)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/onnx/utils.py", line 459, in _model_to_graph
use_new_jit_passes)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/onnx/utils.py", line 420, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/onnx/utils.py", line 380, in _trace_and_get_graph_from_model
torch.jit._get_trace_graph(model, args, strict=False, _force_outplace=False, _return_inputs_states=True)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/jit/_trace.py", line 1139, in _get_trace_graph
outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
result = self.forward(*input, **kwargs)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/jit/_trace.py", line 130, in forward
self._force_outplace,
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/jit/_trace.py", line 116, in wrapper
outs.append(self.inner(*trace_inputs))
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in _slow_forward
result = self.forward(*input, **kwargs)
File "/Users/zhukaili/PycharmProjects/gujiocr/deep_ocr/model.py", line 64, in forward
prediction = self.Prediction(contextual_feature.contiguous(), batch_max_length=self.opt.batch_max_length)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/nn/modules/module.py", line 887, in _call_impl
result = self._slow_forward(*input, **kwargs)
File "/opt/anaconda3/envs/pixelLink/lib/python3.6/site-packages/torch/nn/modules/module.py", line 860, in slow_forward
result = self.forward(*input, **kwargs)
File "/Users/zhukaili/PycharmProjects/gujiocr/deep_ocr/modules/prediction.py", line 38, in forward
targets = torch.LongTensor(batch_size).fill(0).to(device) # [GO] token
IndexError: slice() cannot be applied to a 0-dim tensor.