Other Pretrained Models

Using Custom PyTorch Weights in fastai
Lesson 4

Lesson Video:

Introduction

So far we have looked at using vision_learner out of the box to bring in a resnet34.

But how does it work?

What are these new layers being added at the end of the model?

And can it be recreated without fastai?

That is what today’s topic will end on!

Get Some DataLoaders

Since today we are focused on model building and not so much a data application, we’ll go back to the tried-and-true PETs dataset using the DataBlock quickly:

from fastai.vision.all import *
path = untar_data(URLs.PETS)/'images'
fnames = get_image_files(path)
pat = r'/([^/]+)_\d+.*'
batch_tfms = [*aug_transforms(size=224, max_warp=0), Normalize.from_stats(*imagenet_stats)]
item_tfms = RandomResizedCrop(460, min_scale=0.75, ratio=(1.,1.))
bs=64

pets = DataBlock(
    blocks=(ImageBlock, CategoryBlock),
     get_items=get_image_files,
     splitter=RandomSplitter(),
     get_y=RegexLabeller(pat = r'/([^/]+)_\d+.*'),
     item_tfms=item_tfms,
     batch_tfms=batch_tfms
)
dls = pets.dataloaders(path, bs=bs)
100.00% [811712512/811706944 02:14<00:00]

Now that we have some dataloaders, let’s focus on the model.

timm

The next part of this lecture will be us attempting to recreate (in a similar fashion) how vision_learner creates a model from timm. First make sure timm is installed:

!pip install timm >> /dev/null
What is timm?

timm is a library by Ross Wightman that has a plethera of vision based PyTorch models with his own trained SOTA weights to choose from. Practically the vision image model zoo for PyTorch!

Next we’ll download a model off timm:

from timm import create_model
net = create_model("vit_tiny_patch16_224", pretrained=True)
from timm import create_model
net = create_model("vit_tiny_patch16_224", pretrained=True)

from timm import create_model

To create a PyTorch model in the timm library, the create_model function can be used.


 create_model("vit_tiny_patch16_224", pretrained=True)

The create_model function takes in a registered model class, which stems from this section of code. To download the ImageNet-1k weights, we pass in pretrained=True

Read the Docs!

To learn more about timm, be sure to checkout it’s official documentation

Now that we’ve downloaded a PyTorch model, how do we change it to act like the PyTorch models fastai creates?

Let’s revisit a resnet18 made through vision_learner to start.

Starting with Something Familiar

learn = vision_learner(dls, models.resnet18)

body

noun

The backbone of a neural network, typically pretrained

The body of a Resnet 34 model

head

noun

The last, or last few, layers of a neural network; typically consists of everything after the final pooling layer>

Predictions from the model are the outputs from the head of the network

learn.model[-1]
Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): fastai.layers.Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=1024, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=37, bias=False)
)

What I’ve just shown here is what a function called create_head performs, this is not the original head of our model! Instead fastai added a new head on top of the model, consisting of some linear layers, ReLU activations, and batch normalization. Plus a new pooling layer.

Moreso, what happens when we try to index into the new model we made:

net[-1]
TypeError: 'VisionTransformer' object is not subscriptable

Uh oh, why does this not work?

len(learn.model)
2
len(net)
TypeError: object of type 'VisionTransformer' has no len()

The answer lies in what the model is made of.

The fastai model is wrapped through a nn.Sequential layer, whereas net is not!

Types of PyTorch Models

Typically there’s a few ways to create a PyTorch model class. The first of which is as an individual nn.Module, such as what gave us our error before:

class MyModel(nn.Module):
    def __init__(self):
        self.l1 = nn.Linear(1,1)
        self.l2 = nn.linear(1,1)
    def forward(self, x):
        return self.l2(self.l1(x))

Basically when we create a PyTorch model this way, doing MyModel(x) will call the forward function automatically. However because of this setup technically the model has no length nor is it indexable. The layers are defined as properties we rely on, rather than indicies.

Another way to think about it is nn.Module.__call__ really just calls nn.Module.forward

The other way to do it is to wrap the model definition in something called nn.Sequential.

As the name implies, each layer is a sequence one after the other that will be called. So our model before can be rewritten as:

class MyModel(nn.Sequential):
    def __init__(self):
        layers = [
            nn.Linear(1,1),
            nn.Linear(1,1),
        ]
        super().__init__(*layers)
Why no def forward?

Because nn.Sequential passes the input into each layer sent to its __init__ sequentially, no forward function needs to be defined as that is what it does natively.

