如何做单个样本的SHAP的决策图和力图,类似这片文章中的做法
!pip install shap
import warnings
warnings.filterwarnings("ignore")
import shap
shap.initjs()
Collecting shap
Downloading shap-0.40.0-cp37-cp37m-manylinux2010_x86_64.whl (564 kB)
[K |████████████████████████████████| 564 kB 3.6 MB/s
[?25hRequirement already satisfied: scipy in /usr/local/lib/python3.7/dist-packages (from shap) (1.4.1)
Requirement already satisfied: scikit-learn in /usr/local/lib/python3.7/dist-packages (from shap) (1.0.1)
Requirement already satisfied: cloudpickle in /usr/local/lib/python3.7/dist-packages (from shap) (1.3.0)
Requirement already satisfied: numpy in /usr/local/lib/python3.7/dist-packages (from shap) (1.19.5)
Collecting slicer==0.0.7
Downloading slicer-0.0.7-py3-none-any.whl (14 kB)
Requirement already satisfied: packaging>20.9 in /usr/local/lib/python3.7/dist-packages (from shap) (21.3)
Requirement already satisfied: tqdm>4.25.0 in /usr/local/lib/python3.7/dist-packages (from shap) (4.62.3)
Requirement already satisfied: pandas in /usr/local/lib/python3.7/dist-packages (from shap) (1.1.5)
Requirement already satisfied: numba in /usr/local/lib/python3.7/dist-packages (from shap) (0.51.2)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /usr/local/lib/python3.7/dist-packages (from packaging>20.9->shap) (3.0.6)
Requirement already satisfied: llvmlite<0.35,>=0.34.0.dev0 in /usr/local/lib/python3.7/dist-packages (from numba->shap) (0.34.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.7/dist-packages (from numba->shap) (57.4.0)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2018.9)
Requirement already satisfied: python-dateutil>=2.7.3 in /usr/local/lib/python3.7/dist-packages (from pandas->shap) (2.8.2)
Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.7/dist-packages (from python-dateutil>=2.7.3->pandas->shap) (1.15.0)
Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (1.1.0)
Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.7/dist-packages (from scikit-learn->shap) (3.0.0)
Installing collected packages: slicer, shap
Successfully installed shap-0.40.0 slicer-0.0.7
#如果数据量大,这个运行的会非常慢
explainer = shap.TreeExplainer(clf)
shap_values = explainer.shap_values(train.drop(columns=['Survived'])) #获取shap value
np.array(shap_values).shape #看一下shap value的dim
(2, 712, 9)
shap值是一个三维的数据。一个样本有两个shap值。
第一个维度控制的是选择哪个类别的shap值,第一个维度是0表示0(negative)类,第一个维度是1表示1(positive)类。
后面两个维度就是每个样本和每个特征的shap值。
(shap_values[0] == -1* shap_values[1]).all()
True
可以看到0类的shap值和1类的shap值是相反的