A guide for showing how to use the learning rate finder to select an ideal learning rate

This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, and wwf currently running at the time of writing this:

  • fastai: 2.0.16
  • fastcore: 1.2.2
  • wwf: 0.0.5

The importance of a good Learning Rate

Before we get started, there's a few questions we need to understand.

Why bother selecting a learning rate? Why not use the defaults?

Quite simply, a bad learning rate can mean bad performance. There are 2 ways this can happen.

  1. Learning too slowly: If the learning rate is too small it will take a really long time to train your model. This can mean that to get a model of the same accuracy, you either would need to spend more time or more money. Said another way, it will either take longer to train the model using the same hardware or you will need more expensive hardware (or some combination of the two).

  2. Learning too quickly: If the learning rate is too large, the steps it takes will be so big it overshoots what is an optimal model. Quite simply your accuracy will just bounce all over the place rather than steadily improving.

So we need a learning rate that is not too big, but not too small. How can we thread the needle?

Isn't there some automated way to select a learning rate?

The short answer is no, there isn't. There are some guidelines available that will be covered, but ultimately there is no sure-fire automated way to automated selectig a learning rate. The best method is to to use the learning rate finder.

The Problem

We will be identifying cats vs dogs.

To get our model started, we will import the library.

from fastai.vision.all import *

Then download the data

path = untar_data(URLs.PETS)

Now, let's organize the data in a dataloader.

files = get_image_files(path/"images")
def label_func(f): return f[0].isupper()
dls = ImageDataLoaders.from_name_func(path, files, label_func, item_tfms=Resize(224))

Now we can look at the pictures we are looking to classify. We are predicting whether these images are a cat or not.

dls.show_batch(max_n=3)

Now, we can create a learner to do the classification.

learn = cnn_learner(dls, resnet34, metrics=error_rate)

Recorder.plot_lr_find[source]

Recorder.plot_lr_find(suggestions=False, skip_end=5, lr_min=None, lr_steep=None)

Plot the result of an LR Finder test (won't work if you didn't do learn.lr_find() before)

Learner.lr_find[source]

Learner.lr_find(start_lr=1e-07, end_lr=10, num_it=100, stop_div=True, show_plot=True, suggestions=True)

Launch a mock training to find a good learning rate, return lr_min, lr_steep if suggestions is True

Learning Rate Finder

Finally we can get to the main topic of this tutorial. I have modified the learning rate finder from fastai to add dots at the reccomended locations. We can see a couple of red dots as fast reference points, but it is still on us to pick the value. It's a bit of an art.

What we are looking for is a logical place on the graph where the loss is decreasing. The red dots on the graph indicate the minimum value on the graph divided by 10, as well as the steepest point on the graph.

We can see that in this case, both the dots line up on the curve. Anywhere in that range will be a good guess for a starting learning rate.

learn.lr_find()
SuggestedLRs(lr_min=0.010000000149011612, lr_steep=0.0008317637839354575)

Now we will fine tune the model as a first training step.

learn.fine_tune(1, base_lr = 9e-3)
epoch train_loss valid_loss error_rate time
0 0.087485 0.025303 0.006766 00:15
epoch train_loss valid_loss error_rate time
0 0.077561 0.028005 0.010825 00:15

Now that we have done some training, we will need to re-run the learning rate finder. As the model changes and trains, we can find a new 'best' learning rate.

When we run it below, we see the graph is a bit tricker. We definitely don't want the point to the far left where the loss is spiking. But we also don't want the point on the right where the loss is increasing. For this we will find a value between the two on that curve where loss is decreasing and train some more.

learn.lr_find()
SuggestedLRs(lr_min=9.12010818865383e-08, lr_steep=1.0964781722577754e-06)

We end up with a 0.6% error rate. Not bad!

learn.fit_one_cycle(1,9e-7)
epoch train_loss valid_loss error_rate time
0 0.040149 0.021365 0.006089 00:15