Now when we do MyModel() it will automatically create a nn.Sequential model that we can index into to get a particular layer:

net = MyModel()
net[0], net[1]
(Linear(in_features=1, out_features=1, bias=True),
 Linear(in_features=1, out_features=1, bias=True))

How does this relate to what we just saw?

vision_learner takes this idea and creates a nn.Sequential model consisting of a body and a head.

The body is the original model but the pooling layer and last linear layer. The head consists of that custom head we saw a moment ago.

Where does the fastai course talk about this?

It actually won’t, really. Because this isn’t fastai at this point. This is now raw PyTorch. Getting much more intimate and familiar with PyTorch is how you succeed at fastai, and Deep Learning in general

Don’t believe me? Let’s look at a cut resnet18 vs a created resnet18:

----------------------------------------------------------------
          Layer (type)               Output Shape         Param #
================================================================
              Conv2d-1         [-1, 64, 112, 112]           9,408
         BatchNorm2d-2         [-1, 64, 112, 112]             128
                ReLU-3         [-1, 64, 112, 112]               0
           MaxPool2d-4           [-1, 64, 56, 56]               0
              Conv2d-5           [-1, 64, 56, 56]          36,864
         BatchNorm2d-6           [-1, 64, 56, 56]             128
                ReLU-7           [-1, 64, 56, 56]               0
              Conv2d-8           [-1, 64, 56, 56]          36,864
         BatchNorm2d-9           [-1, 64, 56, 56]             128
               ReLU-10           [-1, 64, 56, 56]               0
         BasicBlock-11           [-1, 64, 56, 56]               0
             Conv2d-12           [-1, 64, 56, 56]          36,864
        BatchNorm2d-13           [-1, 64, 56, 56]             128
               ReLU-14           [-1, 64, 56, 56]               0
             Conv2d-15           [-1, 64, 56, 56]          36,864
        BatchNorm2d-16           [-1, 64, 56, 56]             128
               ReLU-17           [-1, 64, 56, 56]               0
         BasicBlock-18           [-1, 64, 56, 56]               0
             Conv2d-19          [-1, 128, 28, 28]          73,728
        BatchNorm2d-20          [-1, 128, 28, 28]             256
               ReLU-21          [-1, 128, 28, 28]               0
             Conv2d-22          [-1, 128, 28, 28]         147,456
        BatchNorm2d-23          [-1, 128, 28, 28]             256
             Conv2d-24          [-1, 128, 28, 28]           8,192
        BatchNorm2d-25          [-1, 128, 28, 28]             256
               ReLU-26          [-1, 128, 28, 28]               0
         BasicBlock-27          [-1, 128, 28, 28]               0
             Conv2d-28          [-1, 128, 28, 28]         147,456
        BatchNorm2d-29          [-1, 128, 28, 28]             256
               ReLU-30          [-1, 128, 28, 28]               0
             Conv2d-31          [-1, 128, 28, 28]         147,456
        BatchNorm2d-32          [-1, 128, 28, 28]             256
               ReLU-33          [-1, 128, 28, 28]               0
         BasicBlock-34          [-1, 128, 28, 28]               0
             Conv2d-35          [-1, 256, 14, 14]         294,912
        BatchNorm2d-36          [-1, 256, 14, 14]             512
               ReLU-37          [-1, 256, 14, 14]               0
             Conv2d-38          [-1, 256, 14, 14]         589,824
        BatchNorm2d-39          [-1, 256, 14, 14]             512
             Conv2d-40          [-1, 256, 14, 14]          32,768
        BatchNorm2d-41          [-1, 256, 14, 14]             512
               ReLU-42          [-1, 256, 14, 14]               0
         BasicBlock-43          [-1, 256, 14, 14]               0
             Conv2d-44          [-1, 256, 14, 14]         589,824
        BatchNorm2d-45          [-1, 256, 14, 14]             512
               ReLU-46          [-1, 256, 14, 14]               0
             Conv2d-47          [-1, 256, 14, 14]         589,824
        BatchNorm2d-48          [-1, 256, 14, 14]             512
               ReLU-49          [-1, 256, 14, 14]               0
         BasicBlock-50          [-1, 256, 14, 14]               0
             Conv2d-51            [-1, 512, 7, 7]       1,179,648
        BatchNorm2d-52            [-1, 512, 7, 7]           1,024
               ReLU-53            [-1, 512, 7, 7]               0
             Conv2d-54            [-1, 512, 7, 7]       2,359,296
        BatchNorm2d-55            [-1, 512, 7, 7]           1,024
             Conv2d-56            [-1, 512, 7, 7]         131,072
        BatchNorm2d-57            [-1, 512, 7, 7]           1,024
               ReLU-58            [-1, 512, 7, 7]               0
         BasicBlock-59            [-1, 512, 7, 7]               0
             Conv2d-60            [-1, 512, 7, 7]       2,359,296
        BatchNorm2d-61            [-1, 512, 7, 7]           1,024
               ReLU-62            [-1, 512, 7, 7]               0
             Conv2d-63            [-1, 512, 7, 7]       2,359,296
        BatchNorm2d-64            [-1, 512, 7, 7]           1,024
               ReLU-65            [-1, 512, 7, 7]               0
         BasicBlock-66            [-1, 512, 7, 7]               0
