请问Pytorch中怎样给变量添加约束?比如在训练时时刻保持:
1、参数矩阵主对角线元素恒为0;
2、参数矩阵所有元素非负。
我目前看到TensorFlow中的add_weight()中有一个constraint参数可以加约束,不知道Pytorch中有没有类似的方法?
首先编写模型结构:
class Model(nn.Module):
def init(self):
super(Model,self).init()
self.l1=nn.Linear(100,50)
self.l2=nn.Linear(50,10)
self.l3=nn.Linear(10,1)
self.sig=nn.Sigmoid()
def forward(self,x):
x=self.l1(x)
x=self.l2(x)
x=self.l3(x)
x=self.sig(x)
return(x
然后编写限制权重范围的类:
class weightConstraint(object):
def init(self):
pass
def __call__(self,module):
if hasattr(module,'weight'):
print("Entered")
w=module.weight.data
w=w.clamp(0.5,0.7) #将参数范围限制到0.5-0.7之间
module.weight.data=w