Lesson Video:


This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, and wwf currently running at the time of writing this:

  • fastai: 2.1.10
  • fastcore: 1.3.13
  • wwf: 0.0.7

Dataset:

Our dataset today will be ImageWoof. Link

Goal: Using no pre-trained weights, see how well of accuracy we can get in x epochs

This dataset is generally harder than imagenette, both are a subset of ImageNet.

Models are leaning more towards being faster, more effecient

Let's import the library:

from fastai.vision.all import *

Below you will find the exact imports for everything we use today

import kornia
from torch import nn

from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_flat_cos, fit_one_cycle

from fastai.data.core import Datasets
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import Categorize, GrandparentSplitter, IntToFloatTensor, Normalize, ToTensor, parent_label

from fastai.layers import Mish
from fastai.learner import Learner

from fastai.metrics import LabelSmoothingCrossEntropy, top_k_accuracy, accuracy

from fastai.optimizer import ranger, Lookahead, RAdam

from fastai.vision.augment import FlipItem, RandomResizedCrop, Resize
from fastai.vision.core import PILImage, imagenet_stats, get_image_files
from fastai.vision.models.xresnet import xresnet50

Let's grab our data. For the competition, we'll focus on 5 epochs at 128x128

path = untar_data(URLs.IMAGEWOOF)

There are a few more datasets available:

  • ImageNette: Slightly easier than ImageWoof, 10 easily classified classes from Imagenet (tench, English springer, cassette player, chain saw, church, French horn, garbage truck, gas pump, golf ball, parachute)

  • ImageWoof: 10 different dog breeds, Australian terrier, Border terrier, Samoyed, Beagle, Shih-Tzu, English foxhound, Rhodesian ridgeback, Dingo, Golden retriever, Old English sheepdog

  • Image网 (Pronounced Imagewang, "net" in Chinese): Both ImageNette and Woof but:

    • The validation set is the same as Imagewoof (i.e. 30% of Imagewoof images); there are no Imagenette images in the validation set (they're all in the training set)
    • Only 10% of Imagewoof images are in the training set!
    • The remaining are in the unsup ("unsupervised") directory, and you can not use their labels in training!

We'll use the low-level Dataset API for this.

tfms = [[PILImage.create], [parent_label, Categorize()]]
item_tfms = [ToTensor(), Resize(128)]
batch_tfms = [FlipItem(), RandomResizedCrop(128, min_scale=0.35),
              IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)]

Let's make our split

items = get_image_files(path)
split_idx = GrandparentSplitter(valid_name='val')(items)

Now let's build our Datasets

dsets = Datasets(items, tfms, splits=split_idx)

And our DataLoaders

dls = dsets.dataloaders(after_item=item_tfms, after_batch=batch_tfms, bs=64)

Let's make sure they look okay

dls.show_batch()

Wait, those aren't species names! How do I get that readable?

  • We can use a dictionary as a transform on our y's
lbl_dict = dict(
  n02086240= 'Shih-Tzu',
  n02087394= 'Rhodesian ridgeback',
  n02088364= 'Beagle',
  n02089973= 'English foxhound',
  n02093754= 'Australian terrier',
  n02096294= 'Border terrier',
  n02099601= 'Golden retriever',
  n02105641= 'Old English sheepdog',
  n02111889= 'Samoyed',
  n02115641= 'Dingo'
)

To do so we pass in the __getitem__ attribute to our transforms

tfms = [[PILImage.create], [parent_label, lbl_dict.__getitem__, Categorize()]]
dsets = Datasets(items, tfms, splits=split_idx)
dls = dsets.dataloaders(after_item=item_tfms, after_batch=batch_tfms, bs=64)

Let's make sure it worked

dls.show_batch()

Much better!

The Architecture and the new Improvements

All of these ideas live in the library (except Dialated Convolutions for now). But where? Let's walk through it

XResNet50:

arch = xresnet50(pretrained=False)

After the ImageWoof competition, almost all of the above can simply be a setting in xresnet.

Mish:

arch = xresnet50(pretrained=False, act_cls=Mish)
arch[0]
ConvLayer(
  (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): Mish()
)

