@torch.jit.script
def a(x: torch.Tensor):
return (x >= 0).to(x)
在jupyter跑的时候报bug:ValueError: Compiled functions don't support annotations
感觉是@torch.jit.script的问题,但不清楚是哪里有问题,请帮忙看看,谢谢!!
这个错误是因为TorchScript编译器不支持函数参数和返回值的注释。可以尝试移除函数定义中的注释,或者在定义中使用类型提示,而不是注释。例如,将你的代码修改为:
import torch
@torch.jit.script
def a(x):
# 不要在这里使用注释
return (x >= 0).to(x)
注意,这里的x没有类型注释,TorchScript编译器会根据参数的运行时类型进行推断。如果需要使用类型注释,可以在函数内部使用torch.jit.annotate函数,例如:
import torch
@torch.jit.script
def a(x):
# 在函数内部使用类型注释
x: torch.Tensor
return (x >= 0).to(x)
不知道你这个问题是否已经解决, 如果还没有解决的话: