Lesson Video:


This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, and wwf currently running at the time of writing this:

  • fastai: 2.1.10
  • fastcore: 1.3.13
  • wwf: 0.0.7

What is the goal today?

In this lesson, we'll show how you can modify any architecture to fit any problem, so you can experiment. In our example we will utilize a UNET architecture to do:

  • Keypoint Regression
  • Multi-Label Classificaiton (covered in notebook, not class)
  • Classification (covered in notebook, not class)
from fastai.vision.all import *

Below you will find the exact imports for everything we use today

from zipfile import ZipFile

from fastcore.basics import store_attr, ifnone
from fastcore.meta import delegates
from fastcore.xtras import Path

from fastai.torch_core import tensor, apply_init

from fastai.callback.hook import num_features_model, summary, hook_outputs, model_sizes, dummy_eval
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_one_cycle

from fastai.data.block import MultiCategoryBlock, DataBlock
from fastai.data.external import URLs, untar_data
from fastai.data.transforms import get_image_files, get_files, get_c, RandomSplitter, ColReader, Normalize

from fastai.layers import SequentialEx, defaults, ConvLayer, BatchNorm, ResBlock, MergeLayer, SigmoidRange, in_channels

from fastai.learner import Learner
from fastai.losses import MSELossFlat, CrossEntropyLossFlat, BCEWithLogitsLossFlat
from fastai.metrics import accuracy, accuracy_multi

from fastai.optimizer import Adam, defaults

from fastai.vision.augment import Resize, Flip, Rotate, Zoom, Warp, aug_transforms, RandomResizedCrop
from fastai.vision.core import PILImage, imagenet_stats
from fastai.vision.data import ImageBlock, PointBlock, ImageDataLoaders
from fastai.vision.learner import model_meta, _default_meta, _add_norm, create_head, create_body, create_unet_model
from fastai.vision.models.unet import _get_sz_change_idxs, UnetBlock, ResizeToOrig, PixelShuffle_ICNR

import warnings

import pandas as pd
import shutil

import numpy as np
from numpy import array

from torch import nn
from torchvision.models.resnet import resnet34

Prep Data

url = "https://drive.google.com/uc?id=1ffJr3NrYPqzutcXsYIVNLXzzUaC9RqYM"
!gdown {url}
from zipfile import ZipFile
with ZipFile('cat-dataset.zip', 'r') as zip_ref:
  zip_ref.extractall()
for i in range(7):
  path = Path(f'CAT_0{i}')
  shutil.rmtree(path)
for i in range(7):
  paths = Path(f'cats/CAT_0{i}').ls()
  for path in paths:
    p = Path(path).absolute()
    par = p.parents[1]
    p.rename(par/p.name)
path = Path('cats')
lbls = get_files(path, extensions='.cat')
imgs = get_image_files(path)
def img2kpts(f): return f'{str(f)}.cat'

def sep_points(coords:array):
  "Seperate a set of points to groups"
  kpts = []
  for i in range(1, int(coords[0]*2), 2):
    kpts.append([coords[i], coords[i+1]])
  return tensor(kpts)
  
def get_y(f:Path):
  "Get keypoints for `f` image"
  pts = np.genfromtxt(img2kpts(f))
  return sep_points(pts)
bad_imgs = []

for name in imgs:
  im = PILImage.create(name)
  y = get_y(name)
  for x in y:
    if x[0] < im.size[0]:
      if x[0] < 0:
        bad_imgs.append(name)
      if x[1] < im.size[1]:
        if x[1] < 0:
          bad_imgs.append(name)
      else:
        bad_imgs.append(name)
    else:
      bad_imgs.append(name)
      
for name in list(set(bad_imgs)):
  name.unlink()

DataBunch

def img2kpts(f): return f'{str(f)}.cat'

def sep_points(coords:array):
  "Seperate a set of points to groups"
  kpts = []
  for i in range(1, int(coords[0]*2), 2):
    kpts.append([coords[i], coords[i+1]])
  return tensor(kpts)
  
def get_y(f:Path):
  "Get keypoints for `f` image"
  pts = np.genfromtxt(img2kpts(f))
  return sep_points(pts)
