/mnt/d/lib/python3.7/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx (Triggered internally at  /pytorch/c10/cuda/CUDAFunctions.cpp:100.)
  return torch._C._cuda_getDeviceCount() > 0

Lesson Video:


This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, wwf, and nbdev currently running at the time of writing this:

  • fastai: 2.2.5
  • fastcore: 1.3.19
  • wwf: 0.0.10
  • nbdev: 1.1.12

Before we run anything we want to ensure we have a P100 GPU or better:

!nvidia-smi
Fri Feb  5 01:42:56 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.39       Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|  No running processes found                                                 |
+-----------------------------------------------------------------------------+

What will we cover?

  • Custom loss functions
  • Custom models
  • Utilizing the nb_dev library
  • Deployment Code

This notebook is largely based on the work by Lucas Vazquez on the fastai forum

Intro to nbdev

  • Library by Jeremy and Sylvain for writing libraries
  • Used to make the entire fastai library
  • Converts a ipynb to documentation as well as .py with only specific cells

  • Few base terminology:

    • #default_exp
    • #hide
    • #export

We always need the showdoc to export

fastai libraries

from fastai.vision.all import *

What is Style Transfer?

  • Feature Loss
    • Aka perception loss
  • Image Transformation network
    • Input images -> Output images
  • Pre-Trained Loss Network
    • Perceptual loss functions
      • A measurement of the differences in content and style between two images

alt text

Our Pre-Trained Network:

  • vgg-19:
    • Small CNN pre-trained on ImageNet
    • Only 19 layers deep
from torchvision.models import vgg19, vgg16
feat_net = vgg19(pretrained=True).features.cuda().eval()
Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth

We'll get rid of the head and use the internal activations (and our generator model's loss). As a result, we want to set every layer to un-trainable

for p in feat_net.parameters(): p.requries_grad=False

We will be using feature detections that our model picks up, which is like our heatmaps generated for our Classification models

layers = [feat_net[i] for i in [1, 6, 11, 20, 29, 22]]; layers
[ReLU(inplace=True),
 ReLU(inplace=True),
 ReLU(inplace=True),
 ReLU(inplace=True),
 ReLU(inplace=True),
 ReLU(inplace=True)]

The outputs are ReLU layers. Below is a configuration for the 16 and 19 models

_vgg_config = {
    'vgg16' : [1, 11, 18, 25, 20],
    'vgg19' : [1, 6, 11, 20, 29, 22]
}

Let's write a quick get_layers function to grab our network and the layers

def _get_layers(arch:str, pretrained=True):
  "Get the layers and arch for a VGG Model (16 and 19 are supported only)"
  feat_net = vgg19(pretrained=pretrained).cuda() if arch.find('9') > 1 else vgg16(pretrained=pretrained).cuda()
  config = _vgg_config.get(arch)
  features = feat_net.features.cuda().eval()
  for p in features.parameters(): p.requires_grad=False
  return feat_net, [features[i] for i in config]

Now let's make it all in one go utilizing our private functions to pass in an architecture name and a pretrained parameter

get_feats[source]

get_feats(arch:str, pretrained=True)

Get the features of an architecture

def get_feats(arch:str, pretrained=True):
  "Get the features of an architecture"
  feat_net, layers = _get_layers(arch, pretrained)
  hooks = hook_outputs(layers, detach=False)
  def _inner(x):
    feat_net(x)
    return hooks.stored
  return _inner
feats = get_feats('vgg19')

The Loss Function

Our loss fuction needs:

  • Our original image
  • Some artwork / style
  • Activation features from our encoder

What image will we be using?

alt text

Let's grab the image

url = 'https://static.greatbigcanvas.com/images/singlecanvas_thick_none/megan-aroon-duncanson/little-village-abstract-art-house-painting,1162125.jpg'
!wget {url} -O 'style.jpg'
--2021-02-05 01:43:21--  https://static.greatbigcanvas.com/images/singlecanvas_thick_none/megan-aroon-duncanson/little-village-abstract-art-house-painting,1162125.jpg
Resolving static.greatbigcanvas.com (static.greatbigcanvas.com)... 3.86.0.166, 52.205.76.115
Connecting to static.greatbigcanvas.com (static.greatbigcanvas.com)|3.86.0.166|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 484954 (474K) [image/jpeg]
Saving to: ‘style.jpg’

style.jpg           100%[===================>] 473.59K  1.80MB/s    in 0.3s    

2021-02-05 01:43:22 (1.80 MB/s) - ‘style.jpg’ saved [484954/484954]

