Fine-tuning pre-trained LM from HuggingFace model hub on GLUE benchmark
 
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from fastai.text.all import *

from datasets import load_dataset, concatenate_datasets
from inspect import signature
import gc

This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, transformers, and datasets currently running at the time of writing this:

  • fastai : 2.3.1
  • fastcore : 1.3.19
  • transformers : 4.6.0
  • datasets : 1.6.1

Setup

In this notebook we will look at how to conbine the power of HuggingFace with great flexibility of fastai. For this purpose we will be finetuning distilroberta-base on The General Language Understanding Evaluation(GLUE) benchmark tasks.

To give you a grasp on what are we dealing with, here is a brief summary of GLUE tasks:

Name Task description Size Metrics
cola Corpus of Linguistic Acceptability Determine whether it is a grammatical sentence 8.5k matthews_corrcoef
sst2 Stanford Sentiment Treebank Predict the sentiment of a givensentence 67k accuracy
mrpc Microsoft Research Paraphrase Corpus Determine whether the sentences in the pair are semantically equivalent 3.7k f1/accuracy
stsb Semantic Textual Similarity Benchmark Determine similarity score for 2 sentences 7k pearsonr/spearmanr
qqp Quora question pair Determine if 2 questions are the same (paraphrase) 364k f1/accuracy
mnli Mulit-Genre Natural Language Inference Predict whether the premise entails, contradicts or is neutral to the hypothesis 393k accuracy
qnli Stanford Question Answering Dataset Determine whether the context sentence containsthe answer to the question 105k accuracy
rte Recognize Textual Entailment Determine whether one sentece entails another 2.5k accuracy
wnli Winograd Schema Challenge Predict if the sentence with the pronoun substituted is entailed by the original sentence 634 accuracy

Let's define main settings for the run in one place. You can choose any model from wide variety presented in HuggingFace model hub. Some might need special treatment to work but most models of appropriate class should be plug-and-play.

ds_name = 'glue'
model_name = "distilroberta-base"

max_len = 512
bs = 32
val_bs = bs*2

n_epoch = 4
lr = 2e-5
opt_func = Adam

To make switching between datasets smooth I'll define couple of dictionaries containing per-task information. We'll need metrics, text fields to retrieve data and number of outputs for the model.

GLUE_TASKS = ["cola", "mnli", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
def validate_task():
    assert task in GLUE_TASKS
glue_metrics = {
    'cola':[MatthewsCorrCoef()],
    'sst2':[accuracy],
    'mrpc':[F1Score(), accuracy],
    'stsb':[PearsonCorrCoef(), SpearmanCorrCoef()],
    'qqp' :[F1Score(), accuracy],
    'mnli':[accuracy],
    'qnli':[accuracy],
    'rte' :[accuracy],
    'wnli':[accuracy],
}

glue_textfields = {
    'cola':['sentence', None],
    'sst2':['sentence', None],
    'mrpc':['sentence1', 'sentence2'],
    'stsb':['sentence1', 'sentence2'],
    'qqp' :['question1', 'question2'],
    'mnli':['premise', 'hypothesis'],
    'qnli':['question', 'sentence'],
    'rte' :['sentence1', 'sentence2'],
    'wnli':['sentence1', 'sentence2'],
}

glue_num_labels = {'mnli':3, 'stsb':1}

Data preprocessing

We'll be using datasets library for HuggingFace to get data:

task = 'mrpc'; validate_task()
ds = load_dataset(ds_name, task)
Reusing dataset glue (/root/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)

MNLI datasets contains 2 sets for validation: matched and missmatched. The mathced set is selected here for validation when fine-tuning on MNLI.

valid_ = 'validation-matched' if task=='mnli' else 'validation'
len(ds['train']), len(ds[valid_])
(3668, 408)
nt, nv = len(ds['train']), len(ds[valid_])
train_idx, valid_idx = L(range(nt)), L(range(nt, nt+nv))
train_ds = concatenate_datasets([ds['train'], ds[valid_]])

One can inspect single example for the given task:

train_ds[0]
{'idx': 0,
 'label': 1,
 'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .',
 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .'}

Here I use number of characters a proxy for length of tokenized text to speed up dls creation.

lens = train_ds.map(lambda s: {'len': sum([len(s[i]) for i in glue_textfields[task] if i])},
                    remove_columns=train_ds.column_names, num_proc=2, keep_in_memory=True)
train_lens = lens.select(train_idx)['len']
valid_lens = lens.select(valid_idx)['len']

DataBlock and Transforms

TextGetter is analogous to ItemGetter but retrieves either one or two text fields from the source (e.g. "sentence1" and "sentence2").

class TextGetter(ItemTransform):
    def __init__(self, s1='text', s2=None):
        self.s1, self.s2 = s1, s2
    def encodes(self, sample):
        if self.s2 is None: return sample[self.s1]
        else: return sample[self.s1], sample[self.s2]

Transformers expect two parts of text to be concatenated with some SEP token in between. But when displaying the batch it's better to have those texts in separate columns for better readability. To make it work I define a version of show_batch to be dispatched on the TransTensorText class. It will handle cases when there is single decoded text or a tuple of two texts.

class TransTensorText(TensorBase): pass

@typedispatch
def show_batch(x:TransTensorText, y, samples, ctxs=None, max_n=10, trunc_at=150, **kwargs):
    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))
    if isinstance(samples[0][0], tuple):
        samples = L((*s[0], *s[1:]) for s in samples)
        if trunc_at is not None: samples = L((s[0].truncate(trunc_at), s[1].truncate(trunc_at), *s[2:]) for s in samples)
    if trunc_at is not None: samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)
    ctxs = show_batch[object](x, y, samples, max_n=max_n, ctxs=ctxs, **kwargs)
    display_df(pd.DataFrame(ctxs))
