如何做单个样本的SHAP的决策图和力图

如何做单个样本的SHAP的决策图和力图,类似这片文章中的做法

img

  • 这篇博客: SHAP的介绍和应用(附代码)中的 SHAP值的常规用法 部分也许能够解决你的问题, 你可以仔细阅读以下内容或跳转源博客中阅读:
  • !pip install shap
    import warnings
    warnings.filterwarnings("ignore")
    import shap
    shap.initjs()
    
    Collecting shap
      Downloading shap-0.40.0-cp37-cp37m-manylinux2010_x86_64.whl (564 kB)
    [K     |████████████████████████████████| 564 kB 3.6 MB/s 
    [?25hRequirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1)
    Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (1.0.1)
    Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0)
    Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5)
    Collecting slicer==0.0.7
      Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
    Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap) (21.3)
    Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3)
    Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5)
    Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2)
    Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>20.9->shap) (3.0.6)
    Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0)
    Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0)
    Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9)
    Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2)
    Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0)
    Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.1.0)
    Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (3.0.0)
    Installing collected packages: slicer, shap
    Successfully installed shap-0.40.0 slicer-0.0.7
    

    在这里插入图片描述

    #如果数据量大,这个运行的会非常慢
    explainer = shap.TreeExplainer(clf)
    shap_values = explainer.shap_values(train.drop(columns=['Survived'])) #获取shap value
    np.array(shap_values).shape #看一下shap value的dim
    
    (2, 712, 9)
    

    shap值是一个三维的数据。一个样本有两个shap值。

    第一个维度控制的是选择哪个类别的shap值,第一个维度是0表示0(negative)类,第一个维度是1表示1(positive)类。

    后面两个维度就是每个样本和每个特征的shap值。

    (shap_values[0] == -1* shap_values[1]).all()
    
    True
    

    可以看到0类的shap值和1类的shap值是相反的