fn = 'style.jpg'

We can now make a PipeLine to convert our image into a Tensor to use in our loss function. We'll want to use the Datasets for this

dset = Datasets(fn, tfms=[PILImage.create])
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)], bs=1)
dl.show_batch(figsize=(7,7))
style_im = dl.one_batch()[0]
style_im.shape
(1, 3, 831, 1000)

get_style_im[source]

get_style_im(url)

def get_style_im(url):
  download_url(url, 'style.jpg')
  fn = 'style.jpg'
  dset = Datasets(fn, tfms=[PILImage.create])
  dl = dset.dataloaders(after_item=[ToTensor()], after_batch=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)], bs=1)
  return dl.one_batch()[0]

We can then grab the features using our feats function we made earlier

im_feats = feats(style_im)

Let's look at their sizes

for feat in im_feats:
  print(feat.shape)
(1, 64, 831, 1000)
(1, 128, 415, 500)
(1, 256, 207, 250)
(1, 512, 103, 125)
(1, 512, 51, 62)
(1, 512, 103, 125)

Now we can bring those images down to the channel size

gram[source]

gram(x:Tensor)

Transpose a tensor based on c,w,h

def gram(x:Tensor):
  "Transpose a tensor based on c,w,h"
  n, c, h, w = x.shape
  x = x.view(n, c, -1)
  return (x @ x.transpose(1, 2))/(c*w*h)
im_grams = [gram(f) for f in im_feats]
for feat in im_grams:
  print(feat.shape)
(1, 64, 64)
(1, 128, 128)
(1, 256, 256)
(1, 512, 512)
(1, 512, 512)
(1, 512, 512)

get_stl_fs[source]

get_stl_fs(fs)

def get_stl_fs(fs): return fs[:-1]

We're almost there! Let's look at why that was important

style_loss[source]

style_loss(inp:Tensor, out_feat:Tensor)

Calculate style loss, assumes we have im_grams

def style_loss(inp:Tensor, out_feat:Tensor):
  "Calculate style loss, assumes we have `im_grams`"
  # Get batch size
  bs = inp[0].shape[0]
  loss = []
  # For every item in our inputs
  for y, f in zip(*map(get_stl_fs, [im_grams, inp])):
    # Calculate MSE
    loss.append(F.mse_loss(y.repeat(bs, 1, 1), gram(f)))
  # Multiply their sum by 30000
  return 3e5 * sum(loss)

Great, so what now? Let's make a loss function for fastai!

  • Remember, we do not care to use any initial metrics

class FeatureLoss[source]

FeatureLoss(feats, style_loss, act_loss) :: Module

Combines two losses and features into a useable loss function

class FeatureLoss(Module):
  "Combines two losses and features into a useable loss function"
  def __init__(self, feats, style_loss, act_loss):
    store_attr()
    self.reset_metrics()

  def forward(self, pred, targ):
    # First get the features of our prediction and target
    pred_feat, targ_feat = self.feats(pred), self.feats(targ)
    # Calculate style and activation loss
    style_loss = self.style_loss(pred_feat, targ_feat)
    act_loss = self.act_loss(pred_feat, targ_feat)
    # Store the loss
    self._add_loss(style_loss, act_loss)
    # Return the sum
    return style_loss + act_loss

  def reset_metrics(self):
    # Generates a blank metric
    self.metrics = dict(style = [], content = [])

  def _add_loss(self, style_loss, act_loss):
    # Add to our metrics
    self.metrics['style'].append(style_loss)
    self.metrics['content'].append(act_loss)

act_loss[source]

act_loss(inp:Tensor, targ:Tensor)

Calculate the MSE loss of the activation layers

def act_loss(inp:Tensor, targ:Tensor):
  "Calculate the MSE loss of the activation layers"
  return F.mse_loss(inp[-1], targ[-1])

Let's declare our loss function by passing in our features and our two 'mini' loss functions

loss_func = FeatureLoss(feats, style_loss, act_loss)

The Model Architecture

Let's now build our model

class ReflectionLayer[source]

ReflectionLayer(in_channels, out_channels, ks=3, stride=2) :: Module

A series of Reflection Padding followed by a ConvLayer

class ReflectionLayer(Module):
    "A series of Reflection Padding followed by a ConvLayer"
    def __init__(self, in_channels, out_channels, ks=3, stride=2):
        reflection_padding = ks // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, ks, stride)

    def forward(self, x):
        out = self.reflection_pad(x)
        out = self.conv2d(out)
        return out