def find_first(t, e):
    for i, v in enumerate(t):
        if v == e: return i
        
def split_by_sep(t, sep_tok_id):
    idx = find_first(t, sep_tok_id)
    return t[:idx], t[idx:]

Tokenization of the inputs will be done by TokBatchTransform which wraps pre-trained HuggingFace tokenizer. The text processing is done in batches for speed-up. We want to awoid explicit python loops when possible.

class TokBatchTransform(Transform):
    """
    Tokenizes texts in batches using pretrained HuggingFace tokenizer.
    The first element in a batch can be single string or 2-tuple of strings.
    If `with_labels=True` the "labels" are added to the output dictionary.
    """
    def __init__(self, pretrained_model_name=None, tokenizer_cls=AutoTokenizer, 
                 config=None, tokenizer=None, with_labels=False,
                 padding=True, truncation=True, max_length=None, **kwargs):
        if tokenizer is None:
            tokenizer = tokenizer_cls.from_pretrained(pretrained_model_name, config=config)
        self.tokenizer = tokenizer
        self.kwargs = kwargs
        self._two_texts = False
        store_attr()
    
    def encodes(self, batch):
        # batch is a list of tuples of ({text or (text1, text2)}, {targets...})
        if is_listy(batch[0][0]): # 1st element is tuple
            self._two_texts = True
            texts = ([s[0][0] for s in batch], [s[0][1] for s in batch])
        elif is_listy(batch[0]): 
            texts = ([s[0] for s in batch],)

        inps = self.tokenizer(*texts,
                              add_special_tokens=True,
                              padding=self.padding,
                              truncation=self.truncation,
                              max_length=self.max_length,
                              return_tensors='pt',
                              **self.kwargs)
        # inps are batched, collate targets into batches too
        labels = default_collate([s[1:] for s in batch])
        if self.with_labels:
            inps['labels'] = labels[0]
            res = (inps, )
        else:
            res = (inps, ) + tuple(labels)
        return res
    
    def decodes(self, x:TransTensorText):
        if self._two_texts:
            x1, x2 = split_by_sep(x, self.tokenizer.sep_token_id)
            return (TitledStr(self.tokenizer.decode(x1.cpu(), skip_special_tokens=True)),
                    TitledStr(self.tokenizer.decode(x2.cpu(), skip_special_tokens=True)))
        return TitledStr(self.tokenizer.decode(x.cpu(), skip_special_tokens=True))

The batches processed by TokBatchTransform contain a dictionary as the first element. For decoding it's handy to have a tensor instead. The Undict transform fethces input_ids from the batch and creates TransTensorText which should work with typedispatch.

class Undict(Transform):
    def decodes(self, x:dict):
        if 'input_ids' in x: res = TransTensorText(x['input_ids'])
        return res

