在pytorch种如何解决难易样本的分类问题

我查阅资料,发现可以使用focal loss来解决,但是pytorch自带的包里没有这个函数,目前我使用的是交叉熵损失函数,如何将这个损失函数换成focal loss呢?

img


focal loss在网上的多分类代码很多,但我就是不会替换,怎么换呢?
或者不用focal loss,其他的损失函数也行,只要能解决难易样本带来的识别问题就行

来自kaggle上面的这个实现 可以直接用


class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, logits=False, reduce=True):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.logits = logits
        self.reduce = reduce

    def forward(self, inputs, targets):
        if self.logits:
            BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduce=False)
        else:
            BCE_loss = F.binary_cross_entropy(inputs, targets, reduce=False)
        pt = torch.exp(-BCE_loss)
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduce:
            return torch.mean(F_loss)
        else:
            return F_loss