Lesson Video:


This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, wwf, and fastinference currently running at the time of writing this:

  • fastai: 2.1.10
  • fastcore: 1.3.13
  • wwf: 0.0.8
  • fastinference: 0.0.35

ClassConfusion is similar to SHAP in the sense of we get a view into how our model is behaving. ClassConfusion speficially will plot how the various variable distributions differed for our confused classes. For now this only works in the Colab environment. Let's look:

We'll train an ADULTS model again

from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
dep_var = 'salary'
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = IndexSplitter(list(range(800,1000)))(range_of(df))
to = TabularPandas(df, procs, cat_names, cont_names, y_names="salary", splits=splits)
dls = to.dataloaders()
learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)
learn.fit(1, 1e-2)
epoch train_loss valid_loss accuracy time
0 0.369050 0.386868 0.820000 00:08

Now let's bring in ClassConfusion

from fastinference.class_confusion import *

We'll build an instance of ClassConfusion, optionally passing in any variables we want to use, any test dataloaders we want, and whether our list of classes is ordered:

dl = dls.test_dl(df.iloc[:100])
classlist = ['<50k','>=50k']
ClassConfusion(learn, dl=dl, classlist=classlist)
  0%|          | 0/9 [00:00<?, ?it/s]
 11%|█         | 1/9 [00:00<00:03,  2.04it/s]
 22%|██▏       | 2/9 [00:00<00:03,  2.05it/s]
 33%|███▎      | 3/9 [00:01<00:02,  2.12it/s]
 44%|████▍     | 4/9 [00:01<00:02,  2.12it/s]
 56%|█████▌    | 5/9 [00:02<00:01,  2.16it/s]
 67%|██████▋   | 6/9 [00:02<00:01,  2.07it/s]
 78%|███████▊  | 7/9 [00:04<00:01,  1.33it/s]
 89%|████████▉ | 8/9 [00:05<00:00,  1.06it/s]
100%|██████████| 9/9 [00:07<00:00,  1.28it/s]
<fastinference.class_confusion.ClassConfusion at 0x7ff1084bda58>

We can now look into each variable and see what the distributions of the confused classes were and how they differed from the entire test set