torch.matmul()

在使用torch.matmul时遇到cuda out of memory如何解决,共四维,最后两个维度分别为65536 8,8 65536

内存爆炸,还能有什么方法呢?
(1)扩大显存
(2)减小 batch