Deploying without fastai

Taking a trained model and removing all the magic
Lesson 5

Lesson Video:

Introduction

In this version of the lesson, we’re going to take that model we just trained and remove every part of fastai when deploying it. This will involve:

  1. Recreating the PyTorch model from “scratch” (allowing timm only)
  2. Recreating the fastai transforms in raw PyTorch
  3. Recreating the decoding stage

Starting with the Model

First let’s tackle the model. Since we can’t bring in fastai we will only use imports from torch, torchvision, PIL, and timm. (and the stdlib):

import torch
import torch.nn.functional as F
import torchvision.transforms.functional as tvf
import torchvision.transforms as tvtfms
import operator as op
from PIL import Image
from torch import nn
from timm import create_model

# For type hinting later on
import collections
import typing

Next let’s create the core “body” of the model with timm’s create_model:

net = create_model("vit_tiny_patch16_224", pretrained=False, num_classes=0, in_chans=3)

The last step is to recreate fastai’s create_head. We can do this easily with a nn.Sequential directly:

head = nn.Sequential(
    nn.BatchNorm1d(192),
    nn.Dropout(0.25),
    nn.Linear(192, 512, bias=False),
    nn.ReLU(inplace=True),
    nn.BatchNorm1d(512),
    nn.Dropout(0.5),
    nn.Linear(512, 37, bias=False)
)

Now let’s merge them together:

model = nn.Sequential(net, head)

But how do we know it works?

Let’s bring in that state dictionary:

state = torch.load("models/MyModel.pth")

We load in the weights directly with model.load_state_dict and pass in the state we want to use:

model.load_state_dict(state);
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[72], line 1
----> 1 model.load_state_dict(state);

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
   1666         error_msgs.insert(
   1667             0, 'Missing key(s) in state_dict: {}. '.format(
   1668                 ', '.join('"{}"'.format(k) for k in missing_keys)))
   1670 if len(error_msgs) > 0:
-> 1671     raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                        self.__class__.__name__, "\n\t".join(error_msgs)))
   1673 return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for Sequential:
    Missing key(s) in state_dict: "0.cls_token", "0.pos_embed", "0.patch_embed.proj.weight", "0.patch_embed.proj.bias", "0.blocks.0.norm1.weight", "0.blocks.0.norm1.bias", "0.blocks.0.attn.qkv.weight", "0.blocks.0.attn.qkv.bias", "0.blocks.0.attn.proj.weight", "0.blocks.0.attn.proj.bias", "0.blocks.0.norm2.weight", "0.blocks.0.norm2.bias", "0.blocks.0.mlp.fc1.weight", "0.blocks.0.mlp.fc1.bias", "0.blocks.0.mlp.fc2.weight", "0.blocks.0.mlp.fc2.bias", "0.blocks.1.norm1.weight", "0.blocks.1.norm1.bias", "0.blocks.1.attn.qkv.weight", "0.blocks.1.attn.qkv.bias", "0.blocks.1.attn.proj.weight", "0.blocks.1.attn.proj.bias", "0.blocks.1.norm2.weight", "0.blocks.1.norm2.bias", "0.blocks.1.mlp.fc1.weight"

Uh, oh! We have an issue! Something about missing keys, why is that?

Messing with the state_dict

When fastai created our model, the layers in the model’s state_dict follow a particular order. This particular order then gets translated into how these weights get saved.

What happened here is the namings don’t quite align from our current model to our new one. Such as:

list(model.state_dict().keys())[:5]
['0.cls_token',
 '0.pos_embed',
 '0.patch_embed.proj.weight',
 '0.patch_embed.proj.bias',
 '0.blocks.0.norm1.weight']
list(state.keys())[:5]
['0.model.cls_token',
 '0.model.pos_embed',
 '0.model.patch_embed.proj.weight',
 '0.model.patch_embed.proj.bias',
 '0.model.blocks.0.norm1.weight']

See the extra .model attribute in there?

How can we get past this?

This technique I’m introducing you is also the general idea of how to perform “Transfer learning, twice”. Also known as re-using different pretrained weights from other model setups and re-applying them to your own model.

The basic idea is that for every single parameter, we look at its equivalent layer name in the other model and load it in. This let’s us also add or remove certain parts of these layer names, such as the .model we are currently missing:

