Extending fastai's `show_training_loop` to be more verbose about event triggers

This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, and wwf currently running at the time of writing this:

  • fastai: 2.2.5
  • fastcore: 1.3.19
  • wwf: 0.0.10

Understanding fastai's Training Loop

fastai's training loop is certainly unique in its approach, where everything is handled through Callbacks. What this means is there should never be an instance where if you need to modify fastai's training loop you are modifying Learner's source code.

Instead we can use various trigger points through Callbacks to get there. Currently fastai has a methodology of showing just what Callbacks are called during the training loop through a function called Learner.show_training_loop

show_training_loop

The goal of show_training_loop is to show the user just what Callbacks are triggered during fastai's entire training cycle. An example is provided below:

from fastai.callback.all import *
from fastai.test_utils import synth_learner

learn = synth_learner()
learn.show_training_loop()
Start Fit
   - before_fit     : [TrainEvalCallback, Recorder, ProgressCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : []
         - after_pred     : []
         - after_loss     : []
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, ProgressCallback]
  End Epoch Loop
   - after_cancel_epoch: []
   - after_epoch    : [Recorder]
End Fit
 - after_cancel_fit: []
 - after_fit      : [ProgressCallback]

As we can see, every major event is detailed with a Start and Finish, and the intermediate steps at each level are described. Paired with this are the Callbacks that get triggered at that particular event.

However, I think we can take this a step further to enable you to understand just what happens during each step. As a result, I've written a revised version of Learner.show_training_loop:

Learner.show_training_loop[source]

Learner.show_training_loop(verbose:bool=False, cbs:Union[NoneType, list, Callback]=None)

Show each step in the training loop, potentially with Callback event descriptions

With this new version we can pass in a verbose tag and for every Callback and its events we will pull its documentation string, so we can see what happens at each step as shown below:

learn.show_training_loop(verbose=True)
Start Fit
    - before_fit:
        - TrainEvalCallback: 
            - Set the iter and epoch counters to 0, put the model and the right device
        - Recorder: 
            - Prepare state for training
        - ProgressCallback: 
            - Setup the master bar over the epochs
   Start Epoch Loop
       - before_epoch:
           - Recorder: 
               - Set timer if `self.add_time=True`
           - ProgressCallback: 
               - Update the master bar
      Start Train
          - before_train:
              - TrainEvalCallback: 
                  - Set the model in training mode
              - Recorder: 
                  - Reset loss and metrics state
              - ProgressCallback: 
                  - Launch a progress bar over the training dataloader
         Start Batch Loop
             - before_batch:
             - after_pred:
             - after_loss:
             - before_backward:
             - before_step:
             - after_step:
             - after_cancel_batch:
             - after_batch:
                 - TrainEvalCallback: 
                     - Update the iter counter (in training mode)
                 - Recorder: 
                     - Update all metrics and records lr and smooth loss in training
                 - ProgressCallback: 
                     - Update the current progress bar
         End Batch Loop
      End Train
       - after_cancel_train:
           - Recorder: 
               - Ignore training metrics for this epoch
       - after_train:
           - Recorder: 
               - Log loss and metric values on the training set (if `self.training_metrics=True`)
           - ProgressCallback: 
               - Close the progress bar over the training dataloader
      Start Valid
          - before_validate:
              - TrainEvalCallback: 
                  - Set the model in validation mode
              - Recorder: 
                  - Reset loss and metrics state
              - ProgressCallback: 
                  - Launch a progress bar over the validation dataloader
         Start Batch Loop
             - **CBs same as train batch**:
         End Batch Loop
      End Valid
       - after_cancel_validate:
           - Recorder: 
               - Ignore validation metrics for this epoch
       - after_validate:
           - Recorder: 
               - Log loss and metric values on the validation set
           - ProgressCallback: 
               - Close the progress bar over the validation dataloader
   End Epoch Loop
    - after_cancel_epoch:
    - after_epoch:
        - Recorder: 
            - Store and log the loss/metric values
End Fit
 - after_cancel_fit:
 - after_fit:
     - ProgressCallback: 
         - Close the master bar

Usage Example:

To use this functionality, simply do:

from wwf.basics.training_loop import *

And then call learn.show_training_loop(verbose=True)