ReflectionLayer(3, 3)
ReflectionLayer(
  (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
  (conv2d): Conv2d(3, 3, kernel_size=(3, 3), stride=(2, 2))
)

class ResidualBlock[source]

ResidualBlock(channels) :: Module

Two reflection layers and an added activation function with residual

class ResidualBlock(Module):
    "Two reflection layers and an added activation function with residual"
    def __init__(self, channels):
          self.conv1 = ReflectionLayer(channels, channels, ks=3, stride=1)
          self.in1 = nn.InstanceNorm2d(channels, affine=True)
          self.conv2 = ReflectionLayer(channels, channels, ks=3, stride=1)
          self.in2 = nn.InstanceNorm2d(channels, affine=True)
          self.relu = nn.ReLU()

    def forward(self, x):
          residual = x
          out = self.relu(self.in1(self.conv1(x)))
          out = self.in2(self.conv2(out))
          out = out + residual
          return out
ResidualBlock(3)
ResidualBlock(
  (conv1): ReflectionLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  )
  (in1): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (conv2): ReflectionLayer(
    (reflection_pad): ReflectionPad2d((1, 1, 1, 1))
    (conv2d): Conv2d(3, 3, kernel_size=(3, 3), stride=(1, 1))
  )
  (in2): InstanceNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
  (relu): ReLU()
)

class UpsampleConvLayer[source]

UpsampleConvLayer(in_channels, out_channels, ks=3, stride=1, upsample=None) :: Module

Upsample with a ReflectionLayer

class UpsampleConvLayer(Module):
    "Upsample with a ReflectionLayer"
    def __init__(self, in_channels, out_channels, ks=3, stride=1, upsample=None):
        self.upsample = upsample
        reflection_padding = ks // 2
        self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
        self.conv2d = nn.Conv2d(in_channels, out_channels, ks, stride)

    def forward(self, x):
        x_in = x
        if self.upsample:
            x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
        out = self.reflection_pad(x_in)
        out = self.conv2d(out)
        return out

Let's put everything together into a model

class TransformerNet[source]

TransformerNet() :: Module

A simple network for style transfer

class TransformerNet(Module):
    "A simple network for style transfer"
    def __init__(self):
        # Initial convolution layers
        self.conv1 = ReflectionLayer(3, 32, ks=9, stride=1)
        self.in1 = nn.InstanceNorm2d(32, affine=True)
        self.conv2 = ReflectionLayer(32, 64, ks=3, stride=2)
        self.in2 = nn.InstanceNorm2d(64, affine=True)
        self.conv3 = ReflectionLayer(64, 128, ks=3, stride=2)
        self.in3 = nn.InstanceNorm2d(128, affine=True)
        # Residual layers
        self.res1 = ResidualBlock(128)
        self.res2 = ResidualBlock(128)
        self.res3 = ResidualBlock(128)
        self.res4 = ResidualBlock(128)
        self.res5 = ResidualBlock(128)
        # Upsampling Layers
        self.deconv1 = UpsampleConvLayer(128, 64, ks=3, stride=1, upsample=2)
        self.in4 = nn.InstanceNorm2d(64, affine=True)
        self.deconv2 = UpsampleConvLayer(64, 32, ks=3, stride=1, upsample=2)
        self.in5 = nn.InstanceNorm2d(32, affine=True)
        self.deconv3 = ReflectionLayer(32, 3, ks=9, stride=1)
        # Non-linearities
        self.relu = nn.ReLU()

    def forward(self, X):
        y = self.relu(self.in1(self.conv1(X)))
        y = self.relu(self.in2(self.conv2(y)))
        y = self.relu(self.in3(self.conv3(y)))
        y = self.res1(y)
        y = self.res2(y)
        y = self.res3(y)
        y = self.res4(y)
        y = self.res5(y)
        y = self.relu(self.in4(self.deconv1(y)))
        y = self.relu(self.in5(self.deconv2(y)))
        y = self.deconv3(y)
        return y
net = TransformerNet()

DataLoaders and Learner

We will be using the COCO_SAMPLE dataset

path = untar_data(URLs.COCO_SAMPLE)

Our DataBlock needs to be Image -> Image

dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
                   get_items=get_image_files,
                   splitter=RandomSplitter(0.1, seed=42),
                   item_tfms=[Resize(224)],
                   batch_tfms=[Normalize.from_stats(*imagenet_stats)])

If you do not pass in a get_y, fastai will assume your input = output

dls = dblock.dataloaders(path, bs=22)
dls.show_batch()

We now can make our Learner!

