Before we run anything we want to ensure we have a P100 GPU or better:
!nvidia-smi
We always need the showdoc
to export
from fastai.vision.all import *
Source: https://arxiv.org/abs/1603.08155
from torchvision.models import vgg19, vgg16
feat_net = vgg19(pretrained=True).features.cuda().eval()
We'll get rid of the head and use the internal activations (and our generator model's loss). As a result, we want to set every layer to un-trainable
for p in feat_net.parameters(): p.requries_grad=False
We will be using feature detections that our model picks up, which is like our heatmaps generated for our Classification models
layers = [feat_net[i] for i in [1, 6, 11, 20, 29, 22]]; layers
The outputs are ReLU
layers. Below is a configuration for the 16 and 19 models
_vgg_config = {
'vgg16' : [1, 11, 18, 25, 20],
'vgg19' : [1, 6, 11, 20, 29, 22]
}
Let's write a quick get_layers
function to grab our network and the layers
def _get_layers(arch:str, pretrained=True):
"Get the layers and arch for a VGG Model (16 and 19 are supported only)"
feat_net = vgg19(pretrained=pretrained).cuda() if arch.find('9') > 1 else vgg16(pretrained=pretrained).cuda()
config = _vgg_config.get(arch)
features = feat_net.features.cuda().eval()
for p in features.parameters(): p.requires_grad=False
return feat_net, [features[i] for i in config]
Now let's make it all in one go utilizing our private functions to pass in an architecture name and a pretrained
parameter
def get_feats(arch:str, pretrained=True):
"Get the features of an architecture"
feat_net, layers = _get_layers(arch, pretrained)
hooks = hook_outputs(layers, detach=False)
def _inner(x):
feat_net(x)
return hooks.stored
return _inner
feats = get_feats('vgg19')
Our loss fuction needs:
- Our original image
- Some artwork / style
- Activation features from our encoder
What image will we be using?
Let's grab the image
url = 'https://static.greatbigcanvas.com/images/singlecanvas_thick_none/megan-aroon-duncanson/little-village-abstract-art-house-painting,1162125.jpg'
!wget {url} -O 'style.jpg'
fn = 'style.jpg'
We can now make a PipeLine
to convert our image into a Tensor
to use in our loss function. We'll want to use the Datasets
for this
dset = Datasets(fn, tfms=[PILImage.create])
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)], bs=1)
dl.show_batch(figsize=(7,7))
style_im = dl.one_batch()[0]
style_im.shape
def get_style_im(url):
download_url(url, 'style.jpg')
fn = 'style.jpg'
dset = Datasets(fn, tfms=[PILImage.create])
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)], bs=1)
return dl.one_batch()[0]
We can then grab the features using our feats
function we made earlier
im_feats = feats(style_im)
Let's look at their sizes
for feat in im_feats:
print(feat.shape)
Now we can bring those images down to the channel size
def gram(x:Tensor):
"Transpose a tensor based on c,w,h"
n, c, h, w = x.shape
x = x.view(n, c, -1)
return (x @ x.transpose(1, 2))/(c*w*h)
im_grams = [gram(f) for f in im_feats]
for feat in im_grams:
print(feat.shape)
def get_stl_fs(fs): return fs[:-1]
We're almost there! Let's look at why that was important
def style_loss(inp:Tensor, out_feat:Tensor):
"Calculate style loss, assumes we have `im_grams`"
# Get batch size
bs = inp[0].shape[0]
loss = []
# For every item in our inputs
for y, f in zip(*map(get_stl_fs, [im_grams, inp])):
# Calculate MSE
loss.append(F.mse_loss(y.repeat(bs, 1, 1), gram(f)))
# Multiply their sum by 30000
return 3e5 * sum(loss)
Great, so what now? Let's make a loss function for fastai
!
- Remember, we do not care to use any initial metrics
class FeatureLoss(Module):
"Combines two losses and features into a useable loss function"
def __init__(self, feats, style_loss, act_loss):
store_attr()
self.reset_metrics()
def forward(self, pred, targ):
# First get the features of our prediction and target
pred_feat, targ_feat = self.feats(pred), self.feats(targ)
# Calculate style and activation loss
style_loss = self.style_loss(pred_feat, targ_feat)
act_loss = self.act_loss(pred_feat, targ_feat)
# Store the loss
self._add_loss(style_loss, act_loss)
# Return the sum
return style_loss + act_loss
def reset_metrics(self):
# Generates a blank metric
self.metrics = dict(style = [], content = [])
def _add_loss(self, style_loss, act_loss):
# Add to our metrics
self.metrics['style'].append(style_loss)
self.metrics['content'].append(act_loss)
def act_loss(inp:Tensor, targ:Tensor):
"Calculate the MSE loss of the activation layers"
return F.mse_loss(inp[-1], targ[-1])
Let's declare our loss function by passing in our features and our two 'mini' loss functions
loss_func = FeatureLoss(feats, style_loss, act_loss)
class ReflectionLayer(Module):
"A series of Reflection Padding followed by a ConvLayer"
def __init__(self, in_channels, out_channels, ks=3, stride=2):
reflection_padding = ks // 2
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = nn.Conv2d(in_channels, out_channels, ks, stride)
def forward(self, x):
out = self.reflection_pad(x)
out = self.conv2d(out)
return out
ReflectionLayer(3, 3)
class ResidualBlock(Module):
"Two reflection layers and an added activation function with residual"
def __init__(self, channels):
self.conv1 = ReflectionLayer(channels, channels, ks=3, stride=1)
self.in1 = nn.InstanceNorm2d(channels, affine=True)
self.conv2 = ReflectionLayer(channels, channels, ks=3, stride=1)
self.in2 = nn.InstanceNorm2d(channels, affine=True)
self.relu = nn.ReLU()
def forward(self, x):
residual = x
out = self.relu(self.in1(self.conv1(x)))
out = self.in2(self.conv2(out))
out = out + residual
return out
ResidualBlock(3)
class UpsampleConvLayer(Module):
"Upsample with a ReflectionLayer"
def __init__(self, in_channels, out_channels, ks=3, stride=1, upsample=None):
self.upsample = upsample
reflection_padding = ks // 2
self.reflection_pad = nn.ReflectionPad2d(reflection_padding)
self.conv2d = nn.Conv2d(in_channels, out_channels, ks, stride)
def forward(self, x):
x_in = x
if self.upsample:
x_in = torch.nn.functional.interpolate(x_in, mode='nearest', scale_factor=self.upsample)
out = self.reflection_pad(x_in)
out = self.conv2d(out)
return out
Let's put everything together into a model
class TransformerNet(Module):
"A simple network for style transfer"
def __init__(self):
# Initial convolution layers
self.conv1 = ReflectionLayer(3, 32, ks=9, stride=1)
self.in1 = nn.InstanceNorm2d(32, affine=True)
self.conv2 = ReflectionLayer(32, 64, ks=3, stride=2)
self.in2 = nn.InstanceNorm2d(64, affine=True)
self.conv3 = ReflectionLayer(64, 128, ks=3, stride=2)
self.in3 = nn.InstanceNorm2d(128, affine=True)
# Residual layers
self.res1 = ResidualBlock(128)
self.res2 = ResidualBlock(128)
self.res3 = ResidualBlock(128)
self.res4 = ResidualBlock(128)
self.res5 = ResidualBlock(128)
# Upsampling Layers
self.deconv1 = UpsampleConvLayer(128, 64, ks=3, stride=1, upsample=2)
self.in4 = nn.InstanceNorm2d(64, affine=True)
self.deconv2 = UpsampleConvLayer(64, 32, ks=3, stride=1, upsample=2)
self.in5 = nn.InstanceNorm2d(32, affine=True)
self.deconv3 = ReflectionLayer(32, 3, ks=9, stride=1)
# Non-linearities
self.relu = nn.ReLU()
def forward(self, X):
y = self.relu(self.in1(self.conv1(X)))
y = self.relu(self.in2(self.conv2(y)))
y = self.relu(self.in3(self.conv3(y)))
y = self.res1(y)
y = self.res2(y)
y = self.res3(y)
y = self.res4(y)
y = self.res5(y)
y = self.relu(self.in4(self.deconv1(y)))
y = self.relu(self.in5(self.deconv2(y)))
y = self.deconv3(y)
return y
net = TransformerNet()
path = untar_data(URLs.COCO_SAMPLE)
Our DataBlock
needs to be Image -> Image
dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
get_items=get_image_files,
splitter=RandomSplitter(0.1, seed=42),
item_tfms=[Resize(224)],
batch_tfms=[Normalize.from_stats(*imagenet_stats)])
If you do not pass in a get_y
, fastai
will assume your input = output
dls = dblock.dataloaders(path, bs=22)
dls.show_batch()
We now can make our Learner
!
learn = Learner(dls, TransformerNet(), loss_func=loss_func)
learn.summary()
Let's find a learning rate and fit for one epoch
learn.lr_find()
learn.fit_one_cycle(1, 1e-3)
And take a look at some of our results!
learn.show_results()
learn.save('stage1')
Now let's try learn.predict
pred = learn.predict('cat.jpg')
pred[0].show()
Well while that looks cool, we lost a lot of the features! How can we fix this? Let's try something similar to what we did for our style_im
learn.load('stage1')
dset = Datasets('cat.jpg', tfms=[PILImage.create])
dl = dset.dataloaders(after_item=[ToTensor()], after_batch=[IntToFloatTensor(), Normalize.from_stats(*imagenet_stats)], bs=1)
t_im = dl.one_batch()[0]
with torch.no_grad():
res = learn.model(t_im)
Now let's try that again
TensorImage(res[0]).show()
Much better! But the colours seem 'off', because the output activations had not been 'decoded' with the reverse-tfms before being shown as an image, so we will do that below.
dec_res = dl.decode_batch(tuplify(res))[0][0]
dec_res.show();
learn.save('224')
Now we can increase our size to 512 similar to how we could do in the segmentation example (this is homework, we will not do this as the epoch will take ~40 minutes)
dblock = DataBlock(blocks=(ImageBlock, ImageBlock),
get_items=get_image_files,
splitter=RandomSplitter(0.1, seed=42),
item_tfms=[Resize(448)],
batch_tfms=[Normalize.from_stats(*imagenet_stats)])
dls = dblock.dataloaders(path, bs=8)
learn = Learner(dls, net, loss_func=loss_func).load('224')
learn.fit_one_cycle(1, 1e-3)
learn.show_results()
learn.save('final')
Let's export our model to use.
learn.loss_func = CrossEntropyLossFlat()
learn.export('myModel')
From here: Download the notebook and upload it to the main
reset_nbdev_module()