pytorch中piqa库里面的SSIM指标

输入只能是三通道吗,这里是多光谱,有多个通道,直接调用会报错因为通道不匹配
目前尝试的办法,第一种,把SSIM里面的channel改成31,后边loss.backward()时,会出现修改变量的问题
第二种,遍历每个通道,计算然后累加,同样也不行,通道数必须是3啊??