- AdaptiveAvgPool2d-67            [-1, 512, 1, 1]               0
-            Linear-68                 [-1, 1000]         513,000
================================================================
Total params: 11,689,512
Total params: 11,176,512

Now that we know how it works, let’s create a body and a head for a timm model by using the most minimal fastai we can.

fastai has a special TimmBody class which will create the body similar to what I just showed above. I’ve recreated it’s core below:

def custom_cut_model(model:nn.Module, cut:typing.Union[int, typing.Callable]):
    """
    Cuts `model` into an `nn.Sequential` based on `cut`. 
    """
    if isinstance(cut, int):
        return nn.Sequential(*list(model.children())[:cut])
    elif callable(cut):
        return cut(model)
    else:
        raise NameError("`cut` must either be an integer or a function")
class CustomTimmBody(nn.Module):
    """
    A small submodule to work with `timm` models more easily
    """
    def __init__(
        self, 
        model, 
        pretrained:bool=True, 
        cut=None, 
        n_in:int=3
    ):
        super().__init__()
        self.needs_pooling = model.default_cfg.get('pool_size', None)
        if cut is None:
            self.model = model
        else:
            self.model = custom_cut_model(model, cut)
    
    def forward(self, x): 
        if self.needs_pooling:
            return self.model.forward_features(x)
        else:
            return self.model(x)
body = CustomTimmBody(
    create_model("vit_tiny_patch16_224", pretrained=True, num_classes=0, in_chans=3)
).train()

Now that we’ve created a body, we can use create_head to make a head for the model. It takes in the number of features in the last layer, and the number of classes for our final layer:

head = create_head(body.model.num_features, dls.c, pool=None)
head
Sequential(
  (0): BatchNorm1d(192, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (1): Dropout(p=0.25, inplace=False)
  (2): Linear(in_features=192, out_features=512, bias=False)
  (3): ReLU(inplace=True)
  (4): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): Dropout(p=0.5, inplace=False)
  (6): Linear(in_features=512, out_features=37, bias=False)
)

Note that we don’t have a pooling layer in this case, because vit is different.

Let’s pass in a random input to make sure everything gets output correctly:

x = torch.randn(2,3,224,224)
out = head(body(x))
out, out.shape
(tensor([[-0.0650, -0.1741,  0.1089, -1.1668, -0.6229,  0.8892,  0.4859, -0.1704,
          -1.4127,  0.7338,  1.0354,  0.6033,  0.3576, -0.2332,  0.7073, -0.7090,
           0.3852, -0.3440,  0.4645,  0.4209,  1.2090,  0.3201,  0.6480, -1.4800,
           0.7253, -0.1806,  0.7261,  0.6329,  0.5336, -1.4665, -0.9681, -0.3387,
          -0.3044, -0.6216,  2.3369, -0.0941,  0.3703],
         [-0.4785,  1.2014, -0.2310,  1.4840, -0.4752,  0.3363,  0.1472, -0.1076,
           0.8156, -0.6819, -0.6366, -0.0721, -0.8710,  0.2871, -0.4673,  0.5040,
           0.5288,  1.5585, -0.3499,  0.5983, -0.1188,  0.1523, -0.7708,  0.8939,
          -0.0318, -0.8048, -0.2581,  0.5921,  0.1012,  0.1626,  0.2249,  0.4605,
           0.1858, -0.4212, -0.0047,  0.6470, -0.7384]], grad_fn=<MmBackward0>),
 torch.Size([2, 37]))

The last thing we need to do is initialize the new model head weights. fastai uses nn.init.kaiming_normal_ by default:

