Extending the `Interpretation` class with the `show_at` method

This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, and wwf currently running at the time of writing this:

  • fastai: 2.2.5
  • fastcore: 1.3.19
  • wwf: 0.0.9

My problem

I often want to look at the predictions of specific items in the validation set, to see if I can find patterns in the errors made by the model. This notebook extends the Interpretation object created on top of a learner to add a shortcut method, show_at that does exactly this. Let's use as example the "is a cat" classifier, as trained in the fastbook

from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'

def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
    path, get_image_files(path), valid_pct=0.2, seed=42,
    label_func=is_cat, item_tfms=Resize(224))

learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
Downloading: "https://download.pytorch.org/models/resnet34-333f7ec4.pth" to /root/.cache/torch/hub/checkpoints/resnet34-333f7ec4.pth

epoch train_loss valid_loss error_rate time
0 0.168753 0.024423 0.008119 00:50
epoch train_loss valid_loss error_rate time
0 0.049855 0.026575 0.009472 00:49

Creating a ClassificationInterpretation from the learner will give us shorcuts to interpret its results. By default, the Intepretation object will be created around the validation dataloader

interp = ClassificationInterpretation.from_learner(learn)

Let's say we are interested in the predictions for the first item of the validation set. Normally, what I do to visualize both the item and predictions for a single item is to first show it in the dataset and then see its prediction in interp.preds or interp.decoded

item_idx = 0
show_at(dls.valid.dataset, item_idx)
print(interp.decoded[item_idx])
tensor(1)

It would be much easier if we could just call show_at as a method of the interp object, and plot the results in the same way that learn.show_results does it. Here's a piece of code by Zach Mueller, taken from the discord community, to achieve this goal.

Interpretation.show_at[source]

Interpretation.show_at(idxs, max_n=9, ctxs=None, show=True)

Show predictions on the items at idxs

@patch
@delegates(TfmdDL.show_results)
def show_at(self:Interpretation, idxs, **kwargs):
    "Show predictions on the items at `idxs`"
    inp, _, targ, dec, _ = self[idxs]
    self.dl.show_results((inp, dec), targ, **kwargs)

As you can see, the code of the show_at method is pretty simple. It uses the fastcore's @patch decorator to add the method to the class Interpretation (therefore the self:Interpretation as first argument), and the @delegates decorator to replace **kwargs in the signature of the method with the arguments of show_results. All the function does is grab the inputs, targets and decoded predictions from the corresponding attributes of the Interpretation object, and call show_results from its dataloader. By default, when the Interpretation object is created using the method from_learner, this dataloader is the validation dataloader used in the training.

Grabbing the inputs, targets and decoded predictions is done by calling self[idxs]. For that to work, we need a __getitem__ method in the class Interpretation. That method calls getattr for every indexable attribute within Interpretation (i.e, inputs, predictions, decoded predictions, targets, and losses).

Interpretation.__getitem__[source]

Interpretation.__getitem__(idxs)

Get the inputs, preds, targets, decoded outputs, and losses at idxs

@patch
def __getitem__(self:Interpretation, idxs):
    "Get the inputs, preds, targets, decoded outputs, and losses at `idxs`"
    if not is_listy(idxs): idxs = [idxs]
    attrs = 'inputs,preds,targs,decoded,losses'
    res = L([getattr(self, attr)[idxs] for attr in attrs.split(',')])
    return res

Let's see now an example of how show_at work for a single item, in this case, the first element of the validation dataset.

interp.show_at(0)

Here's another example to show the predictions of multiple items, namely the three elements of the validation set with the largest loss

interp.show_at(interp.top_losses(3)[1])

Additionaly, the method __getitem__ is also very useful when you want to know everything (data, prediction, decoded prediction, actual label) of a specific item of the dataset.