def copy_weight(name, parameter, state_dict):
    """
    Takes in a layer `name`, model `parameter`, and `state_dict`
    and loads the weights from `state_dict` into `parameter`
    if it exists.
    """
    # Part of the body
    if name[0] == "0":
        name = name[:2] + "model." + name[2:]
    if name in state_dict.keys():
        input_parameter = state_dict[name]
        if input_parameter.shape == parameter.shape:
            parameter.copy_(input_parameter)
        else:
            print(f'Shape mismatch at layer: {name}, skipping')
    else:
        print(f'{name} is not in the state_dict, skipping.')
def apply_weights(input_model:nn.Module, input_weights:collections.OrderedDict, application_function:callable):
    """
    Takes an input state_dict and applies those weights to the `input_model`, potentially 
    with a modifier function.
    
    Args:
        input_model (`nn.Module`):
            The model that weights should be applied to
        input_weights (`collections.OrderedDict`):
            A dictionary of weights, the trained model's `state_dict()`
        application_function (`callable`):
            A function that takes in one parameter and layer name from `input_model`
            and the `input_weights`. Should apply the weights from the state dict into `input_model`.
    """
    model_dict = input_model.state_dict()
    for name, parameter in model_dict.items():
        application_function(name, parameter, input_weights)
    input_model.load_state_dict(model_dict)

Now we can use this on our model:

apply_weights(model, state, copy_weight)

And now our new weights are loaded in!

The Transforms

Next we have to apply the transforms. Our old model had the the following fastai transform setup:

Item Transforms:

  • RandomResizedCrop(460,460)
  • ToTensor()

Batch Transforms:

  • IntToFloatTensor()
  • aug_transforms(size=224)
  • Normalize.from_stats(*imagenet_stats)

Given this, let’s examine some transforms:

Recreating CropPad

Our first experiment should be that ensuring we can find something that will exactly recreate CropPad fastai provides. Let’s run a test with the native one in torchvision:

from fastai.vision.data import PILImage
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files
import fastai.vision.augment as fastai_aug
import math
import numpy as np
Important

These imports are only needed so we can test on an image and verify that the image we create lines up with the augmentation fastai performs.

numpy is here so we can easily perform an assert statement on the image data.

First we need an image to test on. We’ll grab one of the first ones from the PETs dataset:

path = untar_data(URLs.PETS)/'images'
fname = get_image_files(path)[0]
fname
Path('/home/zach/.fastai/data/oxford-iiit-pet/images/Siamese_160.jpg')

Next we need to open them in their respective API’s:

im_pil = Image.open(fname)
im_fastai = PILImage.create(fname)

Our first assert statement can be made, ensuring that these two images are the same:

assert (np.array(im_pil) == np.array(im_fastai)).all()
assert (np.array(im_pil) == np.array(im_fastai)).all()

np.array(im)

In order to compare these two images, we can convert them into numpy arrays. Their shapes will be a bit odd but we only care about the pixel values in each image as they should be the exact same


(a==b).func()

When wrapping parenthesis around a particular function call, you can perform further operations that the class provides after it’s output has been made.


.all()

.all will take a boolean array (what our np.array == np.array created) and ensure that every single value in it is True.

We’ll create what should be equivalent transforms based on their names. CropPad from fastai is what gets used, and the closest transform in torchvision seems to be CenterCrop (as we center crop and pad). Let’s see if they are the same:

crop_fastai = fastai_aug.CropPad((460,460))
crop_torch = tvtfms.CenterCrop((460,460))
assert (np.array(crop_fastai(im_fastai)) == np.array(crop_torch(im_pil))).all()
AssertionError: 

Oh no! They arent the same after all! What can we do?

I’ve gone through and recreated how fastai performs CropPad using pure PIL, PyTorch, and python below. I highly recommend reading the explainations to understand what is going on:

def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and crops it `size` unless one 
    dimension is larger than the actual image. Padding 
    must be performed afterwards if so.
    
    Args:
        image (`PIL.Image`):
            An image to perform cropping on
        size (`tuple` of integers):
            A size to crop to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    top = max(top, 0)
    left = max(left, 0)
    
    height = min(top + size[0], image.shape[-1])
    width = min(left + size[1], image.shape[-2])
    return image.crop((top, left, height, width))
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and crops it `size` unless one 
    dimension is larger than the actual image. Padding 
    must be performed afterwards if so.
    
    Args:
        image (`PIL.Image`):
            An image to perform cropping on
        size (`tuple` of integers):
            A size to crop to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    top = max(top, 0)
    left = max(left, 0)
    
    height = min(top + size[0], image.shape[-1])
    width = min(left + size[1], image.shape[-2])
    return image.crop((top, left, height, width))

    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2