learn = Learner(dls, TransformerNet(), loss_func=loss_func)
learn.summary()
TransformerNet (Input shape: ['22 x 3 x 224 x 224'])
================================================================
Layer (type)         Output Shape         Param #    Trainable 
================================================================
ReflectionPad2d      22 x 3 x 232 x 232   0          False     
________________________________________________________________
Conv2d               22 x 32 x 224 x 224  7,808      True      
________________________________________________________________
InstanceNorm2d       22 x 32 x 224 x 224  64         True      
________________________________________________________________
ReflectionPad2d      22 x 32 x 226 x 226  0          False     
________________________________________________________________
Conv2d               22 x 64 x 112 x 112  18,496     True      
________________________________________________________________
InstanceNorm2d       22 x 64 x 112 x 112  128        True      
________________________________________________________________
ReflectionPad2d      22 x 64 x 114 x 114  0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   73,856     True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReLU                 22 x 128 x 56 x 56   0          False     
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReLU                 22 x 128 x 56 x 56   0          False     
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReLU                 22 x 128 x 56 x 56   0          False     
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReLU                 22 x 128 x 56 x 56   0          False     
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReflectionPad2d      22 x 128 x 58 x 58   0          False     
________________________________________________________________
Conv2d               22 x 128 x 56 x 56   147,584    True      
________________________________________________________________
InstanceNorm2d       22 x 128 x 56 x 56   256        True      
________________________________________________________________
ReLU                 22 x 128 x 56 x 56   0          False     
________________________________________________________________
ReflectionPad2d      22 x 128 x 114 x 11  0          False     
________________________________________________________________
Conv2d               22 x 64 x 112 x 112  73,792     True      
________________________________________________________________
InstanceNorm2d       22 x 64 x 112 x 112  128        True      
________________________________________________________________
ReflectionPad2d      22 x 64 x 226 x 226  0          False     
________________________________________________________________
Conv2d               22 x 32 x 224 x 224  18,464     True      
________________________________________________________________
InstanceNorm2d       22 x 32 x 224 x 224  64         True      
________________________________________________________________
ReflectionPad2d      22 x 32 x 232 x 232  0          False     
________________________________________________________________
Conv2d               22 x 3 x 224 x 224   7,779      True      
________________________________________________________________
ReLU                 22 x 32 x 224 x 224  0          False     
________________________________________________________________

Total params: 1,679,235
Total trainable params: 1,679,235
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7fe082f1ef28>
Loss function: FeatureLoss()

Callbacks:
  - TrainEvalCallback
  - Recorder
  - ProgressCallback

Let's find a learning rate and fit for one epoch

learn.lr_find()
learn.fit_one_cycle(1, 1e-3)
epoch train_loss valid_loss time
0 28.079075 28.128078 07:40

And take a look at some of our results!

learn.show_results()
learn.save('stage1')

Now let's try learn.predict

pred = learn.predict('cat.jpg')
pred[0].show()
<matplotlib.axes._subplots.AxesSubplot at 0x7fe014c07e10>

Well while that looks cool, we lost a lot of the features! How can we fix this? Let's try something similar to what we did for our style_im

learn.load('stage1')
<fastai.learner.Learner at 0x7ff7890267f0>
dset = Datasets('cat.jpg', tfms=[PILImage.create])
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)], bs=1)
t_im = dl.one_batch()[0]
with torch.no_grad():
  res = learn.model(t_im)

Now let's try that again

TensorImage(res[0]).show()
Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).
<matplotlib.axes._subplots.AxesSubplot at 0x7ff788d839e8>

Much better! But the colours seem 'off', because the output activations had not been 'decoded' with the reverse-tfms before being shown as an image, so we will do that below.

dec_res = dl.decode_batch(tuplify(res))[0][0]
dec_res.show();
learn.save('224')

Now we can increase our size to 512 similar to how we could do in the segmentation example (this is homework, we will not do this as the epoch will take ~40 minutes)

Homework

dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
                   get_items=get_image_files,
                   splitter=RandomSplitter(0.1, seed=42),
                   item_tfms=[Resize(448)],
                   batch_tfms=[Normalize.from_stats(*imagenet_stats)])
dls = dblock.dataloaders(path, bs=8)
learn = Learner(dls, net, loss_func=loss_func).load('224')
learn.fit_one_cycle(1, 1e-3)
learn.show_results()
learn.save('final')

Rest of the lesson

Let's export our model to use.

learn.loss_func = CrossEntropyLossFlat()
learn.export('myModel')

From here: Download the notebook and upload it to the main

Export

reset_nbdev_module()