Imports:
from fastai.vision.all import *
Below you will find the exact imports for everything we use today
from functools import partial
from fastcore.transform import Pipeline
from fastai.callback.fp16 import to_fp16
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import fine_tune
from fastai.data.block import DataBlock, MultiCategoryBlock
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import RandomSplitter, RegexLabeller, Normalize
from fastai.metrics import accuracy_multi, BCEWithLogitsLossFlat
from fastai.vision.augment import RandomResizedCrop, aug_transforms
from fastai.vision.core import get_image_files, PILImage
from fastai.vision.data import ImageBlock, imagenet_stats
from fastai.vision.learner import cnn_learner
from torchvision.models.resnet import resnet34
In this notebook, we will use MultiCategoryBlock
in a specially clever way to have our model return no labels when the given example does not belong to any of the classes seen during training. So let's repurpose our previous code for the dataset pets
to be able to tell if an image does not belong to any of the breeds seen during training, e.g. a donkey picture.
path = untar_data(URLs.PETS)/'images'
path.ls()[:3]
We'll go ahead and make the dataloaders with one important change, the get_y function since MutiCategoryBlock expects a list of labels.
Now we pass that into our get_y
along with any labeller
pets_multi = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=Pipeline([RegexLabeller(pat = r'/([^/]+)_\d+.jpg$'), lambda label: [label]]),
item_tfms=RandomResizedCrop(460, min_scale=0.75),
batch_tfms=[*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)])
dls = pets_multi.dataloaders(untar_data(URLs.PETS)/"images", bs=32)
dls.show_batch(max_n=9, figsize=(7,8))
Training our Model
Note below a very important detail which is that there are two different thresholds. The metric will use a very high threshold so that only highly confident predictions are accepted as correct. The loss function, however, uses the default 0.5 threshold so that the model is not incentivated to make extreme predictions even if unsure.
learn = cnn_learner(dls, resnet34, pretrained=True, metrics=[partial(accuracy_multi, thresh=0.95)], loss_func=BCEWithLogitsLossFlat(thresh=0.5)).to_fp16()
learn.fine_tune(epochs=4, base_lr=2e-3)
learn.save('cats-vs-dogs')
learn.recorder.plot_loss()
That looks very nice..! Let's see how we did in the next section.
learn.loss_func=BCEWithLogitsLossFlat(thresh=0.95)
learn.show_results()
img = PILImage.create('persian_cat.jpg')
img.show()
learn.predict(img)[0]
Awesome! The model returns only one label and it is the correct one. Let's see if we try with a donkey picture...
img = PILImage.create('donkey.jpg')
img.show()
learn.predict(img)[0]
Nothing! Our classifier is smart enough to return no label if the picture does not belong to any of the classes seen during training. Isn't it great?
img = PILImage.create('real-ai.jpg')
img.show()
learn.predict(img)[0]