My problem
I often want to look at the predictions of specific items in the validation set, to see if I can find patterns in the errors made by the model. This notebook extends the Interpretation
object created on top of a learner to add a shortcut method, show_at
that does exactly this. Let's use as example the "is a cat" classifier, as trained in the fastbook
from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
def is_cat(x): return x[0].isupper()
dls = ImageDataLoaders.from_name_func(
path, get_image_files(path), valid_pct=0.2, seed=42,
label_func=is_cat, item_tfms=Resize(224))
learn = cnn_learner(dls, resnet34, metrics=error_rate)
learn.fine_tune(1)
Creating a ClassificationInterpretation
from the learner will give us shorcuts to interpret its results. By default, the Intepretation object will be created around the validation dataloader
interp = ClassificationInterpretation.from_learner(learn)
Let's say we are interested in the predictions for the first item of the validation set. Normally, what I do to visualize both the item and predictions for a single item is to first show it in the dataset and then see its prediction in interp.preds
or interp.decoded
item_idx = 0
show_at(dls.valid.dataset, item_idx)
print(interp.decoded[item_idx])
It would be much easier if we could just call show_at
as a method of the interp
object, and plot the results in the same way that learn.show_results
does it. Here's a piece of code by Zach Mueller, taken from the discord community, to achieve this goal.
@patch
@delegates(TfmdDL.show_results)
def show_at(self:Interpretation, idxs, **kwargs):
"Show predictions on the items at `idxs`"
inp, _, targ, dec, _ = self[idxs]
self.dl.show_results((inp, dec), targ, **kwargs)
As you can see, the code of the show_at
method is pretty simple. It uses the fastcore's
@patch
decorator to add the method to the class Interpretation
(therefore the self:Interpretation
as first argument), and the @delegates
decorator to replace
**kwargs
in the signature of the method with the arguments of show_results
. All the function does is grab the inputs, targets and decoded predictions from the corresponding attributes of the Interpretation object, and call show_results
from its dataloader. By default, when the Interpretation
object is created using the method from_learner
, this dataloader is the validation dataloader used in the training.
Grabbing the inputs, targets and decoded predictions is done by calling self[idxs]
. For that to work, we need a __getitem__
method in the class Interpretation
. That method calls getattr
for every indexable attribute within
Interpretation (i.e, inputs, predictions, decoded predictions, targets, and losses).
@patch
def __getitem__(self:Interpretation, idxs):
"Get the inputs, preds, targets, decoded outputs, and losses at `idxs`"
if not is_listy(idxs): idxs = [idxs]
attrs = 'inputs,preds,targs,decoded,losses'
res = L([getattr(self, attr)[idxs] for attr in attrs.split(',')])
return res
Let's see now an example of how show_at
work for a single item, in this case, the first element of the validation dataset.
interp.show_at(0)
Here's another example to show the predictions of multiple items, namely the three elements of the validation set with the largest loss
interp.show_at(interp.top_losses(3)[1])
Additionaly, the method __getitem__
is also very useful when you want to know everything (data, prediction, decoded prediction, actual label) of a specific item of the dataset.
interp[:3]