sklearn实现recal时报错?

问题遇到的现象和发生背景

使用CNN分类图片(GPU),用sklearn导入recall函数评价网络的分类性能。

遇到的现象和发生背景,请写出第一个错误信息

报错:can‘t convert cuda:0 device type tensor to numpy.Use Tensor.cpu() to copy the tensor to host memory

用代码块功能插入代码,请勿粘贴截图。 不用代码块回答率下降 50%
from sklearn.metrics import precision_score, recall_score, f1_score


            model.eval()            
            pbar = tqdm(enumerate(val_loader), total=len(val_loader))
            with torch.no_grad(): #验证阶段程序不计算参数梯度,解决GPU内存满了的问题
              for batch_idx, (data, target) in pbar:
                  data, target = data.cuda(), target.cuda()
                  output = model(data)
                  prediction = torch.argmax(output, 1)# 预测值
                  correct += (prediction == target).sum().float()
                
                  total += len(target)
                  r = recall_score(target, prediction, average='macro)  #导入recall
                  description = f'epoch {epoch}'
                  pbar.set_description(description)

              val_acc_str = 'Accuracy: %f' % ((correct / total).cpu().detach().data.numpy())
              val_acc.append((correct / total).cpu().detach().data.numpy())
              scheduler.step((correct / total).cpu().detach().data.numpy())
              print('r', r,'epoch', epoch, 'train_acc', train_acc_str, 'val_acc', val_acc_str)

运行结果及详细报错内容

img

我的解答思路和尝试过的方法,不写自己思路的,回答率下降 60%

网上说先将 tensor 转换到 CPU ,因为 Numpy 是 CPU-only。我试过将报错代码self.numpy()改为self.cpu().numpy(),还是报错。也试过降低numpy版本至1.1.9,也还是报错。

我想要达到的结果,如果你需要快速回答,请尝试 “付费悬赏”

想解决报错

望采纳


在 PyTorch 中,当你使用 cpu() 方法将 Tensor 从 GPU 上移动到 CPU 上后,就不能再使用 detach() 方法。更多细节可以参考 PyTorch 官网的文档:https://pytorch.org/docs/stable/tensors.html


在 PyTorch 中,你可以通过以下两种方式将 Tensor 的值转换为 Python 的标量:

  • 使用 Tensor 的 item() 方法,例如 x.item(),其中 x 是一个 Tensor。
  • 使用 Numpy 库的 numpy() 方法,例如 x.numpy(),其中 x 是一个 Tensor。

可以试试如下修改

val_acc_str = 'Accuracy: %f' % ((correct / total).cpu().numpy())
val_acc.append((correct / total).cpu().clone())
scheduler.step((correct / total).cpu().clone())
print('r', r,'epoch', epoch, 'train_acc', train_acc_str, 'val_acc', val_acc_str)