Multi-Label Classification

Identifying images with multiple labels
Lesson 3

Lesson Video:

Introduction

In this lesson we will focus on dealing with multi-labelled images. In the prior edition of Walk with fastai this was done using the high level API, however in the spirit of revisited we will be doing so with the mid-level API and will continue to use it throughout the rest of this course.

This will be a vision problem so again we will import the vision library:

from fastai.vision.all import *
/home/zach/miniconda3/envs/fastai/lib/python3.9/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Below are the exact imports from what we are using today:

import torch
from torch import tensor
from torchvision.models.resnet import resnet34
from PIL import Image
from itertools import compress

import pandas as pd
from pathlib import Path
from fastcore.xtras import Path # @patched Pathlib.path

from fastai.data.core import show_at, Datasets
from fastai.data.external import URLs, untar_data
from fastai.data.transforms import (
    ColReader,
    IntToFloatTensor, 
    MultiCategorize, 
    Normalize,
    OneHotEncode, 
    RandomSplitter,
)

from fastai.metrics import accuracy_multi

from fastai.vision.augment import aug_transforms
from fastai.vision.core import PILImage
from fastai.vision.learner import vision_learner
from fastai.learner import Learner
from fastai.callback.schedule import Learner # @patched Learner functions like lr_find and fit_one_cycle

Exploring the data

For this problem we will use the Planet dataset, a collection of satellite images with multiple labels describing the scene.

First let’s download the data:

src = untar_data(URLs.PLANET_SAMPLE)
df = pd.read_csv(src/'labels.csv')

And then take a peek:

df.head()
image_name tags
0 train_21983 partly_cloudy primary
1 train_9516 clear cultivation primary water
2 train_12664 haze primary
3 train_36960 clear primary
4 train_5302 haze primary road

For this problem we have an image_name column and a tags problem. The labels are also seperated by a space.

Similar to what we did for the Kaggle dataset, let’s look at how the labels are distributed:

all_tags = df["tags"].values
all_labels = []
for row in all_tags:
    all_labels += row.split(" ")
len(all_labels)
2899
all_tags = df["tags"].values
all_labels = []
for row in all_tags:
    all_labels += row.split(" ")
len(all_labels)
2899

all_tags = df["tags"].values

The labels are located in the "all_tags" column, and we can extract the raw values as a regular array using the .values attribute


all_labels = []
for row in all_tags:
    all_labels += row.split(" ")

Since each row’s tags are split by a space, we can turn this string into an array and add these values directly into our all_labels array. This will be a giant list that has many repeated values on purpose.

In total there are 2899 labels, but this doesn’t tell us how many different labels there are. Let’s find out:

different_labels = set(all_labels)
len(different_labels)
17

Only 17! Let’s see the distribution of these labels:

counts = {
    label: all_labels.count(label) 
    for label in different_labels
}

counts = {
    key: value 
    for key, value in 
    sorted(
        counts.items(), 
        key = lambda item: -item[1]
    )
}
counts = {
    label: all_labels.count(label) 
    for label in different_labels
}

counts = {
    key: value 
    for key, value in 
    sorted(
        counts.items(), 
        key = lambda item: -item[1]
    )
}

    label: all_labels.count(label)

Python lists contain a method called count which can take in an item and count how many times it occurs in the list.


counts = {
    key: value 
    for key, value in 
    sorted(
        counts.items(), 
        key = lambda item: -item[1]
    )
}

To make our lives easier, we can sort the list by the number of instances found.


key = lambda item: -item[1]

To make the dictionary be sorted from highest to lowest occurences, we sort by the negative of the actual value

counts
{'primary': 934,
 'clear': 701,
 'agriculture': 318,
 'road': 209,
 'partly_cloudy': 194,
 'water': 169,
 'cultivation': 124,
 'habitation': 93,
 'haze': 55,
 'cloudy': 50,
 'bare_ground': 19,
 'blooming': 9,
 'selective_logging': 8,
 'artisinal_mine': 7,
 'slash_burn': 6,
 'conventional_mine': 2,
 'blow_down': 1}

What we find is that selective_logging, artisinal_mine, slash_burn, conventional_mine, and blow_down had the least number of occurances. For the sake of todays lesson we will get rid of rows with these values.

Typically a Data Scientist has two choices when it comes to dealing with rare values, either leaving them in as they are or performing oversampling. We’re dropping them for convience but normally one would want to oversample the training dataset for rare values and ensure the validation dataset has a few instances of them to test on.