item_tfms = [Resize(224, method='squish')]
batch_tfms = [Flip(), Rotate(), Zoom(), Warp()]
dblock = DataBlock(blocks=(ImageBlock, PointBlock),
                   get_items=get_image_files,
                   splitter=RandomSplitter(),
                   get_y=get_y,
                   item_tfms=item_tfms,
                   batch_tfms=batch_tfms)
bs=16
dls = dblock.dataloaders('', path='', bs=bs)
dls.show_batch()

Model

Now we need to build our model. First let's look at unet_learner's source code:

@delegates(create_unet_model)
def unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,
                 # learner args
                 loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
                 model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),
                 # other model args
                 **kwargs):
    "Build a unet learner from `dls` and `arch`"

    if config:
        warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')
        kwargs = {**config, **kwargs}

    meta = model_meta.get(arch, _default_meta)
    if normalize: _add_norm(dls, meta, pretrained)

    n_out = ifnone(n_out, get_c(dls))
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    img_size = dls.one_batch()[0].shape[-2:]
    assert img_size, "image size could not be inferred from data"
    model = create_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs)

    splitter=ifnone(splitter, meta['split'])
    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,
                   moms=moms)
    if pretrained: learn.freeze()
    # keep track of args for loggers
    store_attr('arch,normalize,n_out,pretrained', self=learn)
    if kwargs: store_attr(self=learn, **kwargs)
    return learn

We'll want to mimic how this is being done, specifically in DynamicUnet

from fastai.vision.models.unet import _get_sz_change_idxs, UnetBlock, ResizeToOrig
from fastai.vision.learner import _default_meta, _add_norm

