怎么分割tensor按指定条件?

img

这个tensor的shape是(400),怎么以里面的2分割tensor成若干个tensor啊?

需要借助numpy

import torch
import numpy as np
t=torch.tensor([1,2,3,4,5,6,8,2,56,5,2,10])
t_numpy=t.numpy()
index=np.argwhere(t_numpy==2)
print(index)#在tensor中的索引,可以根据该索引对tensor切片


我觉得你这个就是一维向量,遍历这个向量,遇到一个2你就新建一个张量把这个数之前的数存下来,不知道我说的可以不?