pyg包中GATConv的Static graphs not supported in 'GATConv'问题

用torch_geometrix.nn里的GATConv遇到Static graphs not supported in 'GATConv'问题,如图所示

img

我通过debug追进去看到源码中加了断言,如图所示

img

要求输入的数据x必须是二维的才能正常运行。
但我想要输入的数据形状是(batch_size, num_nodes, node_features),想请问:我是必须把batch_size设为1,每次输入一个(num_nodes, node_features)形状的数据吗?还是说我通过某种方式把它变成(batch_size, num_nodes*node_features)形状?还是说有其他更好的解决方法?