使用torch.autograd.Function自定义激活函数时,如何在父类中对子类传入参数?

这里贴上yolov5的一个高效自定义激活函数的源码:

class MemoryEfficientMish(nn.Module):
    class F(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)    # 表示forward()的结果要存起来,以后给backward()
            return x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))

        # grad_output是最终object对的forward()输出的导数, 也就是理解为上一层求导的结果
        # ctx是一个元祖
        @staticmethod
        def backward(ctx, grad_output):    # grad_output上一层求导的结果
            x = ctx.saved_tensors[0]       # ctx.saved_tensors得到之前forward()存的结果
            sx = torch.sigmoid(x)
            fx = F.softplus(x).tanh()
            return grad_output * (fx + x * sx * (1 - fx * fx))

    def forward(self, x):
        return self.F.apply(x)

这里我如何传入一个beta参数到子类F中去,也就是在父类MemoryEfficientMish中传入一个参数到子类F中,使得可以控制子类的forward与backward函数的返回。一个设想的伪代码实现如下:

class MemoryEfficientMish(nn.Module):

  # 可以传入参数beta,默认为1,也就是简化的版本
    def __init__(self, beta=1.):
        super().__init__()
        self.beta = beta

    class F(torch.autograd.Function):
        @staticmethod
        def forward(ctx, x):
            ctx.save_for_backward(x)    # 表示forward()的结果要存起来,以后给backward()
            
            # 传入参数beta使得可以控制返回函数
            if self.beta != 1.0:
                 return ...
            return x.mul(torch.tanh(F.softplus(x)))  # x * tanh(ln(1 + exp(x)))

        # grad_output是最终object对的forward()输出的导数, 也就是理解为上一层求导的结果
        # ctx是一个元祖
        @staticmethod
        def backward(ctx, grad_output):    # grad_output上一层求导的结果
            x = ctx.saved_tensors[0]       # ctx.saved_tensors得到之前forward()存的结果
            sx = torch.sigmoid(x)
            fx = F.softplus(x).tanh()

            # 传入参数beta使得可以控制返回函数
            if self.beta != 1.0:
                 return ...
            return grad_output * (fx + x * sx * (1 - fx * fx))

    def forward(self, x):
        return self.F.apply(x)