We need to calculate the coordinates of the top left corner of the crop as Image.crop expects a top, left, height, and width to crop to. We do so by taking a dimension subtract the new dimension from it then divide it by two.
value = (dimension - new_dimension) // 2


    top = max(top, 0)
    left = max(left, 0)

Any negatives that result from our previous calculation are a sign that padding is needed. Since we’re only cropping we need to zero these values out to get a valid point on the image


    height = min(top + size[0], image.shape[-1])
    width = min(left + size[1], image.shape[-2])

Finally we find the smaller of the two dimensions: the proposed new dimension or the current dimension of the image


    return image.crop((top, left, height, width))

All of this gets passed into PIL.Image.crop as a bounding box to crop to

def pad(image, size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and pads it to `size` with
    zeros.
    
    Args:
        image (`PIL.Image`):
            An image to perform padding on
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    pad_top = max(-top, 0)
    pad_left = max(-left, 0)
    
    height, width = (
        max(size[1] - image.shape[-1] + top, 0), 
        max(size[0] - image.shape[-2] + left, 0)
    )
    return tvf.pad(
        image, 
        [pad_top, pad_left, height, width], 
        padding_mode="constant"
    )
def pad(image, size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and pads it to `size` with
    zeros.
    
    Args:
        image (`PIL.Image`):
            An image to perform padding on
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    pad_top = max(-top, 0)
    pad_left = max(-left, 0)
    
    height, width = (
        max(size[1] - image.shape[-1] + top, 0), 
        max(size[0] - image.shape[-2] + left, 0)
    )
    return tvf.pad(
        image, 
        [pad_top, pad_left, height, width], 
        padding_mode="constant"
    )

    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2

We need to calculate the coordinates of the top left corner of the crop as Image.crop expects a top, left, height, and width to crop to. We do so by taking a dimension subtract the new dimension from it then divide it by two.
value = (dimension - new_dimension) // 2


    pad_top = max(-top, 0)
    pad_left = max(-left, 0)

Before we took the maximum of the top and left compared to zero to find the corner. This time we take the maximum of the negative coordinates to tell us how much padding is needed. The more negative the number originally was, the more padding is needed.


    height, width = (
        max(size[1] - image.shape[-1] + top, 0), 
        max(size[0] - image.shape[-2] + left, 0)
    )

value = (new_dimension - dimension + offset)
The maximum between this and zero will be used in our padding operation as the new dimension length


    return tvf.pad(
        image, 
        [pad_top, pad_left, height, width], 
        padding_mode="constant"
    )

Finally we call the pad transform from torchvision.transforms.functional passing in the image, the bounding box, and a padding mode. "constant" here will fill all the new values with 0, what fastai does.

Let’s test out our new augmentation!

size = (460,460)
tfmd_img = pad(crop(im_pil, size),size)
(np.array(tfmd_img) == crop_fastai(im_fastai)).all()
True

It works!

However it’s not quite fastai’s regular RandomResizedCrop, as there’s still a little more that needs to be done. When using this transform, the validation set adds some extra space and resizes the image based on that. The below code takes our pad and crop we just made and applies a new “Center Crop” based on RandomResizedCrop:

random_crop_fastai = fastai_aug.RandomResizedCrop((460,460))
(np.array(tfmd_img) == random_crop_fastai(im_fastai, split_idx=1)).all()
False
def resized_crop_pad(
    image: typing.Union[Image.Image, torch.tensor],
    size: typing.Tuple[int, int],
    extra_crop_ratio: float = 0.14,
) -> Image:
    """
    Takes a `PIL.Image`, resize it according to the
    `extra_crop_ratio`, and then crops and pads
    it to `size`.

    Args:
        image (`PIL.Image`):
            An image to perform padding on
        size (`tuple` of integers):
            A size to crop and pad to, should be in the form
            of (width, height)
        extra_crop_ratio (`float`):
            The ratio of size at the edge cropped out. Default 0.14
    """

    maximum_space = max(size[0], size[1])
    extra_space = maximum_space * extra_crop_ratio
    extra_space = math.ceil(extra_space / 8) * 8
    extended_size = (size[0] + extra_space, size[1] + extra_space)
    resized_image = image.resize(extended_size, resample=Image.Resampling.BILINEAR)

    if extended_size != size:
        resized_image = pad(crop(resized_image, size), size)

    return resized_image
tfmd_img = resized_crop_pad(im_pil, size)
(np.array(tfmd_img) == random_crop_fastai(im_fastai, split_idx=1)).all()
True

Random Resized Crop on GPU

As you can imagine, a batch-applied Resize transform is also a bit different.

There’s quite a bit happening in the below code, but just trust that this will adapt the RandomResizedCropGPU transform that gets added when doing aug_transforms and passing a size and performs what occurs in the validation set (also known as center cropping):

def gpu_crop(
    batch:torch.tensor, 
    size:typing.Tuple[int,int]
):
    """
    Crops each image in `batch` to a particular `size`.
    
    Args:
        batch (array of `torch.Tensor`):
            A batch of images, should be of shape `NxCxWxH`
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        A batch of cropped images
    """
    # Split into multiple lines for clarity
    affine_matrix = torch.eye(3, device=batch.device).float()
    affine_matrix = affine_matrix.unsqueeze(0)
    affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
    affine_matrix = affine_matrix.contiguous()[:,:2]
    
    coords = F.affine_grid(
        affine_matrix, batch.shape[:2] + size, align_corners=True
    )
    
    top_range, bottom_range = coords.min(), coords.max()
    zoom = 1/(bottom_range - top_range).item()*2
    
    resizing_limit = min(
        batch.shape[-2]/coords.shape[-2],
        batch.shape[-1]/coords.shape[-1]
    )/2
    
    if resizing_limit > 1 and resizing_limit > zoom:
        batch = F.interpolate(
            batch, 
            scale_factor=1/resizing_limit, 
            mode='area', 
            recompute_scale_factor=True
        )
    return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)
def gpu_crop(
    batch:torch.tensor, 
    size:typing.Tuple[int,int]
):
    """
    Crops each image in `batch` to a particular `size`.
    
    Args:
        batch (array of `torch.Tensor`):
            A batch of images, should be of shape `NxCxWxH`
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        A batch of cropped images
    """
    # Split into multiple lines for clarity
    affine_matrix = torch.eye(3, device=batch.device).float()
    affine_matrix = affine_matrix.unsqueeze(0)
    affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
    affine_matrix = affine_matrix.contiguous()[:,:2]
    
    coords = F.affine_grid(
        affine_matrix, batch.shape[:2] + size, align_corners=True
    )
    
    top_range, bottom_range = coords.min(), coords.max()
    zoom = 1/(bottom_range - top_range).item()*2
    
    resizing_limit = min(
        batch.shape[-2]/coords.shape[-2],
        batch.shape[-1]/coords.shape[-1]
    )/2
    
    if resizing_limit > 1 and resizing_limit > zoom:
        batch = F.interpolate(
            batch, 
            scale_factor=1/resizing_limit, 
            mode='area', 
            recompute_scale_factor=True
        )
    return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)

    affine_matrix = torch.eye(3, device=batch.device).float()
    affine_matrix = affine_matrix.unsqueeze(0)
    affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
    affine_matrix = affine_matrix.contiguous()[:,:2]