interp[:3]
(#5) [TensorImage([[[[ 1.5125,  1.4954,  1.4954,  ..., -1.2445, -1.3473, -1.2617],
          [ 1.4612,  1.5125,  1.4612,  ..., -1.2617, -1.2788, -1.2103],
          [ 1.3584,  1.4098,  1.4269,  ..., -1.2445, -1.2274, -1.2103],
          ...,
          [-2.0152, -2.0152, -1.9980,  ..., -1.8268, -1.8268, -1.7240],
          [-1.9809, -2.0152, -2.0323,  ..., -1.7583, -1.7925, -1.8097],
          [-2.0152, -2.0152, -2.0323,  ..., -1.8782, -1.7925, -1.7925]],

         [[ 1.7108,  1.7283,  1.6933,  ..., -1.1253, -1.2479, -1.3004],
          [ 1.6408,  1.6583,  1.6408,  ..., -1.1078, -1.1779, -1.1429],
          [ 1.5182,  1.5532,  1.6057,  ..., -1.1954, -1.2129, -1.1604],
          ...,
          [-2.0182, -1.9307, -1.9132,  ..., -1.6856, -1.7206, -1.6506],
          [-1.9657, -1.9307, -1.9482,  ..., -1.6681, -1.6506, -1.6681],
          [-1.9832, -1.9482, -1.9657,  ..., -1.7906, -1.5805, -1.5630]],

         [[ 1.7860,  1.9254,  1.9603,  ..., -0.9853, -1.1247, -1.0550],
          [ 1.8034,  1.8383,  1.8905,  ..., -1.0027, -1.0027, -1.0550],
          [ 1.7337,  1.8383,  1.8731,  ..., -0.9853, -1.0376, -1.1770],
          ...,
          [-1.6824, -1.6476, -1.6476,  ..., -1.5256, -1.4733, -1.4384],
          [-1.6999, -1.6476, -1.6302,  ..., -1.4036, -1.5430, -1.5953],
          [-1.7173, -1.6302, -1.6302,  ..., -1.5256, -1.5604, -1.4907]]],


        [[[-0.2171, -0.2171, -0.2513,  ..., -0.3027, -0.3369, -0.3027],
          [-0.2684, -0.3198, -0.3712,  ..., -0.3541, -0.3541, -0.3369],
          [-0.2684, -0.3027, -0.3883,  ..., -0.4054, -0.3883, -0.3712],
          ...,
          [-0.1999, -0.0629, -0.2856,  ..., -0.8678, -0.9192, -1.2103],
          [ 0.1254, -0.3883, -0.1486,  ..., -0.7137, -0.2684, -0.3541],
          [ 0.0569, -0.7479, -0.4739,  ..., -0.8678, -0.9192, -1.1247]],

         [[ 0.1877,  0.1702,  0.1176,  ..., -0.0224, -0.0399, -0.0049],
          [ 0.1001,  0.0476, -0.0224,  ..., -0.0749, -0.0574, -0.0399],
          [ 0.0651,  0.0126, -0.0749,  ..., -0.1099, -0.0924, -0.0749],
          ...,
          [-0.4951, -0.3901, -0.6702,  ..., -0.6877, -0.7927, -1.2304],
          [-0.1800, -0.4601, -0.3550,  ..., -0.7227, -0.5826, -0.6527],
          [ 0.0301, -0.4951, -0.3550,  ..., -0.8803, -0.8803, -1.0903]],

         [[ 0.5311,  0.5311,  0.4788,  ...,  0.3393,  0.3393,  0.3742],
          [ 0.4439,  0.4091,  0.3393,  ...,  0.3045,  0.3219,  0.3393],
          [ 0.4265,  0.3393,  0.2348,  ...,  0.2696,  0.2871,  0.3045],
          ...,
          [-0.6890, -0.5321, -0.7761,  ..., -0.5670, -0.7936, -1.0027],
          [-0.3055, -0.4973, -0.4450,  ..., -0.6193, -0.6018, -0.6193],
          [ 0.1476, -0.3927, -0.3753,  ..., -0.8284, -0.8284, -1.0376]]],


        [[[-0.2342, -0.1999, -0.1828,  ..., -0.6794, -0.6965, -0.7137],
          [-0.3198, -0.2342, -0.2513,  ..., -0.6452, -0.7137, -0.7479],
          [-0.6794, -0.6281, -0.6623,  ..., -0.6623, -0.6965, -0.7479],
          ...,
          [-1.4329, -1.4329, -1.4672,  ..., -1.8097, -1.8268, -1.8268],
          [-1.4843, -1.4843, -1.5014,  ..., -1.8097, -1.8097, -1.8097],
          [-1.5185, -1.5357, -1.5357,  ..., -1.7583, -1.7754, -1.7412]],

         [[-0.8803, -0.8978, -0.8627,  ..., -1.0728, -1.1253, -1.1429],
          [-0.9328, -0.8978, -0.8978,  ..., -1.0728, -1.1253, -1.1429],
          [-1.1779, -1.1429, -1.1604,  ..., -1.0728, -1.1253, -1.1429],
          ...,
          [-1.3354, -1.3354, -1.3704,  ..., -1.7206, -1.7556, -1.7556],
          [-1.3880, -1.3880, -1.4055,  ..., -1.7206, -1.7381, -1.7556],
          [-1.4230, -1.4405, -1.4405,  ..., -1.6681, -1.6856, -1.6856]],

         [[-1.1944, -1.2119, -1.1770,  ..., -1.2641, -1.2641, -1.2816],
          [-1.2641, -1.2119, -1.1944,  ..., -1.2641, -1.2816, -1.2816],
          [-1.3861, -1.3164, -1.3513,  ..., -1.2467, -1.2641, -1.2816],
          ...,
          [-1.0898, -1.1073, -1.1421,  ..., -1.4907, -1.5604, -1.5604],
          [-1.1421, -1.1596, -1.1944,  ..., -1.4907, -1.5430, -1.5604],
          [-1.1944, -1.2293, -1.2467,  ..., -1.4384, -1.4733, -1.4907]]]]),tensor([[1.9801e-13, 1.0000e+00],
        [1.0000e+00, 3.0365e-06],
        [1.6535e-19, 1.0000e+00]]),TensorCategory([1, 0, 1]),tensor([1, 0, 1]),TensorBase([-0.0000e+00, 2.9802e-06, -0.0000e+00])]