We want to replace the last ConvLayer with our custom head (we'll use create_head). We need to get the input features, we can do this by taking the last layer's size

class CustomUnet(SequentialEx):
    "Create a U-Net from a given architecture."
    def __init__(self, encoder, n_classes, img_size, blur=False, blur_final=True, self_attention=False,
                 y_range=None, last_cross=True, bottle=False, act_cls=defaults.activation,
                 init=nn.init.kaiming_normal_, norm_type=None, **kwargs):
        imsize = img_size
        sizes = model_sizes(encoder, size=imsize)
        sz_chg_idxs = list(reversed(_get_sz_change_idxs(sizes)))
        self.sfs = hook_outputs([encoder[i] for i in sz_chg_idxs], detach=False)
        x = dummy_eval(encoder, imsize).detach()

        ni = sizes[-1][1]
        middle_conv = nn.Sequential(ConvLayer(ni, ni*2, act_cls=act_cls, norm_type=norm_type, **kwargs),
                                    ConvLayer(ni*2, ni, act_cls=act_cls, norm_type=norm_type, **kwargs)).eval()
        x = middle_conv(x)
        layers = [encoder, BatchNorm(ni), nn.ReLU(), middle_conv]

        for i,idx in enumerate(sz_chg_idxs):
            not_final = i!=len(sz_chg_idxs)-1
            up_in_c, x_in_c = int(x.shape[1]), int(sizes[idx][1])
            do_blur = blur and (not_final or blur_final)
            sa = self_attention and (i==len(sz_chg_idxs)-3)
            unet_block = UnetBlock(up_in_c, x_in_c, self.sfs[i], final_div=not_final, blur=do_blur, self_attention=sa,
                                   act_cls=act_cls, init=init, norm_type=norm_type, **kwargs).eval()
            layers.append(unet_block)
            x = unet_block(x)

        ni = x.shape[1]
        if imsize != sizes[0][-2:]: layers.append(PixelShuffle_ICNR(ni, act_cls=act_cls, norm_type=norm_type))
        layers.append(ResizeToOrig())
        if last_cross:
            layers.append(MergeLayer(dense=True))
            ni += in_channels(encoder)
            layers.append(ResBlock(1, ni, ni//2 if bottle else ni, act_cls=act_cls, norm_type=norm_type, **kwargs))
        #layers += [ConvLayer(ni, n_classes, ks=1, act_cls=None, norm_type=norm_type, **kwargs)] # HERE
        nf = num_features_model(nn.Sequential(layers[-1]))
        layers += create_head(nf, n_classes)
        apply_init(nn.Sequential(layers[3], layers[-2]), init)
        if y_range is not None: layers.append(SigmoidRange(*y_range))
        super().__init__(*layers)

    def __del__(self):
        if hasattr(self, "sfs"): self.sfs.remove()

And make a custom_unet function to call it

@delegates(CustomUnet.__init__)
def create_custom_unet_model(arch, n_out, img_size, pretrained=True, cut=None, n_in=3, **kwargs):
    "Create custom unet architecture"
    meta = model_meta.get(arch, _default_meta)
    body = create_body(arch, n_in, pretrained, ifnone(cut, meta['cut']))
    model = CustomUnet(body, n_out, img_size, **kwargs)
    return model
@delegates(create_custom_unet_model)
def custom_unet_learner(dls, arch, normalize=True, n_out=None, pretrained=True, config=None,
                 # learner args
                 loss_func=None, opt_func=Adam, lr=defaults.lr, splitter=None, cbs=None, metrics=None, path=None,
                 model_dir='models', wd=None, wd_bn_bias=False, train_bn=True, moms=(0.95,0.85,0.95),
                 # other model args
                 **kwargs):
    "Build a unet learner from `dls` and `arch`"

    if config:
        warnings.warn('config param is deprecated. Pass your args directly to unet_learner.')
        kwargs = {**config, **kwargs}

    meta = model_meta.get(arch, _default_meta)
    if normalize: _add_norm(dls, meta, pretrained)

    n_out = ifnone(n_out, get_c(dls))
    assert n_out, "`n_out` is not defined, and could not be inferred from data, set `dls.c` or pass `n_out`"
    img_size = dls.one_batch()[0].shape[-2:]
    assert img_size, "image size could not be inferred from data"
    model = create_custom_unet_model(arch, n_out, img_size, pretrained=pretrained, **kwargs) # HERE

    splitter=ifnone(splitter, meta['split'])
    learn = Learner(dls=dls, model=model, loss_func=loss_func, opt_func=opt_func, lr=lr, splitter=splitter, cbs=cbs,
                   metrics=metrics, path=path, model_dir=model_dir, wd=wd, wd_bn_bias=wd_bn_bias, train_bn=train_bn,
                   moms=moms)
    if pretrained: learn.freeze()
    # keep track of args for loggers
    store_attr('arch,normalize,n_out,pretrained', self=learn)
    if kwargs: store_attr(self=learn, **kwargs)
    return learn
learn = custom_unet_learner(dls, resnet34, loss_func=MSELossFlat())
learn.summary()
CustomUnet (Input shape: 16)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     16 x 64 x 112 x 112 
Conv2d                                    9408       False     
BatchNorm2d                               128        True      
ReLU                                                           
MaxPool2d                                                      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
ReLU                                                           
Conv2d                                    36864      False     
BatchNorm2d                               128        True      
____________________________________________________________________________
                     16 x 128 x 28 x 28  
Conv2d                                    73728      False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    8192       False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
ReLU                                                           
Conv2d                                    147456     False     
BatchNorm2d                               256        True      
____________________________________________________________________________
                     16 x 256 x 14 x 14  
Conv2d                                    294912     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    32768      False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
ReLU                                                           
Conv2d                                    589824     False     
BatchNorm2d                               512        True      
____________________________________________________________________________
                     16 x 512 x 7 x 7    
Conv2d                                    1179648    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    131072     False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
ReLU                                                           
Conv2d                                    2359296    False     
BatchNorm2d                               1024       True      
BatchNorm2d                               1024       True      
ReLU                                                           
____________________________________________________________________________
                     16 x 1024 x 7 x 7   
Conv2d                                    4719616    True      
ReLU                                                           
____________________________________________________________________________
                     16 x 512 x 7 x 7    
Conv2d                                    4719104    True      
ReLU                                                           
____________________________________________________________________________
                     16 x 1024 x 7 x 7   
Conv2d                                    525312     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               512        True      
Conv2d                                    2359808    True      
ReLU                                                           
Conv2d                                    2359808    True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     16 x 1024 x 14 x 14 
Conv2d                                    525312     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               256        True      
Conv2d                                    1327488    True      
ReLU                                                           
Conv2d                                    1327488    True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     16 x 768 x 28 x 28  
Conv2d                                    295680     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
Conv2d                                    590080     True      
ReLU                                                           
Conv2d                                    590080     True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     16 x 512 x 56 x 56  
Conv2d                                    131584     True      
ReLU                                                           
PixelShuffle                                                   
BatchNorm2d                               128        True      
____________________________________________________________________________
                     16 x 96 x 112 x 112 
Conv2d                                    165984     True      
ReLU                                                           
Conv2d                                    83040      True      
ReLU                                                           
ReLU                                                           
____________________________________________________________________________
                     16 x 384 x 112 x 11 
Conv2d                                    37248      True      
ReLU                                                           
PixelShuffle                                                   
ResizeToOrig                                                   
MergeLayer                                                     
Conv2d                                    88308      True      
ReLU                                                           
Conv2d                                    88308      True      
Sequential                                                     
ReLU                                                           
AdaptiveAvgPool2d                                              
AdaptiveMaxPool2d                                              
Flatten                                                        
BatchNorm1d                               396        True      
Dropout                                                        
____________________________________________________________________________
                     16 x 512            
Linear                                    101376     True      
ReLU                                                           
BatchNorm1d                               1024       True      
Dropout                                                        
____________________________________________________________________________
                     16 x 18             
Linear                                    9216       True      
____________________________________________________________________________

Total params: 41,332,980
Total trainable params: 20,065,332
Total non-trainable params: 21,267,648

Optimizer used: <function Adam at 0x7f6b8df76e18>
Loss function: FlattenedLoss of MSELoss()

Model frozen up to parameter group #2

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback

And now we can train!

learn.lr_find()
(0.0019054606556892395, 1.5848931980144698e-06)
learn.fit_one_cycle(3, 1e-3)
epoch train_loss valid_loss time
0 0.119708 0.160717 02:46
1 0.040552 0.019095 02:44
2 0.029885 0.013922 02:45
learn.show_results()

Other Examples

These are more a proof of concept to practice with, whether the model performs better or worse is unknown

PETS

path = untar_data(URLs.PETS)
fnames = get_image_files(path/'images')
pat = r'(.+)_\d+.jpg$'
batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))
bs=18

data = ImageDataLoaders.from_name_re(path, fnames, pat, batch_tfms=batch_tfms, 
                                   item_tfms=item_tfms, bs=bs)
learn = custom_unet_learner(data, resnet34, loss_func=CrossEntropyLossFlat(), metrics=accuracy)
learn.lr_find()
(0.0015848932787775993, 3.019951861915615e-07)
learn.fit_one_cycle(5, 1e-3)
epoch train_loss valid_loss accuracy time
0 2.646517 3.048354 0.147497 02:20
1 1.532955 1.528544 0.517591 02:20
2 1.032563 1.075425 0.650203 02:20
3 0.714636 0.439413 0.864682 02:20
4 0.597098 0.383284 0.893099 02:21

Planets

planet_source = untar_data(URLs.PLANET_SAMPLE)
df = pd.read_csv(planet_source/'labels.csv')
batch_tfms = aug_transforms(flip_vert=True, max_lighting=0.1, max_zoom=1.05, max_warp=0.)
planet = DataBlock(blocks=(ImageBlock, MultiCategoryBlock),
                   get_x=ColReader(0, pref=f'{planet_source}/train/', suff='.jpg'),
                   splitter=RandomSplitter(),
                   get_y=ColReader(1, label_delim=' '),
                   batch_tfms = batch_tfms)
dls = planet.dataloaders(df, bs=12)
dls.show_batch(max_n=9, figsize=(12,9))
learn = custom_unet_learner(dls, resnet34, loss_func=BCEWithLogitsLossFlat(), metrics=[accuracy_multi])
learn.lr_find()
(0.014454397559165954, 0.03981071710586548)
learn.fit_one_cycle(5, slice(1e-2))
epoch train_loss valid_loss accuracy_multi time
0 0.534084 0.344417 0.881765 00:26
1 0.323197 0.270720 0.903235 00:23
2 0.265935 0.274014 0.907059 00:23
3 0.244585 0.207186 0.916177 00:23
4 0.229131 0.204003 0.916471 00:23