Self-Attention:

arch = xresnet50(pretrained=False, act_cls=Mish, sa=True)

MaxBlurPool

For this one we'll need a custom function

# https://discuss.pytorch.org/t/how-can-i-replace-an-intermediate-layer-in-a-pre-trained-network/3586/7
import kornia
def convert_MP_to_blurMP(model, layer_type_old):
    conversion_count = 0
    for name, module in reversed(model._modules.items()):
        if len(list(module.children())) > 0:
            # recurse
            model._modules[name] = convert_MP_to_blurMP(module, layer_type_old)

        if type(module) == layer_type_old:
            layer_old = module
            layer_new = kornia.contrib.MaxBlurPool2d(3, True)
            model._modules[name] = layer_new

    return model
net = xresnet50(pretrained=False, act_cls=Mish, sa=True, n_out=10)
net = convert_MP_to_blurMP(net, nn.MaxPool2d)

Now that we know how to use it all, let's test it out!

Ranger + Fit-Flat-Cosine

opt_func = ranger

That is the same thing as:

def opt_func(ps, lr=1e-3): return Lookahead(RAdam(ps, lr=lr))
opt_func = opt_func

We'll also use Label Smoothing Cross Entropy as our loss

learn = Learner(dls, model=net, loss_func=LabelSmoothingCrossEntropy(), metrics=[top_k_accuracy, accuracy])

Finally let's find and fit our model

learn.lr_find()

And fit for 5 epochs! We want to use fit_flat_cos that Mikahil Grankin came up with. Why?

  • We noticed gradient blow up. So instead of One-Cycle:

One-Cycle vs Fit-Flat-Cosine

from fastai.test_utils import synth_learner
synth = synth_learner()
synth.fit_one_cycle(1)
epoch train_loss valid_loss time
0 21.796806 18.085186 00:00
synth.recorder.plot_sched()

We fit with Cosine Annealing:

synth.fit_flat_cos(1, pct_start=0.72)
epoch train_loss valid_loss time
0 17.837477 12.652998 00:00
synth.recorder.plot_sched()

Let's Train!

One Cycle:

net = xresnet50(pretrained=False, act_cls=Mish, sa=True, n_out=10)
net = convert_MP_to_blurMP(net, nn.MaxPool2d)
learn = Learner(dls, model=net, loss_func=LabelSmoothingCrossEntropy(), metrics=[top_k_accuracy, accuracy])
learn.fit_one_cycle(5, 2e-3)
epoch train_loss valid_loss top_k_accuracy accuracy time
0 2.242338 2.251166 0.741410 0.251209 01:23
1 1.895956 2.003013 0.855179 0.325783 01:25
2 1.646342 1.774675 0.885722 0.472130 01:25
3 1.441252 1.400105 0.941970 0.617460 01:25
4 1.296093 1.319047 0.953678 0.652838 01:24
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3103: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed "

Ranger:

net = xresnet50(pretrained=False, act_cls=Mish, sa=True, n_out=10)
net = convert_MP_to_blurMP(net, nn.MaxPool2d)
learn = Learner(dls, model=net, loss_func=LabelSmoothingCrossEntropy(), metrics=[top_k_accuracy, accuracy], opt_func=ranger)
learn.fit_flat_cos(5, 4e-3)
epoch train_loss valid_loss top_k_accuracy accuracy time
0 2.120393 2.435370 0.769152 0.208705 01:24
1 1.818861 1.735399 0.897175 0.458386 01:24
2 1.641860 1.936957 0.855688 0.388394 01:24
3 1.490574 1.440254 0.939934 0.596335 01:23
4 1.279275 1.272718 0.958768 0.671418 01:23
/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py:3103: UserWarning: The default behavior for interpolate/upsample with float scale_factor changed in 1.6.0 to align with other frameworks/libraries, and now uses scale_factor directly, instead of relying on the computed output size. If you wish to restore the old behavior, please set recompute_scale_factor=True. See the documentation of nn.Upsample for details. 
  warnings.warn("The default behavior for interpolate/upsample with float scale_factor changed "