apply_init?
Signature: apply_init(m, func=<function kaiming_normal_ at 0x7f21f43d5630>)
Docstring: Initialize all non-batchnorm layers of `m` with `func`.
File:      /opt/conda/lib/python3.10/site-packages/fastai/torch_core.py
Type:      function
apply_init(head)

We can see that the outputs are a bit different now:

head(body(x))
tensor([[ 0.2204, -3.4587, -0.5113, -1.4922, -1.2036,  3.9744, -1.5592, -1.1304,
          1.1073,  0.4745,  1.4827,  0.8954, -2.0673,  0.3289,  1.6994,  0.0623,
          1.7268,  2.5922, -1.4811, -1.4121,  0.7921,  1.5231,  1.2327, -0.0762,
          0.5696, -1.2702,  3.3962, -2.2976,  2.4296, -0.0874, -0.0975,  0.0168,
          2.2922,  2.0433,  1.1191,  1.1637, -2.1250],
        [ 1.1871,  0.2985,  2.6397, -2.9931,  3.5329, -3.3390,  3.3316, -0.8618,
          0.0611,  1.0972, -1.8489, -3.1779,  0.2882,  1.3150,  0.7034, -0.7141,
         -0.5197, -3.5473,  1.0325,  1.3873,  2.3772, -3.8408, -0.3776,  0.0446,
         -1.7974,  1.3227, -0.8745,  3.6397, -2.2262, -0.2738,  1.7177,  0.8619,
         -3.6088, -4.8258,  0.2685,  2.7378,  1.7348]], grad_fn=<MmBackward0>)

Body and Head? Check. What else is needed?

So far we’ve seen how to create a body and a head for a particular model.

But how do we tell fastai how to split the body from the head? And what even is that?

Some Definitions:

split

adjective

An arrangement of groups of layers by some criteria

The model was split between the body and the head

freeze

verb

To make certain layers of a model untrainable

We froze the backbone of the pretrained model, but not the head

As the definitions imply, we need to create a split that tells fastai’s Learner that the body should be made untrainable at first, and the head should be the only part of the model that gets updated.

The Learner takes in a splitter. Splitters are defined as:

def my_split_func(model:nn.Module):
    "A function that splits layers by their parameters"
    return L(model[0], model[1:]).map(params)
def my_split_func(model:nn.Module):
    "A function that splits layers by their parameters"
    return L(model[0], model[1:]).map(params)

def my_split_func(model:nn.Module):

A splitter needs to take in some model that gets passed to the Learner object. Typically these are nn.Sequential as we saw earlier, but they can also just be any PyTorch model arrangement


L(model[0], model[1: ])

We need to create a list of each group we want to have. In this case since the model we made earlier is a nn.Sequential, this will split the model by the body and the head.


map(params)

Finally each layer group in the arrangement gets passed to the params function from fastai. This will return every single layer parameter inside of whatever gets sent to it. fastai will then use this mapping to know what parameters belong to which group

We can write a similar splitting function for our purposes:

def splitter(model):
    "Splits a model by head and body"
    return L(model[0], model[1]).map(params)

How the params actually work is these are passed to the Optimizer as layer groups, and we can freeze these groups accordingly in PyTorch or selectively update them.

And now let’s create a Learner:

learn = Learner(
    dls,
    nn.Sequential(body, head),
    splitter=splitter
)
print(learn.summary()[-250:])
Total trainable params: 5,605,056
Total non-trainable params: 0

Optimizer used: <function Adam>
Loss function: FlattenedLoss of CrossEntropyLoss()

Callbacks:
  - TrainEvalCallback
  - CastToTensor
  - Recorder
  - ProgressCallback

You’ll notice that the non-trainable params are currently 0. This is because vision_learner, tabular_learner, and all the others will call learn.freeze() if using a pretrained model to freeze that backbone. We need to do that here as well.

What we should see after calling learn.freeze is that learn.summary() should show a different set of frozen and non-frozen parameters

learn.freeze()
print(learn.summary()[-295:])
Total trainable params: 128,256
Total non-trainable params: 5,476,800

Optimizer used: <function Adam>
Loss function: FlattenedLoss of CrossEntropyLoss()

Model frozen up to parameter group #1

Callbacks:
  - TrainEvalCallback
  - CastToTensor
  - Recorder
  - ProgressCallback

Which we do! Now you can use learn.fine_tune, learn.unfreeze, and more as before!