SHAP
is a library for interpreting neural networks, and we can use it to help us with tabular data too! I wrote a library called FastSHAP
which ports over the usabilities of it. Let's do a walkthrough of what each does and how it works.
- Note: I only have it ported for tabular data
from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dep_var = 'salary'
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = IndexSplitter(list(range(800,1000)))(range_of(df))
to = TabularPandas(df, procs, cat_names, cont_names, y_names="salary", splits=splits)
dls = to.dataloaders()
learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)
learn.fit(1, 1e-2)
Now let's go through some example usage!
from fastinference.tabular import *
And now we'll make a ShapInterpretation
object. It expects your Learner
along with some test data to look at and any keywords that SHAP
can use. If you don't pass anything in it will use a subset of your validation data:
exp = ShapInterpretation(learn, df.iloc[:100])
Let's look at the various methods available to us:
df.iloc[10]
As we can see, our y
value is '<50k'. Let's look at how the model performed and what could have been influencing our result into an opposite direction
exp.decision_plot(class_id=0, row_idx=10)
And now if we turn this the other way around:
exp.decision_plot(class_id=1, row_idx=10)
We can visually see what variables were having the largest impact on the model. (Note, it shows the pre-processed datapoints)
exp.dependence_plot('age', class_id=0)
exp.dependence_plot('age', class_id=1)
exp.force_plot(class_id=1)
exp.summary_plot()
exp.waterfall_plot(row_idx=10, class_id='<50k')