shap 명목형 변수 표기 방법 바꾸기

    목차
반응형

1. 배경 상황

테이블 형태 데이터를 가지고 예측 모델을 만들 때 보통 명목형(object) 변수를 라벨 인코딩 혹은 원-핫 인코딩을 수행하여 모델이 학습할 수 있도록 변환한다. 이후 shap을 시각화하면 각 컬럼이 어떤 값이어서 영향을 줬는지에 대한 해석이 어려워진다.

예를 들어, 모든 컬럼값이 명목형 변수인 버섯 데이터의 식용 여부를 예측하는 모델을 만들고 shap value를 계산했을 때 아래처럼 확인이 된다.

더보기
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split

df_encoded = df.copy()

for column in df.select_dtypes('object'):
    df_encoded[column] = LabelEncoder().fit_transform(df_encoded[column])

y = df_encoded['class']
y = y.astype('category').cat.codes
X = df_encoded.drop('class', axis=1)

이걸 해석하기 위해서는 `gill-color=0`이 뭔지를 알아야 한다. 모델을 만든 사람이 아니거나 각 번호에 해당하는 값을 적어두지 않는다면 파악이 어렵다.

 

2. 원래 값으로 바꿔넣기

더보기
import shap

explainer = shap.TreeExplainer(model)
shap_values = explainer(X_test)

idx = 17
want_to_show = shap_values[idx]

보고자 하는 shap value에 대해 print를 찍어보면 아래와 같이 나온다.

여기서 `.data` 부분을 원래 값으로 바꿔주면 된다!

idx = 17
want_to_show = shap_values[idx]
want_to_show.data = df.drop(columns='class').loc[X_test.iloc[idx].name].to_numpy()
shap.plots.waterfall(want_to_show)

이렇게 하면 `gill-color=b`이어서 +4.2만큼의 영향을 주었다고 해석이 가능해진다.

 

 

3. 전체 코드

github

728x90
반응형