pytorch批量返回tensor特定值索引

问题遇到的现象和发生背景

在pytorch,pyg中,
想从1万多个边(all_edges:tensor shape[2,10640])找到指定64个边(target_nodes:tensor shape[2,64])对应的索引,代码怎么写会快一点?

问题相关代码,请勿粘贴截图

idx = []
for u, v in target_edges:
idx.append(all_egdes.index([u, v]))

我的解答思路和尝试过的方法

尝试过两个tensor都转为list,遍历64个边,对每一个边,用.index()找到它在1万个边的位置,这样做太慢了,有没有可以批量处理的方法?

法一:遍历64个边,每边 用 repeat 变成 all_edges 的形状,然后利用“==” ,取出为 True 的索引。
法二:将 all_edges:tensor shape[2,10640] 增加一个维度变成 all_edges:tensor shape[2, 1, 10640] ,再 repeat,变成 all_edges:tensor shape[2, 64, 10640]
将 target_nodes:tensor shape[2,64] 变成 target_nodes:tensor shape[2,64, 1],再 repeat,变成 target_nodes:tensor shape[2,64, 10640],
两者 "==" 取出为 True 的索引,反着推就能推出要得到的2维索引了

把all_edges先转为dict会不会快些?字典最快