Now the transforms are to be combined inside a data block to be used for dls creation. The inputs are prebatched by TokBatchTranform so we don't need to use fa_collate for batching, so fa_convert is passed in as for "create_batch".

The texts we processing are of different lengths. Each sample in the batch is padded to the length of longest input to make them "collatable". Shuffling samples randomly will therefor result in getting longer batches on average. As the compute time depends on the sequence length this is udesired. SortedDL groups the inputs by length and if shuffle=True those are shuffled within certain interval keeping samples of similar length together.

dls_kwargs = {
    'before_batch': TokBatchTransform(pretrained_model_name=model_name, max_length=max_len),
    'create_batch': fa_convert
}
text_block = TransformBlock(dl_type=SortedDL, dls_kwargs=dls_kwargs, batch_tfms=Undict(), )

dblock = DataBlock(blocks = [text_block, CategoryBlock()],
                   get_x=TextGetter(*glue_textfields[task]),
                   get_y=ItemGetter('label'),
                   splitter=IndexSplitter(valid_idx))
%%time
dl_kwargs=[{'res':train_lens}, {'val_res':valid_lens}]
dls = dblock.dataloaders(train_ds, bs=bs, val_bs=val_bs, dl_kwargs=dl_kwargs)
CPU times: user 2.92 s, sys: 858 ms, total: 3.78 s
Wall time: 3.79 s
dls.show_batch(max_n=4)
text text_ category
0 The Securities and Exchange Commission yesterday said companies trading on the biggest U.S. markets must win shareholder approval before granting stock options and other stock-based compensation plans to corporate executives. Companies trading on the biggest stock markets must get shareholder approval before granting stock options and other equity compensation under rules cleared yesterday by the Securities and Exchange Commission. 1
1 " The investigation appears to be focused on certain accounting practices common to the interactive entertainment industry, with specific emphasis on revenue recognition, " Activision said in an SEC filing. According to the company filings, the investigation " appears to be focused on certain accounting practices common to the interactive entertainment industry, with specific emphasis on revenue recognition. " 1
2 The U.N. nuclear watchdog reprimanded Iran on Thursday for failing to comply with its nuclear safeguards obligations and called on Tehran to unconditionally accept stricter inspections by the agency. The U.N. atomic watchdog rapped Iran Thursday for failing to comply with nuclear safeguards, issuing a statement Washington said underlined international opposition to Tehran developing any banned weapons. 0
3 Microsoft favors setting up " independent e-mail trust authorities to establish and maintain commercial email guidelines, certify senders who follow the guidelines, and resolve customer disputes. " Gates says he wants to see " independent e-mail trust authorities " who " establish and maintain commercial email guidelines, certify senders who follow the guidelines, and resolve customer disputes. " 1

Customized Learner

Now the xb we get from dataloader contains a dictionary and HuggingFace transformers accept keyword argument as input. But fastai Learner feeds the model with a sequence of positional arguments (self.pred = self.model(*self.xb)). To make this work smoothly we can create a callback to handle unrolling of the input dict into proper xb tuple.

But first we need to define some utility functions. default_splitter is used to divide model parameters into groups:

def default_splitter(model):
    groups = L(model.base_model.children()) + L(m for m in list(model.children())[1:] if params(m))
    return groups.map(params)

Similar to show_batch one have to customize show_results:

@typedispatch
def show_results(x: TransTensorText, y, samples, outs, ctxs=None, max_n=10, trunc_at=150, **kwargs):
    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))
    if isinstance(samples[0][0], tuple):
        samples = L((*s[0], *s[1:]) for s in samples)
        if trunc_at is not None: samples = L((s[0].truncate(trunc_at), s[1].truncate(trunc_at), *s[2:]) for s in samples)
    elif trunc_at is not None: samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)
    ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)
    display_df(pd.DataFrame(ctxs))
    return ctxs

TransLearner itself doesn't do much: it adds TransCallback and sets splitter to be default_splitter if None is provided.

@delegates(Learner.__init__)
class TransLearner(Learner):
    "Learner for training transformers from HuggingFace"
    def __init__(self, dls, model, **kwargs):
        splitter = kwargs.get('splitter', None)
        if splitter is None: kwargs['splitter'] = default_splitter
        super().__init__(dls, model, **kwargs)
        self.add_cb(TransCallback(model))

Main piece of work needed to train transformers model happens in TransCallback. It saves valid model argument and makes input dict yielded by dataloader into a tuple.

