This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.
Below are the versions of fastai
, fastcore
, wwf
, and timm
currently running at the time of writing this:
fastai
: 2.1.10fastcore
: 1.3.13wwf
: 0.0.7timm
: 0.3.2
This notebook will cover:
- Using a
PyTorch
model - Using pre-trained weights for transfer learning
- Setting up a
cnn_learner
styleLearner
The Problem:
The problem today will be a familiar one, PETs
, as we are going to focus on the Learner
instead
from fastai.vision.all import *
from fastai.vision.learner import _update_first_layer
Below you will find the exact imports for everything we use today
from fastcore.xtras import Path
from fastai.callback.hook import summary
from fastai.callback.progress import ProgressCallback
from fastai.callback.schedule import lr_find, fit_one_cycle
from fastai.data.block import DataBlock, CategoryBlock
from fastai.data.core import Datasets
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files, Normalize, RandomSplitter, GrandparentSplitter, RegexLabeller, ToTensor, IntToFloatTensor, Categorize, parent_label
from fastai.learner import Learner
from fastai.losses import LabelSmoothingCrossEntropy
from fastai.metrics import error_rate, accuracy
from fastai.vision.augment import aug_transforms, RandomResizedCrop, Resize, FlipItem
from fastai.vision.core import PILImage, imagenet_stats
from fastai.vision.data import ImageBlock
from fastai.vision.learner import cnn_learner, create_head, create_body, num_features_model, default_split, has_pool_type, apply_init, _update_first_layer
import torch
from torch import nn
from torchvision.models.resnet import resnet18
from timm import create_model
Let's make our usual dataloaders real quick
path = untar_data(URLs.PETS)/'images'
fnames = get_image_files(path)
pat = r'/([^/]+)_\d+.*'
batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))
bs=64
pets = DataBlock(blocks=(ImageBlock, CategoryBlock),
get_items=get_image_files,
splitter=RandomSplitter(),
get_y=RegexLabeller(pat = r'/([^/]+)_\d+.*'),
item_tfms=item_tfms,
batch_tfms=batch_tfms)
dls = pets.dataloaders(path, bs=bs)
dls.show_batch(max_n=9, figsize=(6,7))
Now let's focus on our EfficentNet
model. We'll be working out of Ross Wightman's repository here. Included in this repository is tons of pretrained models for almost every major model in Computer Vision. All were for 224x224 training and validation size. Let's install it
!pip install timm
Now we can then use his weights one of two ways. First we'll show the direct way to load it in, then we'll load in the weights ourselves
from timm import create_model
net = create_model('efficientnet_b3a', pretrained=True)
Now let's take a look at our downloaded model, so we know how to modify it for transfer learning. With fastai models we can do something like so:
learn = cnn_learner(dls, resnet18)
learn.model[-1]
Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): full: False (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25, inplace=False) (4): Linear(in_features=1024, out_features=512, bias=False) (5): ReLU(inplace=True) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5, inplace=False) (8): Linear(in_features=512, out_features=37, bias=False) )
And we see this head of our model! Let's see if we can do this for our EfficientNet
net[-1]
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-20-9c0a096ee31d> in <module>() ----> 1 net[-1] TypeError: 'EfficientNet' object does not support indexing
No! Why?
len(learn.model)
2
len(net)
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-13-fefbadd730e7> in <module>() ----> 1 len(net) TypeError: object of type 'EfficientNet' has no len()
We can see that our fastai
model was split into two different layer groups:
- Group 1: Our encoder, which is everything but the last layer of our original model
- Group 2: Our head, which is a
fastai
version of aLinear
layer plus a few extra bits
create_head(2048, 10)
Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): Flatten(full=False) (2): BatchNorm1d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25, inplace=False) (4): Linear(in_features=2048, out_features=512, bias=False) (5): ReLU(inplace=True) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5, inplace=False) (8): Linear(in_features=512, out_features=10, bias=False) )
How do we do this for our model? Let's take a look at it:
net
EfficientNet( (conv_stem): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (bn1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (blocks): Sequential( (0): Sequential( (0): DepthwiseSeparableConv( (conv_dw): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False) (bn1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pw): Conv2d(40, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Identity() ) (1): DepthwiseSeparableConv( (conv_dw): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24, bias=False) (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(24, 6, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(6, 24, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pw): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Identity() ) ) (1): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False) (bn2): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(144, 6, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(6, 144, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False) (bn2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(192, 8, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(8, 192, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False) (bn2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(192, 8, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(8, 192, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(192, 192, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=192, bias=False) (bn2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(192, 8, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(8, 192, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=288, bias=False) (bn2): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(288, 12, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(12, 288, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=288, bias=False) (bn2): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(288, 12, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(12, 288, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (3): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(288, 288, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=288, bias=False) (bn2): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(288, 12, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(12, 288, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (4): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (5): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (6): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(384, 2304, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(2304, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(2304, 2304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2304, bias=False) (bn2): BatchNorm2d(2304, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(2304, 96, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(96, 2304, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(2304, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (conv_head): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (global_pool): SelectAdaptivePool2d (pool_type=avg, flatten=True) (classifier): Linear(in_features=1536, out_features=1000, bias=True) )
We can see that our Pooling
layer and our Linear
layer is the last two layers of our model. Let's pop those off
Now if we use the original fastai
create_body
function, we'll get an error:
body = create_body(net, pretrained=False, cut=-1)
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) <ipython-input-21-d13ee58c7af6> in <module>() ----> 1 body = create_body(net, pretrained=False, cut=-1) /usr/local/lib/python3.6/dist-packages/fastai/vision/learner.py in create_body(arch, n_in, pretrained, cut) 63 def create_body(arch, n_in=3, pretrained=True, cut=None): 64 "Cut off the body of a typically pretrained `arch` as determined by `cut`" ---> 65 model = arch(pretrained=pretrained) 66 _update_first_layer(model, n_in, pretrained) 67 #cut = ifnone(cut, cnn_config(arch)['cut']) /usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 725 result = self._slow_forward(*input, **kwargs) 726 else: --> 727 result = self.forward(*input, **kwargs) 728 for hook in itertools.chain( 729 _global_forward_hooks.values(), TypeError: forward() got an unexpected keyword argument 'pretrained'
Why? Let's take a look
def create_body(arch, pretrained=True, cut=None):
"Cut off the body of a typically pretrained `arch` as determined by `cut`"
model = arch(pretrained=pretrained)
#cut = ifnone(cut, cnn_config(arch)['cut'])
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
elif callable(cut): return cut(model)
else: raise NamedError("cut must be either integer or a function")
We can see that arch needs to be a generator. Let's try to make a function to help us with specifically his library
def create_timm_body(arch:str, pretrained=True, cut=None, n_in=3):
"Creates a body from any model in the `timm` library."
model = create_model(arch, pretrained=pretrained, num_classes=0, global_pool='')
_update_first_layer(model, n_in, pretrained)
if cut is None:
ll = list(enumerate(model.children()))
cut = next(i for i,o in reversed(ll) if has_pool_type(o))
if isinstance(cut, int): return nn.Sequential(*list(model.children())[:cut])
elif callable(cut): return cut(model)
else: raise NamedError("cut must be either integer or function")
Let's try it out!
body = create_timm_body('efficientnet_b3a', pretrained=True)
len(body)
7
Now we can see that we have seven seperate groups
body
Sequential( (0): Conv2d(3, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) (1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (2): SiLU(inplace=True) (3): Sequential( (0): Sequential( (0): DepthwiseSeparableConv( (conv_dw): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=40, bias=False) (bn1): BatchNorm2d(40, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(40, 10, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(10, 40, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pw): Conv2d(40, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Identity() ) (1): DepthwiseSeparableConv( (conv_dw): Conv2d(24, 24, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=24, bias=False) (bn1): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(24, 6, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(6, 24, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pw): Conv2d(24, 24, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn2): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): Identity() ) ) (1): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(24, 144, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(144, 144, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=144, bias=False) (bn2): BatchNorm2d(144, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(144, 6, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(6, 144, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(144, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False) (bn2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(192, 8, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(8, 192, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(192, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=192, bias=False) (bn2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(192, 8, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(8, 192, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(192, 32, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (2): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(32, 192, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(192, 192, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=192, bias=False) (bn2): BatchNorm2d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(192, 8, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(8, 192, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(192, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=288, bias=False) (bn2): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(288, 12, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(12, 288, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(288, 288, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=288, bias=False) (bn2): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(288, 12, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(12, 288, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(288, 48, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(48, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (3): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(48, 288, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(288, 288, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), groups=288, bias=False) (bn2): BatchNorm2d(288, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(288, 12, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(12, 288, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(288, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 96, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (4): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(96, 576, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(576, 576, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=576, bias=False) (bn2): BatchNorm2d(576, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(576, 24, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(24, 576, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(576, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 136, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(136, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (5): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(136, 816, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(816, 816, kernel_size=(5, 5), stride=(2, 2), padding=(2, 2), groups=816, bias=False) (bn2): BatchNorm2d(816, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(816, 34, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(34, 816, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(816, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (2): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (3): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (4): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (5): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 232, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(232, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) (6): Sequential( (0): InvertedResidual( (conv_pw): Conv2d(232, 1392, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(1392, 1392, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=1392, bias=False) (bn2): BatchNorm2d(1392, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(1392, 58, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(58, 1392, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(1392, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) (1): InvertedResidual( (conv_pw): Conv2d(384, 2304, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn1): BatchNorm2d(2304, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act1): SiLU(inplace=True) (conv_dw): Conv2d(2304, 2304, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=2304, bias=False) (bn2): BatchNorm2d(2304, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (act2): SiLU(inplace=True) (se): SqueezeExcite( (conv_reduce): Conv2d(2304, 96, kernel_size=(1, 1), stride=(1, 1)) (act1): SiLU(inplace=True) (conv_expand): Conv2d(96, 2304, kernel_size=(1, 1), stride=(1, 1)) ) (conv_pwl): Conv2d(2304, 384, kernel_size=(1, 1), stride=(1, 1), bias=False) (bn3): BatchNorm2d(384, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) ) ) ) (4): Conv2d(384, 1536, kernel_size=(1, 1), stride=(1, 1), bias=False) (5): BatchNorm2d(1536, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (6): SiLU(inplace=True) )
But we've popped off the last layers we need! Let's move onto our head of the model. We know the input should be 3072
(we can see this in the last linear layer of the original model). We need it 2x it because of our AdaptiveConcatPooling
We want it to have an output to our classes. But what if we dont' know that?
nf = num_features_model(nn.Sequential(*body.children())); nf
3072
head = create_head(nf, dls.c)
head
Sequential( (0): AdaptiveConcatPool2d( (ap): AdaptiveAvgPool2d(output_size=1) (mp): AdaptiveMaxPool2d(output_size=1) ) (1): Flatten(full=False) (2): BatchNorm1d(3072, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (3): Dropout(p=0.25, inplace=False) (4): Linear(in_features=3072, out_features=512, bias=False) (5): ReLU(inplace=True) (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (7): Dropout(p=0.5, inplace=False) (8): Linear(in_features=512, out_features=37, bias=False) )
Now finally we need to wrap it together
model = nn.Sequential(body, head)
And then we initialize our new head of our model
apply_init(model[1], nn.init.kaiming_normal_)
Now we have our two layer-long model! What's next?
len(model)
2
Let's try making a Learner
learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy())
learn.summary()
Sequential (Input shape: 64) ============================================================================ Layer (type) Output Shape Param # Trainable ============================================================================ 64 x 40 x 112 x 112 Conv2d 1080 True BatchNorm2d 80 True SiLU Conv2d 360 True BatchNorm2d 80 True SiLU ____________________________________________________________________________ 64 x 10 x 1 x 1 Conv2d 410 True SiLU ____________________________________________________________________________ 64 x 40 x 1 x 1 Conv2d 440 True ____________________________________________________________________________ 64 x 24 x 112 x 112 Conv2d 960 True BatchNorm2d 48 True Identity Conv2d 216 True BatchNorm2d 48 True SiLU ____________________________________________________________________________ 64 x 6 x 1 x 1 Conv2d 150 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 168 True Conv2d 576 True BatchNorm2d 48 True Identity ____________________________________________________________________________ 64 x 144 x 112 x 11 Conv2d 3456 True BatchNorm2d 288 True SiLU ____________________________________________________________________________ 64 x 144 x 56 x 56 Conv2d 1296 True BatchNorm2d 288 True SiLU ____________________________________________________________________________ 64 x 6 x 1 x 1 Conv2d 870 True SiLU ____________________________________________________________________________ 64 x 144 x 1 x 1 Conv2d 1008 True ____________________________________________________________________________ 64 x 32 x 56 x 56 Conv2d 4608 True BatchNorm2d 64 True ____________________________________________________________________________ 64 x 192 x 56 x 56 Conv2d 6144 True BatchNorm2d 384 True SiLU Conv2d 1728 True BatchNorm2d 384 True SiLU ____________________________________________________________________________ 64 x 8 x 1 x 1 Conv2d 1544 True SiLU ____________________________________________________________________________ 64 x 192 x 1 x 1 Conv2d 1728 True ____________________________________________________________________________ 64 x 32 x 56 x 56 Conv2d 6144 True BatchNorm2d 64 True ____________________________________________________________________________ 64 x 192 x 56 x 56 Conv2d 6144 True BatchNorm2d 384 True SiLU Conv2d 1728 True BatchNorm2d 384 True SiLU ____________________________________________________________________________ 64 x 8 x 1 x 1 Conv2d 1544 True SiLU ____________________________________________________________________________ 64 x 192 x 1 x 1 Conv2d 1728 True ____________________________________________________________________________ 64 x 32 x 56 x 56 Conv2d 6144 True BatchNorm2d 64 True ____________________________________________________________________________ 64 x 192 x 56 x 56 Conv2d 6144 True BatchNorm2d 384 True SiLU ____________________________________________________________________________ 64 x 192 x 28 x 28 Conv2d 4800 True BatchNorm2d 384 True SiLU ____________________________________________________________________________ 64 x 8 x 1 x 1 Conv2d 1544 True SiLU ____________________________________________________________________________ 64 x 192 x 1 x 1 Conv2d 1728 True ____________________________________________________________________________ 64 x 48 x 28 x 28 Conv2d 9216 True BatchNorm2d 96 True ____________________________________________________________________________ 64 x 288 x 28 x 28 Conv2d 13824 True BatchNorm2d 576 True SiLU Conv2d 7200 True BatchNorm2d 576 True SiLU ____________________________________________________________________________ 64 x 12 x 1 x 1 Conv2d 3468 True SiLU ____________________________________________________________________________ 64 x 288 x 1 x 1 Conv2d 3744 True ____________________________________________________________________________ 64 x 48 x 28 x 28 Conv2d 13824 True BatchNorm2d 96 True ____________________________________________________________________________ 64 x 288 x 28 x 28 Conv2d 13824 True BatchNorm2d 576 True SiLU Conv2d 7200 True BatchNorm2d 576 True SiLU ____________________________________________________________________________ 64 x 12 x 1 x 1 Conv2d 3468 True SiLU ____________________________________________________________________________ 64 x 288 x 1 x 1 Conv2d 3744 True ____________________________________________________________________________ 64 x 48 x 28 x 28 Conv2d 13824 True BatchNorm2d 96 True ____________________________________________________________________________ 64 x 288 x 28 x 28 Conv2d 13824 True BatchNorm2d 576 True SiLU ____________________________________________________________________________ 64 x 288 x 14 x 14 Conv2d 2592 True BatchNorm2d 576 True SiLU ____________________________________________________________________________ 64 x 12 x 1 x 1 Conv2d 3468 True SiLU ____________________________________________________________________________ 64 x 288 x 1 x 1 Conv2d 3744 True ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 27648 True BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 True BatchNorm2d 1152 True SiLU Conv2d 5184 True BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 True SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 True ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 55296 True BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 True BatchNorm2d 1152 True SiLU Conv2d 5184 True BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 True SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 True ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 55296 True BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 True BatchNorm2d 1152 True SiLU Conv2d 5184 True BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 True SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 True ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 55296 True BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 True BatchNorm2d 1152 True SiLU Conv2d 5184 True BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 True SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 True ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 55296 True BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 True BatchNorm2d 1152 True SiLU Conv2d 14400 True BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 True SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 True ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 78336 True BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 True BatchNorm2d 1632 True SiLU Conv2d 20400 True BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 True SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 True ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 110976 True BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 True BatchNorm2d 1632 True SiLU Conv2d 20400 True BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 True SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 True ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 110976 True BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 True BatchNorm2d 1632 True SiLU Conv2d 20400 True BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 True SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 True ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 110976 True BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 True BatchNorm2d 1632 True SiLU Conv2d 20400 True BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 True SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 True ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 110976 True BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 True BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 816 x 7 x 7 Conv2d 20400 True BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 True SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 True ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 189312 True BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 True BatchNorm2d 2784 True SiLU Conv2d 34800 True BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 True SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 True ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 True BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 True BatchNorm2d 2784 True SiLU Conv2d 34800 True BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 True SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 True ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 True BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 True BatchNorm2d 2784 True SiLU Conv2d 34800 True BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 True SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 True ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 True BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 True BatchNorm2d 2784 True SiLU Conv2d 34800 True BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 True SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 True ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 True BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 True BatchNorm2d 2784 True SiLU Conv2d 34800 True BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 True SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 True ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 True BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 True BatchNorm2d 2784 True SiLU Conv2d 12528 True BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 True SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 True ____________________________________________________________________________ 64 x 384 x 7 x 7 Conv2d 534528 True BatchNorm2d 768 True ____________________________________________________________________________ 64 x 2304 x 7 x 7 Conv2d 884736 True BatchNorm2d 4608 True SiLU Conv2d 20736 True BatchNorm2d 4608 True SiLU ____________________________________________________________________________ 64 x 96 x 1 x 1 Conv2d 221280 True SiLU ____________________________________________________________________________ 64 x 2304 x 1 x 1 Conv2d 223488 True ____________________________________________________________________________ 64 x 384 x 7 x 7 Conv2d 884736 True BatchNorm2d 768 True ____________________________________________________________________________ 64 x 1536 x 7 x 7 Conv2d 589824 True BatchNorm2d 3072 True SiLU AdaptiveAvgPool2d AdaptiveMaxPool2d Flatten BatchNorm1d 6144 True Dropout ____________________________________________________________________________ 64 x 512 Linear 1572864 True ReLU BatchNorm1d 1024 True Dropout ____________________________________________________________________________ 64 x 37 Linear 18944 True ____________________________________________________________________________ Total params: 12,295,208 Total trainable params: 12,295,208 Total non-trainable params: 0 Optimizer used: <function Adam at 0x7f8f350cfc80> Loss function: LabelSmoothingCrossEntropy() Callbacks: - TrainEvalCallback - Recorder - ProgressCallback
Oh no! It isn't frozen, what do we do? We never split the model! Since we have it set to where model[0]
is the first group and model[1]
is the second group, we can use the default_split
splitter. Let's try again
learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy(),
splitter=default_split, metrics=error_rate)
learn.freeze()
learn.summary()
Sequential (Input shape: 64) ============================================================================ Layer (type) Output Shape Param # Trainable ============================================================================ 64 x 40 x 112 x 112 Conv2d 1080 False BatchNorm2d 80 True SiLU Conv2d 360 False BatchNorm2d 80 True SiLU ____________________________________________________________________________ 64 x 10 x 1 x 1 Conv2d 410 False SiLU ____________________________________________________________________________ 64 x 40 x 1 x 1 Conv2d 440 False ____________________________________________________________________________ 64 x 24 x 112 x 112 Conv2d 960 False BatchNorm2d 48 True Identity Conv2d 216 False BatchNorm2d 48 True SiLU ____________________________________________________________________________ 64 x 6 x 1 x 1 Conv2d 150 False SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 168 False Conv2d 576 False BatchNorm2d 48 True Identity ____________________________________________________________________________ 64 x 144 x 112 x 11 Conv2d 3456 False BatchNorm2d 288 True SiLU ____________________________________________________________________________ 64 x 144 x 56 x 56 Conv2d 1296 False BatchNorm2d 288 True SiLU ____________________________________________________________________________ 64 x 6 x 1 x 1 Conv2d 870 False SiLU ____________________________________________________________________________ 64 x 144 x 1 x 1 Conv2d 1008 False ____________________________________________________________________________ 64 x 32 x 56 x 56 Conv2d 4608 False BatchNorm2d 64 True ____________________________________________________________________________ 64 x 192 x 56 x 56 Conv2d 6144 False BatchNorm2d 384 True SiLU Conv2d 1728 False BatchNorm2d 384 True SiLU ____________________________________________________________________________ 64 x 8 x 1 x 1 Conv2d 1544 False SiLU ____________________________________________________________________________ 64 x 192 x 1 x 1 Conv2d 1728 False ____________________________________________________________________________ 64 x 32 x 56 x 56 Conv2d 6144 False BatchNorm2d 64 True ____________________________________________________________________________ 64 x 192 x 56 x 56 Conv2d 6144 False BatchNorm2d 384 True SiLU Conv2d 1728 False BatchNorm2d 384 True SiLU ____________________________________________________________________________ 64 x 8 x 1 x 1 Conv2d 1544 False SiLU ____________________________________________________________________________ 64 x 192 x 1 x 1 Conv2d 1728 False ____________________________________________________________________________ 64 x 32 x 56 x 56 Conv2d 6144 False BatchNorm2d 64 True ____________________________________________________________________________ 64 x 192 x 56 x 56 Conv2d 6144 False BatchNorm2d 384 True SiLU ____________________________________________________________________________ 64 x 192 x 28 x 28 Conv2d 4800 False BatchNorm2d 384 True SiLU ____________________________________________________________________________ 64 x 8 x 1 x 1 Conv2d 1544 False SiLU ____________________________________________________________________________ 64 x 192 x 1 x 1 Conv2d 1728 False ____________________________________________________________________________ 64 x 48 x 28 x 28 Conv2d 9216 False BatchNorm2d 96 True ____________________________________________________________________________ 64 x 288 x 28 x 28 Conv2d 13824 False BatchNorm2d 576 True SiLU Conv2d 7200 False BatchNorm2d 576 True SiLU ____________________________________________________________________________ 64 x 12 x 1 x 1 Conv2d 3468 False SiLU ____________________________________________________________________________ 64 x 288 x 1 x 1 Conv2d 3744 False ____________________________________________________________________________ 64 x 48 x 28 x 28 Conv2d 13824 False BatchNorm2d 96 True ____________________________________________________________________________ 64 x 288 x 28 x 28 Conv2d 13824 False BatchNorm2d 576 True SiLU Conv2d 7200 False BatchNorm2d 576 True SiLU ____________________________________________________________________________ 64 x 12 x 1 x 1 Conv2d 3468 False SiLU ____________________________________________________________________________ 64 x 288 x 1 x 1 Conv2d 3744 False ____________________________________________________________________________ 64 x 48 x 28 x 28 Conv2d 13824 False BatchNorm2d 96 True ____________________________________________________________________________ 64 x 288 x 28 x 28 Conv2d 13824 False BatchNorm2d 576 True SiLU ____________________________________________________________________________ 64 x 288 x 14 x 14 Conv2d 2592 False BatchNorm2d 576 True SiLU ____________________________________________________________________________ 64 x 12 x 1 x 1 Conv2d 3468 False SiLU ____________________________________________________________________________ 64 x 288 x 1 x 1 Conv2d 3744 False ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 27648 False BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 False BatchNorm2d 1152 True SiLU Conv2d 5184 False BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 False SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 False ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 55296 False BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 False BatchNorm2d 1152 True SiLU Conv2d 5184 False BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 False SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 False ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 55296 False BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 False BatchNorm2d 1152 True SiLU Conv2d 5184 False BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 False SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 False ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 55296 False BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 False BatchNorm2d 1152 True SiLU Conv2d 5184 False BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 False SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 False ____________________________________________________________________________ 64 x 96 x 14 x 14 Conv2d 55296 False BatchNorm2d 192 True ____________________________________________________________________________ 64 x 576 x 14 x 14 Conv2d 55296 False BatchNorm2d 1152 True SiLU Conv2d 14400 False BatchNorm2d 1152 True SiLU ____________________________________________________________________________ 64 x 24 x 1 x 1 Conv2d 13848 False SiLU ____________________________________________________________________________ 64 x 576 x 1 x 1 Conv2d 14400 False ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 78336 False BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 False BatchNorm2d 1632 True SiLU Conv2d 20400 False BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 False SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 False ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 110976 False BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 False BatchNorm2d 1632 True SiLU Conv2d 20400 False BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 False SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 False ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 110976 False BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 False BatchNorm2d 1632 True SiLU Conv2d 20400 False BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 False SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 False ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 110976 False BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 False BatchNorm2d 1632 True SiLU Conv2d 20400 False BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 False SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 False ____________________________________________________________________________ 64 x 136 x 14 x 14 Conv2d 110976 False BatchNorm2d 272 True ____________________________________________________________________________ 64 x 816 x 14 x 14 Conv2d 110976 False BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 816 x 7 x 7 Conv2d 20400 False BatchNorm2d 1632 True SiLU ____________________________________________________________________________ 64 x 34 x 1 x 1 Conv2d 27778 False SiLU ____________________________________________________________________________ 64 x 816 x 1 x 1 Conv2d 28560 False ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 189312 False BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 False BatchNorm2d 2784 True SiLU Conv2d 34800 False BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 False SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 False ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 False BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 False BatchNorm2d 2784 True SiLU Conv2d 34800 False BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 False SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 False ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 False BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 False BatchNorm2d 2784 True SiLU Conv2d 34800 False BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 False SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 False ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 False BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 False BatchNorm2d 2784 True SiLU Conv2d 34800 False BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 False SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 False ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 False BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 False BatchNorm2d 2784 True SiLU Conv2d 34800 False BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 False SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 False ____________________________________________________________________________ 64 x 232 x 7 x 7 Conv2d 322944 False BatchNorm2d 464 True ____________________________________________________________________________ 64 x 1392 x 7 x 7 Conv2d 322944 False BatchNorm2d 2784 True SiLU Conv2d 12528 False BatchNorm2d 2784 True SiLU ____________________________________________________________________________ 64 x 58 x 1 x 1 Conv2d 80794 False SiLU ____________________________________________________________________________ 64 x 1392 x 1 x 1 Conv2d 82128 False ____________________________________________________________________________ 64 x 384 x 7 x 7 Conv2d 534528 False BatchNorm2d 768 True ____________________________________________________________________________ 64 x 2304 x 7 x 7 Conv2d 884736 False BatchNorm2d 4608 True SiLU Conv2d 20736 False BatchNorm2d 4608 True SiLU ____________________________________________________________________________ 64 x 96 x 1 x 1 Conv2d 221280 False SiLU ____________________________________________________________________________ 64 x 2304 x 1 x 1 Conv2d 223488 False ____________________________________________________________________________ 64 x 384 x 7 x 7 Conv2d 884736 False BatchNorm2d 768 True ____________________________________________________________________________ 64 x 1536 x 7 x 7 Conv2d 589824 False BatchNorm2d 3072 True SiLU AdaptiveAvgPool2d AdaptiveMaxPool2d Flatten BatchNorm1d 6144 True Dropout ____________________________________________________________________________ 64 x 512 Linear 1572864 True ReLU BatchNorm1d 1024 True Dropout ____________________________________________________________________________ 64 x 37 Linear 18944 True ____________________________________________________________________________ Total params: 12,295,208 Total trainable params: 1,686,272 Total non-trainable params: 10,608,936 Optimizer used: <function Adam at 0x7f8f350cfc80> Loss function: LabelSmoothingCrossEntropy() Model frozen up to parameter group #1 Callbacks: - TrainEvalCallback - Recorder - ProgressCallback
That looks much better. Let's train!
learn.lr_find()
learn.fit_one_cycle(5, slice(3e-2))
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 1.935569 | 1.548306 | 0.198917 | 00:38 |
1 | 1.360869 | 1.022956 | 0.100135 | 00:37 |
2 | 1.070680 | 0.958684 | 0.085250 | 00:37 |
3 | 0.941763 | 0.930160 | 0.080514 | 00:37 |
4 | 0.876002 | 0.905405 | 0.071719 | 00:37 |
learn.save('stage_1')
Then we can unfreeze it and train a little more
learn.unfreeze()
learn.lr_find()
learn.fit_one_cycle(5, 1e-4)
epoch | train_loss | valid_loss | error_rate | time |
---|---|---|---|---|
0 | 0.856545 | 0.891247 | 0.066306 | 00:50 |
1 | 0.835565 | 0.874920 | 0.054804 | 00:50 |
2 | 0.801183 | 0.863947 | 0.051421 | 00:50 |
3 | 0.779860 | 0.863559 | 0.056157 | 00:50 |
4 | 0.772397 | 0.860884 | 0.052774 | 00:50 |
learn.save('model_2')
One of the hardest parts about training the EfficientNet
models is figuring out how to find the right learning rate that won't break everything, so choose cautiously and always a bit lower than what you may want to use after unfreezing
We barely under-matched our Resnet34, but we're using a model that is 57% the size of the Resnet34!
path = untar_data(URLs.IMAGEWOOF)
tfms = [[PILImage.create], [parent_label, Categorize()]]
item_tfms = [ToTensor(), Resize(128)]
batch_tfms = [FlipItem(), RandomResizedCrop(128, min_scale=0.35),
IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)]
items = get_image_files(path)
split_idx = GrandparentSplitter(valid_name='val')(items)
dsets = Datasets(items, tfms, splits=split_idx)
dls = dsets.dataloaders(after_item=item_tfms, after_batch=batch_tfms, bs=64)
dls.show_batch()
Let's walk through how we would do that. First let's grab our model and make our Learner
like we did before, with everything but the pretraining!
body = create_timm_body('efficientnet_b3a', pretrained=False)
head = create_head(1536, dls.c)
model = nn.Sequential(body, head)
apply_init(model[1], nn.init.kaiming_normal_)
learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy(),
splitter=default_split, metrics=accuracy)
Now, remember these are all random weights right now. How do we change this? We look at the state_dict
learn.model.state_dict()
The state dict tells us the parameters and weights at every layer. Now to work with it, we'll borrow some code from viraat. Basically what we want to do is:
- Keep two state_dict's, one of our new model and one of the old
- If a layer matches, copy the weights
- Move until there are no more layers
- Finally load the new state_dict generated
learn.model_dir
'models'
def transfer_learn(learn:Learner, name:Path, device:torch.device=None):
"Load model `name` from `self.model_dir` using `device`, defaulting to `self.dls.device`."
if device is None: device = learn.dls.device
learn.model_dir = Path(learn.model_dir)
if (learn.model_dir/name).with_suffix('.pth').exists(): model_path = (learn.model_dir/name).with_suffix('.pth')
else: model_path = name
new_state_dict = torch.load(model_path, map_location=device)['model']
learn_state_dict = learn.model.state_dict()
for name, param in learn_state_dict.items():
if name in new_state_dict:
input_param = new_state_dict[name]
if input_param.shape == param.shape:
param.copy_(input_param)
else:
print('Shape mismatch at:', name, 'skipping')
else:
print(f'{name} weight of the model not in pretrained weights')
learn.model.load_state_dict(learn_state_dict)
return learn
Now let's load some in!
learn = transfer_learn(learn, 'stage_1')
Shape mismatch at: 1.8.weight skipping
learn.model[1][8]
Linear(in_features=512, out_features=10, bias=False)
And we can see the only weight that wasn't loaded in was our new layer! Let's freeze and train our model
learn.freeze()
Let's see if it worked. We'll do a comparison test, 5 epochs without our transfer_learn
and five with
learn.fit_one_cycle(5, 3e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 1.282371 | 0.934233 | 0.832018 | 00:37 |
1 | 0.912630 | 0.814222 | 0.883685 | 00:38 |
2 | 0.811893 | 0.780937 | 0.893357 | 00:38 |
3 | 0.759423 | 0.770062 | 0.896920 | 00:38 |
4 | 0.730576 | 0.767888 | 0.898193 | 00:39 |
And now let's try on a regular non-transfered learner (at the same learning rate, frozen, etc)
body = create_timm_body('efficientnet_b3a', pretrained=False)
head = create_head(1536, dls.c)
model = nn.Sequential(body, head)
apply_init(model[1], nn.init.kaiming_normal_)
learn = Learner(dls, model, loss_func=LabelSmoothingCrossEntropy(),
splitter=default_split, metrics=accuracy)
learn.freeze()
learn.fit_one_cycle(5, 3e-3)
epoch | train_loss | valid_loss | accuracy | time |
---|---|---|---|---|
0 | 2.869522 | 2.830859 | 0.114024 | 00:38 |
1 | 2.684337 | 2.571606 | 0.124714 | 00:38 |
2 | 2.543600 | 2.666231 | 0.135912 | 00:38 |
3 | 2.485797 | 2.387705 | 0.146602 | 00:38 |
4 | 2.430820 | 2.360096 | 0.154238 | 00:38 |