export failure: ONNX export failed: Couldn't export operator aten::unfold
不知道你这个问题是否已经解决, 如果还没有解决的话:答案:
这个问题的根本原因是 ONNX 无法导出 aten::unfold
运算符。为了解决这个问题,我们可以修改 PyTorch 模型,以使用 ONNX 能够导出的运算符来代替 aten::unfold
。我们可以使用以下步骤解决这个问题:
aten::unfold
运算符。我们可以使用以下代码来查找模型中的所有运算符及其类型:```python import torch from torch.onnx import register_custom_op_symbolic
def print_model_ops(model): for name, module in model.named_modules(): for i, (name, op) in enumerate(module.named_parameters(recurse=False)): print(f"{name} ({module.class.name}.{op.class.name}): {op.cpu().contiguous().view(-1).numpy()}")
for i, (name, op) in enumerate(module.named_buffers(recurse=False)):
print(f"{name} ({module.__class__.__name__}.{op.__class__.__name__}): {op.cpu().contiguous().view(-1).numpy()}")
for i, (name, op) in enumerate(module.named_children()):
print(f"{name} ({module.__class__.__name__}):")
print(op)
for i, (name, op) in enumerate(module.named_modules()):
print(f"{name} ({module.__class__.__name__}):")
print(op)
if hasattr(module, "op"):
print(f"{name} ({module.__class__.__name__}.{module.op.__class__.__name__}): {module.op}")
# Register custom symbolic function for operator
register_custom_op_symbolic("mynamespace::myop", my_custom_symbolic_linear)
model = torch.load("mymodel.pt") print_model_ops(model) ```
如果这个函数输出了 aten::unfold
,表示模型中使用了这个运算符。
aten::unfold
。如果我们需要 aten::unfold
,我们可以手动实现这个函数。以下是一个简单的实现:```python def my_unfold(input, kernel_size, dilation=1, padding=0, stride=1): # Calcuate output shape B, C, D, H, W = input.size() DH, DW = tuple(kernel_size) PD, PH, PW = tuple(padding) SH, SW = tuple(stride) OH = (D + 2 * PD - DH - (DH-1) * (dilation - 1)) // SH + 1 OW = (W + 2 * PW - DW - (DW-1) * (dilation - 1)) // SW + 1
# Pad input
input_pad = torch.nn.functional.pad(input, (PW, PW, PH, PH, PD, PD))
# Reshape input to allow 2D convolution
input_2d = input_pad.view(B*C*D, 1, H+2*PH, W+2*PW)
# Create convolution kernel for each output pixel
idx = torch.arange(OH*OW).reshape(OH, OW)
idh = idx // OW
idw = idx % OW
kh = (idh*SH).reshape(-1, 1, 1, 1) + torch.arange(DH).reshape(1, 1, -1, 1).to(idh.device)
kw = (idw*SW).reshape(-1, 1, 1, 1) + torch.arange(DW).reshape(1, 1, 1, -1).to(idw.device)
kernel = input_2d[:, :, kh, kw]
# Do 2D convolution
output_2d = torch.nn.functional.conv2d(input_2d, kernel, groups=B*C*D)
# Reshape output to original size
output = output_2d.view(B, C, OH, OW)
return output
```
我们需要将所有的 aten::unfold
替换为 my_unfold
。以下是一个简单的脚本,可以自动找到模型中的所有 aten::unfold
运算符,并将它们替换为 my_unfold
:
```python import torch from torch.onnx import SymbolicShapeFinder
class ReplaceUnfold(torch.nn.Module): def init(self): super(ReplaceUnfold, self).init()
def forward(self, x):
return my_unfold(x)
def replace_unfold(model): # Register custom symbolic function for my_unfold def my_unfold_symbolic(g, input, kernel_size, dilation, padding, stride): return g.op("mynamespace::my_unfold", input, kernel_size, dilation, padding, stride)
register_custom_op_symbolic("mynamespace::my_unfold", my_unfold_symbolic)
# Find all aten::unfold ops in the model
symbolic_shape_finder = SymbolicShapeFinder()
symbolic_shape_finder.check_masks = True
symbolic_shape_finder.find_shapes(model, torch.randn((1, 3, 224, 224)))
for node in symbolic_shape_finder.nodes:
if node.kind() == "aten::unfold":
unfold_node = node
unfold_name = unfold_node.outputs()[0].debugName()
unfold_inputs = unfold_node.inputs()
# Replace aten::unfold with my_unfold
my_unfold_node = graph.create("mynamespace::my_unfold", [*unfold_inputs[0:1], *unfold_inputs[1].node().inputs()[1:], *unfold_inputs[2:]], 1)
my_unfold_name = my_unfold_node.outputs()[0].debugName()
unfold_node.outputs()[0].replaceAllUsesWith(my_unfold_node.outputs()[0])
graph.eraseUnusedNodeOutputNames()
my_unfold_node.outputs()[0].copyMetadata(unfold_node.outputs()[0])
my_unfold_node.moveAfter(unfold_node)
graph.eraseOutput(unfold_node.outputs()[0])
graph.insertOutput(my_unfold_node.outputs()[0], 0)
graph.node(unfold_node).destroy()
model = torch.load("mymodel.pt") replace_unfold(model) torch.onnx.export(model, (torch.randn((1, 3, 224, 224)),), "mymodel.onnx") ```
```python import onnx
model = onnx.load("mymodel.onnx") onnx.checker.check_model(model) ```