By default the model returns a dictionary-like object containing logits and possibly other outputs as defined by model config (e.g. intermediate hidden representations). In the fastai training loop we usually expect preds to be a tensor containing model predictions (logits). The callback formats the preds properly.

Notice that if labels are found in the input, transformer models compute the loss and return it together with output logits. The callback below is designed to utilise the loss returned by model instead of recomputing it using learn.loss_func. This is not actually used in this example but might be handy in some use cases.

class TransCallback(Callback):
    "Handles HuggingFace model inputs and outputs"
    
    def __init__(self, model):
        self.labels = tuple()
        self.model_args = {k:v.default for k, v in signature(model.forward).parameters.items()}
    
    def before_batch(self):
        if 'labels' in self.xb[0].keys():
            self.labels = (self.xb[0]['labels'], )
        # make a tuple containing an element for each argument model excepts
        # if argument is not in xb it is set to default value
        self.learn.xb = tuple([self.xb[0].get(k, self.model_args[k]) for k in self.model_args.keys()])
    
    def after_pred(self):
        if 'loss' in self.pred:
            self.learn.loss_grad = self.pred.loss
            self.learn.loss = self.pred.loss.clone()
        self.learn.pred = self.pred.logits
    
    def after_loss(self):
        if len(self.labels):
            self.learn.yb = self.labels
            self.labels = tuple()

Training

After all the preparations the training is streightforward. Setting num_labels for the model and choosing apropriate metrics is automated.

model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=glue_num_labels.get(task, 2))
metrics = glue_metrics[task]
learn = TransLearner(dls, model, metrics=metrics, opt_func=opt_func)
learn.summary()
RobertaForSequenceClassification (Input shape: 32)
============================================================================
Layer (type)         Output Shape         Param #    Trainable 
============================================================================
                     32 x 101 x 768      
Embedding                                 38603520   True      
Embedding                                 394752     True      
Embedding                                 768        True      
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Linear                                    590592     True      
Linear                                    590592     True      
Dropout                                                        
Linear                                    590592     True      
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     32 x 101 x 3072     
Linear                                    2362368    True      
____________________________________________________________________________
                     32 x 101 x 768      
Linear                                    2360064    True      
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Linear                                    590592     True      
Linear                                    590592     True      
Dropout                                                        
Linear                                    590592     True      
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     32 x 101 x 3072     
Linear                                    2362368    True      
____________________________________________________________________________
                     32 x 101 x 768      
Linear                                    2360064    True      
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Linear                                    590592     True      
Linear                                    590592     True      
Dropout                                                        
Linear                                    590592     True      
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     32 x 101 x 3072     
Linear                                    2362368    True      
____________________________________________________________________________
                     32 x 101 x 768      
Linear                                    2360064    True      
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Linear                                    590592     True      
Linear                                    590592     True      
Dropout                                                        
Linear                                    590592     True      
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     32 x 101 x 3072     
Linear                                    2362368    True      
____________________________________________________________________________
                     32 x 101 x 768      
Linear                                    2360064    True      
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Linear                                    590592     True      
Linear                                    590592     True      
Dropout                                                        
Linear                                    590592     True      
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     32 x 101 x 3072     
Linear                                    2362368    True      
____________________________________________________________________________
                     32 x 101 x 768      
Linear                                    2360064    True      
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Linear                                    590592     True      
Linear                                    590592     True      
Dropout                                                        
Linear                                    590592     True      
LayerNorm                                 1536       True      
Dropout                                                        
____________________________________________________________________________
                     32 x 101 x 3072     
Linear                                    2362368    True      
____________________________________________________________________________
                     32 x 101 x 768      
Linear                                    2360064    True      
LayerNorm                                 1536       True      
Dropout                                                        
Linear                                    590592     True      
Dropout                                                        
____________________________________________________________________________
                     32 x 2              
Linear                                    1538       True      
____________________________________________________________________________

Total params: 82,119,938
Total trainable params: 82,119,938
Total non-trainable params: 0

Optimizer used: <function Adam at 0x7fa9069c1268>
Loss function: FlattenedLoss of CrossEntropyLoss()

Callbacks:
  - TrainEvalCallback
  - TransCallback
  - Recorder
  - ProgressCallback
