pytorch自定义loss函数

大家好,在实现自定义的语义分割的loss函数的时候,遇到了问题,请大家帮忙一下,
这个自定义的loss函数的做法是,根据真实label(batchsize,h,w)的每个pixel的对应的class值,在网络的输出的预测值(batch-size,num-class,h,w)中,选出class对应的那个预测值,得到的就是真实label的每个pixel的class对应的预测值(batchsize,h,w),现在我自己按照我下面的方式想实现上述的目的,但是在pytorh中的loss函数,想要能够反向传播就必须所有的值都是Variable,现在发现的问题就在pytorch的tensor中的flatten函数会有问题,想问问大家有没有什么方式能够在tensor的方式下实现。

图片说明

看有些帖子里说自定义loss函数里面的数学操作都要用torch实现,如上图numpy可以吗