train_data = get_train_data(unit=list(range(52)))
train_g = construct_graph(train_data, neighbor=5)
train_g = train_g.to("cuda:0")#此行报错
train_g.edata['h_e'] = train_g.edata['h_e'].cuda()
train_data = torch.Tensor(train_data)
train_data = train_data.cuda()
这个问答https://datascience.stackexchange.com/questions/54907/model-cuda-in-pytorch写的很好,或许你可以尝试替换成
train_g.cuda()
或者
train_g.to(torch.device('cuda:0'))