metric_to_monitor = metrics[0].name if isinstance(metrics[0], Metric) else metrics[0].__name__
cbs = [SaveModelCallback(monitor=metric_to_monitor)]
learn.fit_one_cycle(4, lr, cbs=cbs)
epoch train_loss valid_loss f1_score accuracy time
0 0.599429 0.508199 0.832551 0.737745 00:21
1 0.436955 0.337732 0.892193 0.857843 00:22
2 0.318495 0.331339 0.900175 0.860294 00:22
3 0.232149 0.354381 0.897747 0.855392 00:22
Better model found at epoch 0 with f1_score value: 0.8325508607198748.
Better model found at epoch 1 with f1_score value: 0.8921933085501859.
Better model found at epoch 2 with f1_score value: 0.9001751313485115.

After training the model it's useful to verify that results make sense:

learn.show_results()
text text_ category category_
0 The delegates said raising and distributing funds has been complicated by the U.S. crackdown on jihadi charitable foundations, bank accounts of terror-related organizations and money transfers. Bin Laden ’ s men pointed out that raising and distributing funds has been complicated by the U.S. crackdown on jihadi charitable foundations, bank accounts of terror-related organizations and money transfers. 1 1
1 The attack followed several days of disturbances in the city where American soldiers exchanged fire with an unknown number of attackers as civilians carried out demonstrations against the American presence. The attack came after several days of disturbance in the city in which U.S. soldiers exchanged fire with an unknown number of attackers as civilians protested the American presence. 1 1
2 Massachusetts regulators and the Securities and Exchange Commission on Tuesday pressed securities fraud charges against Putnam Investments and two of its former portfolio managers for alleged improper mutual fund trading. State and federal securities regulators filed civil charges against Putnam Investments and two portfolio managers in the ever-expanding mutual fund trading scandal. 1 1
3 Justice Minister Martin Cauchon and Prime Minister Jean Chrétien have both said the Liberal government will introduce legislation soon to decriminalize possession of small amounts of pot for personal use. Justice Minister Martin Cauchon and Prime Minister Jean Chretien both have said the government will introduce legislation to decriminalize possession of small amounts of pot. 1 1
4 Myanmar's pro-democracy leader Aung San Suu Kyi will return home late Friday but will remain in detention after recovering from surgery at a Yangon hospital, her personal physician said. Myanmar's pro-democracy leader Aung San Suu Kyi will be kept under house arrest following her release from a hospital where she underwent surgery, her personal physician said Friday. 1 1
5 President Bush raised a record-breaking $ 49.5 million for his re-election campaign over the last three months, with contributions from 262,000 Americans, the president's campaign chairman said Tuesday. President Bush has raised $ 83.9 million since beginning his re-election campaign in May, and has $ 70 million of that left to spend, his campaign said Tuesday. 0 0
6 Barry Callebaut will be able to use Brach's retail network to sell products made from its German subsidiary Stollwerck, which makes chocolate products not sold in the United States. Barry Callebaut will be able to use Brach's retail network to sell products made from its German subsidiary Stollwerck, which makes chocolate products unknown to the American market. 1 1
7 Donations stemming from the Sept. 11 attacks helped push up contributions to human service organizations and large branches of the United Way by 15 percent and 28.6 percent, respectively. Donations stemming from the Sept. 11 attacks helped push up contributions to human service organizations by 15 percent and to large branches of the United Way by 28.6 percent. 1 1
8 The Guru microcontroller serves four functions : hardware monitoring, overclocking management, BIOS ( Basic Input Output System ) update and a troubleshooting-assistance feature called Black Box. The µGuru microcontroller serves four functions : hardware monitoring, overclocking management, BIOS update and a troubleshooting-assistance feature called Black Box. 1 1

Finally we can run our model on test set to get the predictions.

test_dl = dls.test_dl(ds['test'])
preds = learn.get_preds(dl=test_dl)
preds[0]
tensor([[0.0121, 0.9879],
        [0.1296, 0.8704],
        [0.0072, 0.9928],
        ...,
        [0.0159, 0.9841],
        [0.0046, 0.9954],
        [0.1612, 0.8388]])

Final remarks

Generalised versions of "wrapper" code used in this notebook can be found in fasthugs library. Also you can check out some extra info on fine-tuning models on GLUE tasks in this blogpost. Another option for training HuggingFace transformers with fastai is using blurr library.