First define an affine matrix of 3x3 such that:
[1,0,0]
[0,1,0]
[0,0,1]

Then create a sampling grid based on theta, the affine_matrix we just made, and make it of shape NxCxW2xH2 where W2 and H2 are the new widths and heights.


    top_range, bottom_range = coords.min(), coords.max()
    zoom = 1/(bottom_range - top_range).item()*2

Then calculate whether we are zooming in or out of the image. If zoom > 1, then we are zooming in


    resizing_limit = min(
        batch.shape[-2]/coords.shape[-2],
        batch.shape[-1]/coords.shape[-1]
    )/2

Calculate the amount we are resizing by with 100% extra margin


    if resizing_limit > 1 and resizing_limit > zoom:
        batch = F.interpolate(
            batch, 
            scale_factor=1/resizing_limit, 
            mode='area', 
            recompute_scale_factor=True
        )

If we are resizing by over 200% and zooming in less than that, perform an interpolation first


    return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)

Finally perform a grid sample to squish the image down to the new shape based on coords

Let’s test to make sure this aligns properly:

# fastai augmentations
tt_fastai = fastai_aug.ToTensor()
i2f_fastai = fastai_aug.IntToFloatTensor()
rrc_fastai = fastai_aug.RandomResizedCropGPU((224,224))

# torchvision augmentations
tt_torch = tvtfms.ToTensor()