Next we’ll use some pandas magic to filter our dataframe by these values:

len(df)
1000
for key, count in counts.items():
    if count < 10:
        df = df[df["tags"].str.contains(key) == False]
for key, count in counts.items():
    if count < 10:
        df = df[df["tags"].str.contains(key) == False]

if count < 10: 

Since we’re limiting it based on rare values, we’ll arbitrarily get rid of classes that occur less than 10 times


df['tags'].str

This converts each rows item into a string and we can utilize methods inside the str class to be applied on every single row


.contains(key) == False

From here we then look for if any of these rows have our rare class, and only keep the ones that do not.

len(df)
968

Despite what seemed like getting rid of quite a lot of classes, they only showed up in < 40 rows. In the real world one would want to try and get more data from these underrepresented classes if possible, or perform oversampling on them.

Next let’s take a look at one of the images:

df["image_name"].head(), src.ls()
(0    train_21983
 1     train_9516
 2    train_12664
 3    train_36960
 4     train_5302
 Name: image_name, dtype: object,
 (#2) [Path('/home/zach/.fastai/data/planet_sample/train'),Path('/home/zach/.fastai/data/planet_sample/labels.csv')])

We get a partial string of the filename, and based on looking at how the data source is setup they live in the train folder. Let’s look in there:

(src/'train').ls()[:3]
(#3) [Path('/home/zach/.fastai/data/planet_sample/train/train_13505.jpg'),Path('/home/zach/.fastai/data/planet_sample/train/train_34206.jpg'),Path('/home/zach/.fastai/data/planet_sample/train/train_2407.jpg')]

Each image has the extension .jpg:

PILImage.create((src/'train'/'train_2407.jpg'))

This looks like just a gray blob because this is a satellite image of Earth for a specific quadrent. Sadly much of these will look quite boring!

We now have enough to build ourselves a Datasets object!

Working with fastai Datasets

Next let’s build a Datasets object. Here’s what we know:

  • x: Our x’s are colored images, meaning we should use PILImage
  • y: Our y’s are multilabeled images, meaning we should use MultiCategorize

We need to write a set of getters to get each one based on looking at a single row of data

  • x: Our x’s are located at src/'train'/{fname}.jpg, so this can be written as a function that looks at the image_name column
  • y: Our y’s are a split string based on the tags column, and this too can be written as a function

Technically these can be written as lambda functions, such as lambda x: print(x) however lambda’s are not pickleable, meaning if you export the Learner it will give you an error of something along the lines of "Cannot pickle lambda ...". The solution is to define them as seperate functions that you pull in somewhere before importing the Learner

There is no set value for what is “train” or “validation”, so we can randomly split the dataset again

Let’s build everything that was just described:

def get_x(row:pd.Series) -> Path:
    return (src/'train'/row.image_name).with_suffix(".jpg")

Using with_suffix on a pathlib.Path object will either replace or add the defined suffix to the last item in the path.

def get_y(row:pd.Series) -> List[str]:
    return row.tags.split(" ")
row = df.iloc[0]
get_x(row), get_y(row)
(Path('/home/zach/.fastai/data/planet_sample/train/train_21983.jpg'),
 ['partly_cloudy', 'primary'])

Now that we’ve seen how these are written, in fastai there exists a labeller we can use instead called the ColReader which takes in the index of the column and any adjustments we want to make:

get_x = ColReader(0, pref=f'{src}/train/', suff=".jpg")
get_y = ColReader(1, label_delim=" ")
tfms = [
    [get_x, PILImage.create], 
    [
        get_y,
        MultiCategorize(vocab=different_labels), 
        OneHotEncode(len(different_labels))
    ]
]
tfms = [
    [get_x, PILImage.create], 
    [
        get_y,
        MultiCategorize(vocab=different_labels), 
        OneHotEncode(len(different_labels))
    ]
]

MultiCategorize(vocab=different_labels)

We need to explicitly pass in the list of different class labels to use here.


OneHotEncode(len(different_labels))

For our particular problem, just doing MultiCategorize isn’t quite enough, we also need to do OneHotEncode. This is because MultiCategorize will just turn our labels into something like 17, 22 (the index’s into the vocab), we need to turn it into [0, 0, ...1, ...1, ...0] where each 1 represents a label present in the image. But to do so we must pass in the number of different labels present.

train_idxs, valid_idxs = (
    RandomSplitter(valid_pct=0.2, seed=42)(df)
)

The RandomSplitter can accept a validation percentage as well as a random seed to be set during the splitting. After instantiating the class we can then pass in any items we want to have split, such as our dataframe here.

train_idxs, valid_idxs
((#775) [888,918,313,104,751,504,661,774,492,634...],
 (#193) [622,943,48,152,686,547,418,768,558,863...])

Now we can build the datasets object!

dsets = Datasets(df, tfms=tfms, splits=[train_idxs, valid_idxs])
dsets.train[0]
(PILImage mode=RGB size=256x256,
 TensorMultiCategory([1., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
                      0., 0.]))

We now return the PILImage expected as well as our one-hot encoded labels!

show_at(dsets.train, 0);

Building some DataLoaders

Lastly we need to build some DataLoaders. The fastai Datasets class has a .dataloaders() function for us to do so easily, we just need to pass in some transforms to use:

batch_tfms = [
    IntToFloatTensor(), 
    *aug_transforms(
        flip_vert=True, 
        max_lighting=0.1, 
        max_zoom=1.05, 
        max_warp=0.
    ), 
    Normalize.from_stats(*imagenet_stats)
]
dls = dsets.dataloaders(
    after_item=[ToTensor], 
    after_batch=batch_tfms
)
dls.device
device(type='cuda', index=0)

You may notice that we don’t pass any Resize or other augmentation to the item transforms, just ToTensor. This is because all of the images are already 256x256, so there isn’t a need to and we can just jump to augmenting the data on the GPU.

Let’s look at a batch of data to make sure everything looks correct:

dls.show_batch()

Great! Now to train a model

Training a Model

Similar to our previous problem, we will use the baseline resnet34 for this task, and since we are looking at multiple labels we will want to use the accuracy_multi metric:

learn = vision_learner(dls, resnet34, metrics=[accuracy_multi])

Let’s take a look at a few of the defaults fastai set for us:

learn.model[1]
Sequential(
  (0): AdaptiveConcatPool2d(
    (ap): AdaptiveAvgPool2d(output_size=1)
    (mp): AdaptiveMaxPool2d(output_size=1)
  )
  (1): fastai.layers.Flatten(full=False)
  (2): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): Dropout(p=0.25, inplace=False)
  (4): Linear(in_features=1024, out_features=512, bias=False)
  (5): ReLU(inplace=True)
  (6): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (7): Dropout(p=0.5, inplace=False)
  (8): Linear(in_features=512, out_features=17, bias=False)
)

We can see that the head of our model is still exactly the same, since we have a total of 17 classes that can show up. So what needs to change for our multi-label problem?

learn.loss_func
FlattenedLoss of BCEWithLogitsLoss()

The loss function.

The difference between normal Cross Entropy and Binary Cross Entropy with Logits is rather than performing a softmax, we instead perform what is called a sigmoid operation and use nn.BCEWithLogitsLoss instead of nn.CrossEntropyLoss:

t = tensor([[0.1, 0.5, 0.3, 0.7, 0.2]])
torch.sigmoid(t)
tensor([[0.5250, 0.6225, 0.5744, 0.6682, 0.5498]])

After scaling we can then also limit what is perceved as “seen” vs “not seen” through a threshold:

learn.loss_func.thresh
0.5

This essentially means that if there are any results that are less than 0.5 from the output of our sigmoid then we ignore them and assume they are not there.

Keeping the thresholds aligned!

It’s extremely important to remember the metric and loss function’s thresholds should be the exact same otherwise you’re looking at two different versions of the same result. E.g. while you can just change the metric’s threshold to be 0.6, the loss function will still be 0.5 so you’re not actually training with the assumption that the right answer should be > 0.6

Now that we have everything setup, let’s find a learning rate and train!

learn.lr_find()
SuggestedLRs(valley=0.002511886414140463)

We find that 2e-3 is a pretty good learning rate, so let’s do some fine tuning!

learn.fit_one_cycle(1, slice(2e-3))
epoch train_loss valid_loss accuracy_multi time
0 0.957430 0.729192 0.603170 00:05

Then we’ll unfreeze and train a bit more:

learn.unfreeze()
learn.fit_one_cycle(5, slice(2e-3/2.6**4, 2e-3))
epoch train_loss valid_loss accuracy_multi time
0 0.852019 0.676423 0.644925 00:08
1 0.779269 0.730228 0.687900 00:08
2 0.711121 0.519590 0.793051 00:08
3 0.649912 0.406981 0.901250 00:08
4 0.597333 0.384169 0.919537 00:08

This lr/2.6**4 is a general rule of thumb that Jeremy Howard found works quite well when doing gradual unfreezing, see the ULMFiT notebook to see it in practice!

And now let’s look at our results:

learn.show_results(figsize=(15,15))

Predictions in the wild

While we’ve looked at how to train, let’s look at how to predict and get back our answers without using the fastai API.

model = learn.model
fname = get_x(df.iloc[0])
fname = '/home/zach/.fastai/data/planet_sample/train/train_21983.jpg'

First, the item transforms:

from torchvision.transforms import PILToTensor
im = Image.open(fname)
im = im.convert("RGB")
t_im = PILToTensor()(im)
im = Image.open(fname)
im = im.convert("RGB")
t_im = PILToTensor()(im)

im = Image.open(fname)
im = im.convert("RGB")

First we will open the file using Pillow and convert it to an RGB image


t_im = PILToTensor()(im)

This convers the PIL image into a Tensor through a torchvision transform for us.

Then the batch transforms:

t_im = t_im.unsqueeze(0)
t_im = t_im.float().div_(255.)
t_im = t_im.unsqueeze(0)
t_im = t_im.float().div_(255.)

t_im = t_im.unsqueeze(0)

First we turn the single image into a batch of 1


t_im = t_im.float().div_(255.)

Then we turn the tensor into a float and divide it by 255 as pixel values range from 0-256.

mean, std = (
    [0.485, 0.456, 0.406], 
    [0.229, 0.224, 0.225]
)
vector = [1]*4
vector[1] = -1
mean = tensor(mean).view(*vector)
std = tensor(std).view(*vector)
mean, std = (
    [0.485, 0.456, 0.406], 
    [0.229, 0.224, 0.225]
)
vector = [1]*4
vector[1] = -1
mean = tensor(mean).view(*vector)
std = tensor(std).view(*vector)

mean, std = (
    [0.485, 0.456, 0.406], 
    [0.229, 0.224, 0.225]
)

First we need the mean and standard deviation of ImageNet


vector = [1]*4
vector[1] = -1

Then we create a vector of how these two sets of three numbers should be formatted so that a matrix multiplication between the image and the setting can be performed


mean = tensor(mean).view(*vector)
std = tensor(std).view(*vector)

Finally we apply these two formats to the tensors and return them

I highly recommend Lesson 1 from Deep Learning from the Foundations to learn more about this!

mean.shape, std.shape
(torch.Size([1, 3, 1, 1]), torch.Size([1, 3, 1, 1]))

And now we can normalize the data!

t_im = (t_im - mean) / std
t_im.shape
torch.Size([1, 3, 256, 256])

Now all that’s left is to get our predictions:

with torch.inference_mode():
    model.eval()
    preds = model(t_im.cuda())

We use inference_mode here instead of no_grad as inference_mode is a more powerful version of no_grad. Find out more in the docs

preds.shape
torch.Size([1, 17])

Now that we have our predictions, we need to perform the sigmoid operation, find the limit, and grab the ones present:

decoded_preds = torch.sigmoid(preds) > 0.5
decoded_preds
tensor([[False, False, False, False, False, False, False, False, False, False,
         False,  True, False,  True, False, False, False]], device='cuda:0')
from itertools import compress
present_labels = list(compress(
        data=list(different_labels), selectors=decoded_preds[0]
    ))
present_labels = list(compress(
        data=list(different_labels), selectors=decoded_preds[0]
    ))

compress(
        data=list(different_labels), selectors=decoded_preds[0]
    )

The compress function creates an iterator that filters elements based on some boolean array, which is what our decoded_preds are originally. We can use this to find what labels are actually present!

present_labels
['partly_cloudy', 'primary']

And now we’ve successfully done what fastai does during predictions end-to-end:

learn.predict(fname)[0]
(#2) ['partly_cloudy','primary']

Here’s the code again for a quick copy-paste:

im = Image.open(fname)
im = im.convert("RGB")
t_im = PILToTensor()(im)

mean, std = (
    [0.485, 0.456, 0.406], 
    [0.229, 0.224, 0.225]
)
vector = [1]*4
vector[1] = -1
mean = tensor(mean).view(*vector)
std = tensor(std).view(*vector)
t_im = (t_im - mean) / std
with torch.inference_mode():
    model.eval()
    preds = model(t_im.cuda())
    
decoded_preds = torch.sigmoid(preds) > 0.5

present_labels = list(compress(
        data=list(different_labels), selectors=decoded_preds[0]
    ))