import torch
import torch.nn.functional as F
import torchvision.transforms.functional as tvf
import torchvision.transforms as tvtfms
import operator as op
from PIL import Image
from torch import nn
from timm import create_model
# For type hinting later on
import collections
import typing
Lesson Video:
Introduction
In this version of the lesson, we’re going to take that model we just trained and remove every part of fastai
when deploying it. This will involve:
- Recreating the PyTorch model from “scratch” (allowing
timm
only) - Recreating the
fastai
transforms in raw PyTorch - Recreating the decoding stage
Starting with the Model
First let’s tackle the model. Since we can’t bring in fastai
we will only use imports from torch
, torchvision
, PIL
, and timm
. (and the stdlib
):
Next let’s create the core “body” of the model with timm
’s create_model
:
= create_model("vit_tiny_patch16_224", pretrained=False, num_classes=0, in_chans=3) net
The last step is to recreate fastai’s create_head
. We can do this easily with a nn.Sequential
directly:
= nn.Sequential(
head 192),
nn.BatchNorm1d(0.25),
nn.Dropout(192, 512, bias=False),
nn.Linear(=True),
nn.ReLU(inplace512),
nn.BatchNorm1d(0.5),
nn.Dropout(512, 37, bias=False)
nn.Linear( )
Now let’s merge them together:
= nn.Sequential(net, head) model
But how do we know it works?
Let’s bring in that state dictionary:
= torch.load("models/MyModel.pth") state
We load in the weights directly with model.load_state_dict
and pass in the state we want to use:
; model.load_state_dict(state)
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[72], line 1
----> 1 model.load_state_dict(state);
File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1671, in Module.load_state_dict(self, state_dict, strict)
1666 error_msgs.insert(
1667 0, 'Missing key(s) in state_dict: {}. '.format(
1668 ', '.join('"{}"'.format(k) for k in missing_keys)))
1670 if len(error_msgs) > 0:
-> 1671 raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
1672 self.__class__.__name__, "\n\t".join(error_msgs)))
1673 return _IncompatibleKeys(missing_keys, unexpected_keys)
RuntimeError: Error(s) in loading state_dict for Sequential:
Missing key(s) in state_dict: "0.cls_token", "0.pos_embed", "0.patch_embed.proj.weight", "0.patch_embed.proj.bias", "0.blocks.0.norm1.weight", "0.blocks.0.norm1.bias", "0.blocks.0.attn.qkv.weight", "0.blocks.0.attn.qkv.bias", "0.blocks.0.attn.proj.weight", "0.blocks.0.attn.proj.bias", "0.blocks.0.norm2.weight", "0.blocks.0.norm2.bias", "0.blocks.0.mlp.fc1.weight", "0.blocks.0.mlp.fc1.bias", "0.blocks.0.mlp.fc2.weight", "0.blocks.0.mlp.fc2.bias", "0.blocks.1.norm1.weight", "0.blocks.1.norm1.bias", "0.blocks.1.attn.qkv.weight", "0.blocks.1.attn.qkv.bias", "0.blocks.1.attn.proj.weight", "0.blocks.1.attn.proj.bias", "0.blocks.1.norm2.weight", "0.blocks.1.norm2.bias", "0.blocks.1.mlp.fc1.weight"
Uh, oh! We have an issue! Something about missing keys, why is that?
Messing with the state_dict
When fastai created our model, the layers in the model’s state_dict
follow a particular order. This particular order then gets translated into how these weights get saved.
What happened here is the namings don’t quite align from our current model to our new one. Such as:
list(model.state_dict().keys())[:5]
['0.cls_token',
'0.pos_embed',
'0.patch_embed.proj.weight',
'0.patch_embed.proj.bias',
'0.blocks.0.norm1.weight']
list(state.keys())[:5]
['0.model.cls_token',
'0.model.pos_embed',
'0.model.patch_embed.proj.weight',
'0.model.patch_embed.proj.bias',
'0.model.blocks.0.norm1.weight']
See the extra .model
attribute in there?
How can we get past this?
This technique I’m introducing you is also the general idea of how to perform “Transfer learning, twice”. Also known as re-using different pretrained weights from other model setups and re-applying them to your own model.
The basic idea is that for every single parameter, we look at its equivalent layer name in the other model and load it in. This let’s us also add or remove certain parts of these layer names, such as the .model
we are currently missing:
def copy_weight(name, parameter, state_dict):
"""
Takes in a layer `name`, model `parameter`, and `state_dict`
and loads the weights from `state_dict` into `parameter`
if it exists.
"""
# Part of the body
if name[0] == "0":
= name[:2] + "model." + name[2:]
name if name in state_dict.keys():
= state_dict[name]
input_parameter if input_parameter.shape == parameter.shape:
parameter.copy_(input_parameter)else:
print(f'Shape mismatch at layer: {name}, skipping')
else:
print(f'{name} is not in the state_dict, skipping.')
def apply_weights(input_model:nn.Module, input_weights:collections.OrderedDict, application_function:callable):
"""
Takes an input state_dict and applies those weights to the `input_model`, potentially
with a modifier function.
Args:
input_model (`nn.Module`):
The model that weights should be applied to
input_weights (`collections.OrderedDict`):
A dictionary of weights, the trained model's `state_dict()`
application_function (`callable`):
A function that takes in one parameter and layer name from `input_model`
and the `input_weights`. Should apply the weights from the state dict into `input_model`.
"""
= input_model.state_dict()
model_dict for name, parameter in model_dict.items():
application_function(name, parameter, input_weights) input_model.load_state_dict(model_dict)
Now we can use this on our model:
apply_weights(model, state, copy_weight)
And now our new weights are loaded in!
The Transforms
Next we have to apply the transforms. Our old model had the the following fastai transform setup:
Item Transforms:
RandomResizedCrop(460,460)
ToTensor()
Batch Transforms:
IntToFloatTensor()
aug_transforms(size=224)
Normalize.from_stats(*imagenet_stats)
Given this, let’s examine some transforms:
fastai.vision.augment.RandomResizedCrop
when applied on the validation set instead usesfastai.vision.augment.CropPad
fastai.vision.augment.Flip
andfastai.vision.augment.Brightness
aren’t applied on the validation set.- However because there is a
size
parameter anotherfastai.vision.augment.RandomResizedCrop
was performed - The rest we can recreate easily enough
Recreating CropPad
Our first experiment should be that ensuring we can find something that will exactly recreate CropPad
fastai provides. Let’s run a test with the native one in torchvision
:
from fastai.vision.data import PILImage
from fastai.data.external import untar_data, URLs
from fastai.data.transforms import get_image_files
import fastai.vision.augment as fastai_aug
import math
import numpy as np
These imports are only needed so we can test on an image and verify that the image we create lines up with the augmentation fastai
performs.
numpy
is here so we can easily perform an assert
statement on the image data.
First we need an image to test on. We’ll grab one of the first ones from the PETs dataset:
= untar_data(URLs.PETS)/'images'
path = get_image_files(path)[0]
fname fname
Path('/home/zach/.fastai/data/oxford-iiit-pet/images/Siamese_160.jpg')
Next we need to open them in their respective API’s:
= Image.open(fname)
im_pil = PILImage.create(fname) im_fastai
Our first assert statement can be made, ensuring that these two images are the same:
assert (np.array(im_pil) == np.array(im_fastai)).all()
assert (np.array(im_pil) == np.array(im_fastai)).all()
np.array(im)
In order to compare these two images, we can convert them into numpy
arrays. Their shapes will be a bit odd but we only care about the pixel values in each image as they should be the exact same
==b).func() (a
When wrapping parenthesis around a particular function call, you can perform further operations that the class provides after it’s output has been made.
all() .
.all
will take a boolean array (what our np.array == np.array
created) and ensure that every single value in it is True
.
We’ll create what should be equivalent transforms based on their names. CropPad
from fastai
is what gets used, and the closest transform in torchvision
seems to be CenterCrop
(as we center crop and pad). Let’s see if they are the same:
= fastai_aug.CropPad((460,460))
crop_fastai = tvtfms.CenterCrop((460,460)) crop_torch
assert (np.array(crop_fastai(im_fastai)) == np.array(crop_torch(im_pil))).all()
AssertionError:
Oh no! They arent the same after all! What can we do?
I’ve gone through and recreated how fastai performs CropPad
using pure PIL
, PyTorch
, and python
below. I highly recommend reading the explainations to understand what is going on:
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
"""
Takes a `PIL.Image` and crops it `size` unless one
dimension is larger than the actual image. Padding
must be performed afterwards if so.
Args:
image (`PIL.Image`):
An image to perform cropping on
size (`tuple` of integers):
A size to crop to, should be in the form
of (width, height)
Returns:
An augmented `PIL.Image`
"""
= (image.shape[-1] - size[0]) // 2
top = (image.shape[-2] - size[1]) // 2
left
= max(top, 0)
top = max(left, 0)
left
= min(top + size[0], image.shape[-1])
height = min(left + size[1], image.shape[-2])
width return image.crop((top, left, height, width))
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
"""
Takes a `PIL.Image` and crops it `size` unless one
dimension is larger than the actual image. Padding
must be performed afterwards if so.
Args:
image (`PIL.Image`):
An image to perform cropping on
size (`tuple` of integers):
A size to crop to, should be in the form
of (width, height)
Returns:
An augmented `PIL.Image`
"""
= (image.shape[-1] - size[0]) // 2
top = (image.shape[-2] - size[1]) // 2
left
= max(top, 0)
top = max(left, 0)
left
= min(top + size[0], image.shape[-1])
height = min(left + size[1], image.shape[-2])
width return image.crop((top, left, height, width))
= (image.shape[-1] - size[0]) // 2
top = (image.shape[-2] - size[1]) // 2 left
We need to calculate the coordinates of the top left corner of the crop as Image.crop
expects a top
, left
, height
, and width
to crop to. We do so by taking a dimension subtract the new dimension from it then divide it by two. value = (dimension - new_dimension) // 2
= max(top, 0)
top = max(left, 0) left
Any negatives that result from our previous calculation are a sign that padding is needed. Since we’re only cropping we need to zero these values out to get a valid point on the image
= min(top + size[0], image.shape[-1])
height = min(left + size[1], image.shape[-2]) width
Finally we find the smaller of the two dimensions: the proposed new dimension or the current dimension of the image
return image.crop((top, left, height, width))
All of this gets passed into PIL.Image.crop
as a bounding box to crop to
def pad(image, size:typing.Tuple[int,int]) -> Image:
"""
Takes a `PIL.Image` and pads it to `size` with
zeros.
Args:
image (`PIL.Image`):
An image to perform padding on
size (`tuple` of integers):
A size to pad to, should be in the form
of (width, height)
Returns:
An augmented `PIL.Image`
"""
= (image.shape[-1] - size[0]) // 2
top = (image.shape[-2] - size[1]) // 2
left
= max(-top, 0)
pad_top = max(-left, 0)
pad_left
= (
height, width max(size[1] - image.shape[-1] + top, 0),
max(size[0] - image.shape[-2] + left, 0)
)return tvf.pad(
image,
[pad_top, pad_left, height, width], ="constant"
padding_mode )
def pad(image, size:typing.Tuple[int,int]) -> Image:
"""
Takes a `PIL.Image` and pads it to `size` with
zeros.
Args:
image (`PIL.Image`):
An image to perform padding on
size (`tuple` of integers):
A size to pad to, should be in the form
of (width, height)
Returns:
An augmented `PIL.Image`
"""
= (image.shape[-1] - size[0]) // 2
top = (image.shape[-2] - size[1]) // 2
left
= max(-top, 0)
pad_top = max(-left, 0)
pad_left
= (
height, width max(size[1] - image.shape[-1] + top, 0),
max(size[0] - image.shape[-2] + left, 0)
)return tvf.pad(
image,
[pad_top, pad_left, height, width], ="constant"
padding_mode )
= (image.shape[-1] - size[0]) // 2
top = (image.shape[-2] - size[1]) // 2 left
We need to calculate the coordinates of the top left corner of the crop as Image.crop
expects a top
, left
, height
, and width
to crop to. We do so by taking a dimension subtract the new dimension from it then divide it by two. value = (dimension - new_dimension) // 2
= max(-top, 0)
pad_top = max(-left, 0) pad_left
Before we took the maximum of the top
and left
compared to zero to find the corner. This time we take the maximum of the negative coordinates to tell us how much padding is needed. The more negative the number originally was, the more padding is needed.
= (
height, width max(size[1] - image.shape[-1] + top, 0),
max(size[0] - image.shape[-2] + left, 0)
)
value = (new_dimension - dimension + offset)
The maximum between this and zero will be used in our padding operation as the new dimension length
return tvf.pad(
image,
[pad_top, pad_left, height, width], ="constant"
padding_mode )
Finally we call the pad
transform from torchvision.transforms.functional
passing in the image, the bounding box, and a padding mode. "constant"
here will fill all the new values with 0
, what fastai
does.
Let’s test out our new augmentation!
= (460,460)
size = pad(crop(im_pil, size),size) tfmd_img
== crop_fastai(im_fastai)).all() (np.array(tfmd_img)
True
It works!
However it’s not quite fastai’s regular RandomResizedCrop
, as there’s still a little more that needs to be done. When using this transform, the validation set adds some extra space and resizes the image based on that. The below code takes our pad
and crop
we just made and applies a new “Center Crop” based on RandomResizedCrop
:
= fastai_aug.RandomResizedCrop((460,460))
random_crop_fastai == random_crop_fastai(im_fastai, split_idx=1)).all() (np.array(tfmd_img)
False
def resized_crop_pad(
image: typing.Union[Image.Image, torch.tensor],int, int],
size: typing.Tuple[float = 0.14,
extra_crop_ratio: -> Image:
) """
Takes a `PIL.Image`, resize it according to the
`extra_crop_ratio`, and then crops and pads
it to `size`.
Args:
image (`PIL.Image`):
An image to perform padding on
size (`tuple` of integers):
A size to crop and pad to, should be in the form
of (width, height)
extra_crop_ratio (`float`):
The ratio of size at the edge cropped out. Default 0.14
"""
= max(size[0], size[1])
maximum_space = maximum_space * extra_crop_ratio
extra_space = math.ceil(extra_space / 8) * 8
extra_space = (size[0] + extra_space, size[1] + extra_space)
extended_size = image.resize(extended_size, resample=Image.Resampling.BILINEAR)
resized_image
if extended_size != size:
= pad(crop(resized_image, size), size)
resized_image
return resized_image
= resized_crop_pad(im_pil, size)
tfmd_img == random_crop_fastai(im_fastai, split_idx=1)).all() (np.array(tfmd_img)
True
Random Resized Crop on GPU
As you can imagine, a batch-applied Resize
transform is also a bit different.
There’s quite a bit happening in the below code, but just trust that this will adapt the RandomResizedCropGPU
transform that gets added when doing aug_transforms
and passing a size
and performs what occurs in the validation set (also known as center cropping):
def gpu_crop(
batch:torch.tensor, int,int]
size:typing.Tuple[
):"""
Crops each image in `batch` to a particular `size`.
Args:
batch (array of `torch.Tensor`):
A batch of images, should be of shape `NxCxWxH`
size (`tuple` of integers):
A size to pad to, should be in the form
of (width, height)
Returns:
A batch of cropped images
"""
# Split into multiple lines for clarity
= torch.eye(3, device=batch.device).float()
affine_matrix = affine_matrix.unsqueeze(0)
affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
affine_matrix = affine_matrix.contiguous()[:,:2]
affine_matrix
= F.affine_grid(
coords 2] + size, align_corners=True
affine_matrix, batch.shape[:
)
= coords.min(), coords.max()
top_range, bottom_range = 1/(bottom_range - top_range).item()*2
zoom
= min(
resizing_limit -2]/coords.shape[-2],
batch.shape[-1]/coords.shape[-1]
batch.shape[/2
)
if resizing_limit > 1 and resizing_limit > zoom:
= F.interpolate(
batch
batch, =1/resizing_limit,
scale_factor='area',
mode=True
recompute_scale_factor
)return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)
def gpu_crop(
batch:torch.tensor, int,int]
size:typing.Tuple[
):"""
Crops each image in `batch` to a particular `size`.
Args:
batch (array of `torch.Tensor`):
A batch of images, should be of shape `NxCxWxH`
size (`tuple` of integers):
A size to pad to, should be in the form
of (width, height)
Returns:
A batch of cropped images
"""
# Split into multiple lines for clarity
= torch.eye(3, device=batch.device).float()
affine_matrix = affine_matrix.unsqueeze(0)
affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
affine_matrix = affine_matrix.contiguous()[:,:2]
affine_matrix
= F.affine_grid(
coords 2] + size, align_corners=True
affine_matrix, batch.shape[:
)
= coords.min(), coords.max()
top_range, bottom_range = 1/(bottom_range - top_range).item()*2
zoom
= min(
resizing_limit -2]/coords.shape[-2],
batch.shape[-1]/coords.shape[-1]
batch.shape[/2
)
if resizing_limit > 1 and resizing_limit > zoom:
= F.interpolate(
batch
batch, =1/resizing_limit,
scale_factor='area',
mode=True
recompute_scale_factor
)return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)
= torch.eye(3, device=batch.device).float()
affine_matrix = affine_matrix.unsqueeze(0)
affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
affine_matrix = affine_matrix.contiguous()[:,:2] affine_matrix
3x3
such that:
[1,0,0]
[0,1,0]
[0,0,1]
Then create a sampling grid based on theta
, the affine_matrix
we just made, and make it of shape NxCxW2xH2
where W2
and H2
are the new widths and heights.
= coords.min(), coords.max()
top_range, bottom_range = 1/(bottom_range - top_range).item()*2 zoom
Then calculate whether we are zooming in or out of the image. If zoom > 1
, then we are zooming in
= min(
resizing_limit -2]/coords.shape[-2],
batch.shape[-1]/coords.shape[-1]
batch.shape[/2 )
Calculate the amount we are resizing by with 100% extra margin
if resizing_limit > 1 and resizing_limit > zoom:
= F.interpolate(
batch
batch, =1/resizing_limit,
scale_factor='area',
mode=True
recompute_scale_factor )
If we are resizing by over 200% and zooming in less than that, perform an interpolation first
return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)
Finally perform a grid sample to squish the image down to the new shape based on coords
Let’s test to make sure this aligns properly:
# fastai augmentations
= fastai_aug.ToTensor()
tt_fastai = fastai_aug.IntToFloatTensor()
i2f_fastai = fastai_aug.RandomResizedCropGPU((224,224))
rrc_fastai
# torchvision augmentations
= tvtfms.ToTensor()
tt_torch
# apply fastai augmentations
= crop_fastai(im_fastai)
base_im_fastai = rrc_fastai(
result_im_fastai
i2f_fastai(0)
tt_fastai(base_im_fastai).unsqueeze(=1
), split_idx
)
# apply torchvision augmentations
= gpu_crop(tt_torch(tfmd_img).unsqueeze(0), (224,224)) result_im_tv
torch.allclose(result_im_fastai, result_im_tv)
True
Excellent! We’ve successfully recreated the entire fastai
preprocessing data augmentation pipeline!
All that’s left is Normalize
:
= tvtfms.Normalize([0.485, 0.456, 0.406], [0.229,0.224,0.225]) norm_torch
Because of torchvision
’s ToTensor
implementation, we don’t need to worry about IntToFloatTensor
as we saw a moment ago.
Let’s try them out fully this time:
# fastai augmentations
= fastai_aug.Normalize.from_stats(*fastai_aug.imagenet_stats, cuda=False)
norm_fastai # apply fastai augmentations
= crop_fastai(im_fastai)
base_im_fastai = norm_fastai(
result_im_fastai
rrc_fastai(
i2f_fastai(0)
tt_fastai(base_im_fastai).unsqueeze(=1
), split_idx
)
)
# apply torchvision augmentations
= norm_torch(gpu_crop(tt_torch(tfmd_img).unsqueeze(0), (224,224))) result_im_tv
torch.allclose(result_im_fastai, result_im_tv)
True
We’ve fully recreated the fastai validation preprocessing pipeline without fastai. Now we can deploy it all!
Code
Below is the full code for the transforms. To my knowledge this is the only version that exists currently that fully recreates the transforms in a clear and proper way.
import typing
from PIL import Image
import torchvision.transforms.functional as tvf
def crop(image:typing.Union[Image.Image, torch.tensor], size:typing.Tuple[int,int]) -> Image:
"""
Takes a `PIL.Image` and crops it `size` unless one
dimension is larger than the actual image. Padding
must be performed afterwards if so.
Args:
image (`PIL.Image`):
An image to perform cropping on
size (`tuple` of integers):
A size to crop to, should be in the form
of (width, height)
Returns:
An augmented `PIL.Image`
"""
= (image.shape[-1] - size[0]) // 2
top = (image.shape[-2] - size[1]) // 2
left
= max(top, 0)
top = max(left, 0)
left
= min(top + size[0], image.shape[-1])
height = min(left + size[1], image.shape[-2])
width return image.crop((top, left, height, width))
def pad(image, size:typing.Tuple[int,int]) -> Image:
"""
Takes a `PIL.Image` and pads it to `size` with
zeros.
Args:
image (`PIL.Image`):
An image to perform padding on
size (`tuple` of integers):
A size to pad to, should be in the form
of (width, height)
Returns:
An augmented `PIL.Image`
"""
= (image.shape[-1] - size[0]) // 2
top = (image.shape[-2] - size[1]) // 2
left
= max(-top, 0)
pad_top = max(-left, 0)
pad_left
= (
height, width max(size[1] - image.shape[-1] + top, 0),
max(size[0] - image.shape[-2] + left, 0)
)return tvf.pad(
image,
[pad_top, pad_left, height, width], ="constant"
padding_mode )
def gpu_crop(
batch:torch.tensor, int,int]
size:typing.Tuple[
):"""
Crops each image in `batch` to a particular `size`.
Args:
batch (array of `torch.Tensor`):
A batch of images, should be of shape `NxCxWxH`
size (`tuple` of integers):
A size to pad to, should be in the form
of (width, height)
Returns:
A batch of cropped images
"""
# Split into multiple lines for clarity
= torch.eye(3, device=batch.device).float()
affine_matrix = affine_matrix.unsqueeze(0)
affine_matrix = affine_matrix.expand(batch.size(0), 3, 3)
affine_matrix = affine_matrix.contiguous()[:,:2]
affine_matrix
= F.affine_grid(
coords 2] + size, align_corners=True
affine_matrix, batch.shape[:
)
= coords.min(), coords.max()
top_range, bottom_range = 1/(bottom_range - top_range).item()*2
zoom
= min(
resizing_limit -2]/coords.shape[-2],
batch.shape[-1]/coords.shape[-1]
batch.shape[/2
)
if resizing_limit > 1 and resizing_limit > zoom:
= F.interpolate(
batch
batch, =1/resizing_limit,
scale_factor='area',
mode=True
recompute_scale_factor
)return F.grid_sample(batch, coords, mode='bilinear', padding_mode='reflection', align_corners=True)
Gradio Apps
See the source for both gradio apps here:
- No fastai: https://huggingface.co/spaces/muellerzr/deployment-no-fastai/
- fastai: https://huggingface.co/spaces/muellerzr/deployment-fastai