# apply fastai augmentations
base_im_fastai = crop_fastai(im_fastai)
result_im_fastai = rrc_fastai(
    i2f_fastai(
        tt_fastai(base_im_fastai).unsqueeze(0)
    ), split_idx=1
)

# apply torchvision augmentations
result_im_tv = gpu_crop(tt_torch(tfmd_img).unsqueeze(0), (224,224))
torch.allclose(result_im_fastai, result_im_tv)
True

Excellent! We’ve successfully recreated the entire fastai preprocessing data augmentation pipeline!

All that’s left is Normalize:

norm_torch = tvtfms.Normalize([0.485, 0.456, 0.406], [0.229,0.224,0.225])
Note

Because of torchvision’s ToTensor implementation, we don’t need to worry about IntToFloatTensor as we saw a moment ago.

Let’s try them out fully this time:

# fastai augmentations
norm_fastai = fastai_aug.Normalize.from_stats(*fastai_aug.imagenet_stats, cuda=False)
# apply fastai augmentations
base_im_fastai = crop_fastai(im_fastai)
result_im_fastai = norm_fastai(
    rrc_fastai(
        i2f_fastai(
            tt_fastai(base_im_fastai).unsqueeze(0)
        ), split_idx=1
    )
)

# apply torchvision augmentations
result_im_tv = norm_torch(gpu_crop(tt_torch(tfmd_img).unsqueeze(0), (224,224)))
torch.allclose(result_im_fastai, result_im_tv)
True

We’ve fully recreated the fastai validation preprocessing pipeline without fastai. Now we can deploy it all!

Code

Below is the full code for the transforms. To my knowledge this is the only version that exists currently that fully recreates the transforms in a clear and proper way.

import typing
from PIL import Image
import torchvision.transforms.functional as tvf
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and crops it `size` unless one 
    dimension is larger than the actual image. Padding 
    must be performed afterwards if so.
    
    Args:
        image (`PIL.Image`):
            An image to perform cropping on
        size (`tuple` of integers):
            A size to crop to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    top = max(top, 0)
    left = max(left, 0)
    
    height = min(top + size[0], image.shape[-1])
    width = min(left + size[1], image.shape[-2])
    return image.crop((top, left, height, width))
def pad(image, size:typing.Tuple[int,int]) -> Image:
    """
    Takes a `PIL.Image` and pads it to `size` with
    zeros.
    
    Args:
        image (`PIL.Image`):
            An image to perform padding on
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        An augmented `PIL.Image`
    """
    top = (image.shape[-1] - size[0]) // 2
    left = (image.shape[-2] - size[1]) // 2
    
    pad_top = max(-top, 0)
    pad_left = max(-left, 0)
    
    height, width = (
        max(size[1] - image.shape[-1] + top, 0), 
        max(size[0] - image.shape[-2] + left, 0)
    )
    return tvf.pad(
        image, 
        [pad_top, pad_left, height, width], 
        padding_mode="constant"
    )
def gpu_crop(
    batch:torch.tensor, 
    size:typing.Tuple[int,int]
):
    """
    Crops each image in `batch` to a particular `size`.
    
    Args:
        batch (array of `torch.Tensor`):
            A batch of images, should be of shape `NxCxWxH`
        size (`tuple` of integers):
            A size to pad to, should be in the form
            of (width, height)
            
    Returns:
        A batch of cropped images
    """
    # Split into multiple lines for clarity
    affine_matrix = torch.eye(3, device=batch.device).float()
    affine_matrix = affine_matrix.unsqueeze(0)
    affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
    affine_matrix = affine_matrix.contiguous()[:,:2]
    
    coords = F.affine_grid(
        affine_matrix, batch.shape[:2] + size, align_corners=True
    )
    
    top_range, bottom_range = coords.min(), coords.max()
    zoom = 1/(bottom_range - top_range).item()*2
    
    resizing_limit = min(
        batch.shape[-2]/coords.shape[-2],
        batch.shape[-1]/coords.shape[-1]
    )/2
    
    if resizing_limit > 1 and resizing_limit > zoom:
        batch = F.interpolate(
            batch, 
            scale_factor=1/resizing_limit, 
            mode='area', 
            recompute_scale_factor=True
        )
    return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)

Gradio Apps

See the source for both gradio apps here:

  • No fastai: https://huggingface.co/spaces/muellerzr/deployment-no-fastai/
  • fastai: https://huggingface.co/spaces/muellerzr/deployment-fastai