Lesson Video:


This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, fastdot, wwf, fast_tabnet, and pytorch_tabnet currently running at the time of writing this:

  • fastai: 2.1.10
  • fastcore: 1.3.13
  • fastdot: 0.1.4
  • wwf: 0.0.8
  • fast_tabnet: 0.0.8
  • pytorch_tabnet: 1.0.6

In this notebook we'll be looking at comparing the TabNet architecture to our regular fastai fully connected models. We'll be utilizing Michael Grankin's fast_tabnet wrapper to utilize the model.

TabNet is an attention-based network for tabular data, originating here. Let's first look at our fastai architecture and then compare it with TabNet utilizing the fastdot library.

First let's build our data real quick so we know just what we're visualizing. We'll use ADULTs again

from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')

We'll build our TabularPandas object and the DataLoaders

cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter()(range_of(df))
to = TabularPandas(df, procs, cat_names, cont_names, y_names="salary", splits=splits)
dls = to.dataloaders(bs=1)

And let's look at one batch to understand how the data is coming into the model:

dls.one_batch()
(tensor([[ 8, 10,  3,  5,  1,  5,  1]], device='cuda:0'),
 tensor([[-0.1178, -1.5307,  1.1407]], device='cuda:0'),
 tensor([[1]], device='cuda:0', dtype=torch.int8))

So we can see first is our categoricals, second is our continuous, and the third is our y. With this in mind, let's make a TabularModel with 200 and 100 layer sizes:

learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)

Now a basic visualization of this model can be made with fastdot like below:

from fastdot import *
G cluster_na91be200e7d64a10b9c34be7153a2602 Preprocessed Input cluster_n4c27260ebd33499fadd869ebab199017 Categorical Embeddings cluster_nd92d5cb5404243a59c3bf3f25380664a Continous Batch Normalization cluster_n9ae7b982ef6346f5811b48b2e68a5a71 Fully Connected Layers cluster_ne465e523d0d34a35a1e98dec51c4772e LinBnDrop nb12c74ec7e5349c48e4aa1868b1d2879 Input n173e4f7225d04cfe803de4ee70cd5cc3 Embedding Matrix nb12c74ec7e5349c48e4aa1868b1d2879->n173e4f7225d04cfe803de4ee70cd5cc3 n83ff9133ec3541459b9b2e625e38dcc3 BatchNorm1d nb12c74ec7e5349c48e4aa1868b1d2879->n83ff9133ec3541459b9b2e625e38dcc3 nbe4bd1942696463a9d24b4a4053506b5 Dropout n173e4f7225d04cfe803de4ee70cd5cc3->nbe4bd1942696463a9d24b4a4053506b5 nc16deee45ab24500a57043d4adea7425 LinBnDrop (ni, 200) nbe4bd1942696463a9d24b4a4053506b5->nc16deee45ab24500a57043d4adea7425 n83ff9133ec3541459b9b2e625e38dcc3->nc16deee45ab24500a57043d4adea7425 n65d01e6550924feeab3d6b1d6d19d899 LinBnDrop (200,100) nc16deee45ab24500a57043d4adea7425->n65d01e6550924feeab3d6b1d6d19d899 nf6f192cefb094c57be63301d94fba190 LinBnDrop (100,2) n65d01e6550924feeab3d6b1d6d19d899->nf6f192cefb094c57be63301d94fba190 n6b018acdf79f4bf48d5c21811ebd1daf BatchNorm1d na457b182c7e8412ab3b1b0d6835a9d94 Dropout n6b018acdf79f4bf48d5c21811ebd1daf->na457b182c7e8412ab3b1b0d6835a9d94 n12828734e9b448b79a394470dba1607c Linear na457b182c7e8412ab3b1b0d6835a9d94->n12828734e9b448b79a394470dba1607c n170e43e36dfa493d837d376643f51c14 ReLU n12828734e9b448b79a394470dba1607c->n170e43e36dfa493d837d376643f51c14

How does this compare to TabNet? This is TabNet:

G cluster_nf4682770916e4c768e62291233bc0090 Preprocessed Input cluster_n483dd59d1f874b0c83790bfe94eeca2a Categorical Embeddings cluster_nc9abee62c1cd484d989833b2105c60be Continous Batch Normalization cluster_n5536da0750cb42d39e8f3b22c2f2af02 TabNet n3943c774972a444cbc55212e75621079 Input n0509cc37baef489c9ba0a471f201cce3 Embedding Matrix n3943c774972a444cbc55212e75621079->n0509cc37baef489c9ba0a471f201cce3 n4d6fd1311d4641e9a5f26d0544fe035b BatchNorm1d n3943c774972a444cbc55212e75621079->n4d6fd1311d4641e9a5f26d0544fe035b ne7dad90087dd4d4e9238fbce1cff8e2a Dropout n0509cc37baef489c9ba0a471f201cce3->ne7dad90087dd4d4e9238fbce1cff8e2a nb3ff49af3fa34f748a2a3cb3dbd91695 Attention Transformer ne7dad90087dd4d4e9238fbce1cff8e2a->nb3ff49af3fa34f748a2a3cb3dbd91695 n4d6fd1311d4641e9a5f26d0544fe035b->nb3ff49af3fa34f748a2a3cb3dbd91695 ncf4ef6ec8fa74146b9c1eeedd9d10b88 Feature Transformer nb3ff49af3fa34f748a2a3cb3dbd91695->ncf4ef6ec8fa74146b9c1eeedd9d10b88 ncccf4576d55c49569be56fadb5852b9b Final Mapping (Linear) ncf4ef6ec8fa74146b9c1eeedd9d10b88->ncccf4576d55c49569be56fadb5852b9b n5f9b1f8938e94c1db0197b0b9f528ca4 Output ncccf4576d55c49569be56fadb5852b9b->n5f9b1f8938e94c1db0197b0b9f528ca4 nc290d806db5c475da0f773b92ef7269b Mask_Loss ncccf4576d55c49569be56fadb5852b9b->nc290d806db5c475da0f773b92ef7269b n263282a5c4fe4d80b0d5e10c73206d94 Mask_Explain ncccf4576d55c49569be56fadb5852b9b->n263282a5c4fe4d80b0d5e10c73206d94 n8ed7e0e6ea1a409a935692b390f2cdd0 Masks ncccf4576d55c49569be56fadb5852b9b->n8ed7e0e6ea1a409a935692b390f2cdd0
G cluster_nd0695641687e489bbd3b3ad04b9c79c3 Feature Transformer n66405355f1964cabaa65035ed88afe99 Linear (ni, 80) n0376e1b9ce344c0e9e44bc181dde77a9 Linear (ni-2, 80) n66405355f1964cabaa65035ed88afe99->n0376e1b9ce344c0e9e44bc181dde77a9 n5677a0fc8147465f8e333215fdb8b2b2 GLU Block n0376e1b9ce344c0e9e44bc181dde77a9->n5677a0fc8147465f8e333215fdb8b2b2
G cluster_n837465b021f54591a15faa62817a5aa6 Attention Transformer nee86858db37944db9158eee76f088c02 Linear n5bcd2c7878844d6093f87e28eca59a08 GhostBatchNorm nee86858db37944db9158eee76f088c02->n5bcd2c7878844d6093f87e28eca59a08 n72a923eec948450cbea9d671fd2ea701 torch.mul(x, prior) n5bcd2c7878844d6093f87e28eca59a08->n72a923eec948450cbea9d671fd2ea701 n65dbf70391b94dca96d84a5b918b8e8c Sparsemax n72a923eec948450cbea9d671fd2ea701->n65dbf70391b94dca96d84a5b918b8e8c

So a few things to note, we now have two transformers, one that keeps an eye on the features and another that keeps an eye on the attention. We could call the Attention transformer the encoder and the Feature transformer the decoder. What this attention let's us do is see exactly how our model is behaving, moreso than just how our feature importance and other techniques "guess"

Now that we have this done, how do we make a model?

Using TabNet

I have found in my experiments that TabNet isn't quite as good as fastai's tabular model, but as attention can be important and is a hot topic, we'll use it here. Another con of this model is it takes many epochs to get a decent accuracy as we will see:

First we need to grab the embedding matrix sizes:

emb_szs = get_emb_sz(to); emb_szs
[(10, 6), (17, 8), (8, 5), (16, 8), (7, 5), (6, 4), (3, 3)]

And now we can make use of our model! There's many different values we can pass in, here's a brief summary:

  • n_d: Dimensions of the prediction layer (usually between 4 to 64)
  • n_a: Dimensions of the attention layer (similar to n_d)
  • n_steps: Number of sucessive steps in our network (usually 3 to 10)
  • gamma: A scalling factor for updating attention (usually between 1.0 to 2.0)
  • momentum: Momentum in all batch normalization
  • n_independent: Number of independant GLU layers in each block (default is 2)
  • n_shared: Number of shared GLU layers in each block (default is 2)
  • epsilon: Should be kept very low (avoid log(0)

Let's build one similar to the model we showed in the above. To do so we'll set the dimensions of the prediction layer to 8, the number of attention layer dimensions to 32, and our steps to 4:

from pytorch_tabnet.tab_network import TabNetNoEmbeddings
class TabNetModel(Module):
    "Attention model for tabular data."
    def __init__(self, emb_szs, n_cont, out_sz, embed_p=0., y_range=None,
                 n_d=8, n_a=8,
                 n_steps=3, gamma=1.3,
                 n_independent=2, n_shared=2, epsilon=1e-15,
                 virtual_batch_size=128, momentum=0.02):
        self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs])
        self.emb_drop = nn.Dropout(embed_p)
        self.bn_cont = nn.BatchNorm1d(n_cont)
        n_emb = sum(e.embedding_dim for e in self.embeds)
        self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range
        self.tab_net = TabNetNoEmbeddings(n_emb + n_cont, out_sz, n_d, n_a, n_steps,
                                          gamma, n_independent, n_shared, epsilon, virtual_batch_size, momentum)

    def forward(self, x_cat, x_cont, att=False):
        if self.n_emb != 0:
            x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
            x = torch.cat(x, 1)
            x = self.emb_drop(x)
        if self.n_cont != 0:
            x_cont = self.bn_cont(x_cont)
            x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
        x, m_loss, m_explain, masks = self.tab_net(x)
        if self.y_range is not None:
            x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]
        if att:
            return x, m_loss, m_explain, masks
        else:
            return x

First we need to make new DataLoaders because we currently have a batch size of 1

dls = to.dataloaders(bs=1024)

Then build the model:

net = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=8, n_a=32, n_steps=1); 

Finally we'll build our Learner and use the ranger optimizer:

learn = Learner(dls, net, CrossEntropyLossFlat(), metrics=accuracy, opt_func=ranger)
learn.fit_flat_cos(5, 1e-1)
epoch train_loss valid_loss accuracy time
0 0.494824 0.398702 0.814343 00:01
1 0.418295 0.385029 0.823864 00:01
2 0.388235 0.360466 0.835227 00:01
3 0.373567 0.361757 0.832924 00:01
4 0.363297 0.353453 0.835842 00:01

Now as you can see it actually didn't take that long to get to the 83% fairly quickly. On my other tests I wasn't able to do quite as well but try it out! The code is here for you to use and play with.

dl = learn.dls.test_dl(df.iloc[:20], bs=1)
batch = dl.one_batch()
batch
(tensor([[5, 8, 3, 0, 6, 5, 1]], device='cuda:0'),
 tensor([[ 0.7599, -0.8427,  0.7501]], device='cuda:0'),
 tensor([[1]], device='cuda:0', dtype=torch.int8))
from pytorch_tabnet.tab_model import *
cat_dims = [emb_szs[i][1] for i in range(len(emb_szs))]
cat_dims
[6, 8, 5, 8, 5, 4, 3]

Cat dims are the first bits of the embedding sizes. cat_idxs are what index in the batch our categorical variables come from. In our case it's everything after 3

cat_idxs = [3,4,5,6,7,8, 9]
cat_idxs
[3, 4, 5, 6, 7, 8, 9]
tot = len(to.cont_names) + len(to.cat_names)

The 42 comes from the first input out of the embeddings:

matrix = create_explain_matrix(tot,
                      cat_dims,
                      cat_idxs,
                      42)
dl = learn.dls.test_dl(df.iloc[:20], bs=1)

Now let's patch in an explainer function to Learner

@patch
def explain(x:Learner, dl:TabDataLoader):
  "Get explain values for a set of predictions"
  dec_y = []
  x.model.eval()
  for batch_nb, data in enumerate(dl):
    with torch.no_grad():
      out, M_loss, M_explain, masks = x.model(data[0], data[1], True)
    for key, value in masks.items():
      masks[key] = csc_matrix.dot(value.cpu().numpy(), matrix)
    if batch_nb == 0:
      res_explain = csc_matrix.dot(M_explain.cpu().numpy(),
                                  matrix)
      res_masks = masks
    else:
      res_explain = np.vstack([res_explain,
                              csc_matrix.dot(M_explain.cpu().numpy(),
                                              matrix)])
      for key, value in masks.items():
        res_masks[key] = np.vstack([res_masks[key], value])

    dec_y.append(int(learn.loss_func.decodes(out)))
  return dec_y, res_masks, res_explain

We'll pass in a DataLoader

dec_y, res_masks, res_explain = learn.explain(dl)

And now we can visualize them with plot_explain

def plot_explain(masks, lbls, figsize=(12,12)):
  "Plots masks with `lbls` (`dls.x_names`)"
  fig = plt.figure(figsize=figsize)
  ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
  plt.yticks(np.arange(0, len(masks[0]), 1.0))
  plt.xticks(np.arange(0, len(masks[0][0]), 1.0))
  ax.set_xticklabels(lbls, rotation=90)
  plt.ylabel('Sample Number')
  plt.xlabel('Variable')
  plt.imshow(masks[0])

We pass in the masks and the x_names and we can see for each input how it affected the output:

lbls = dls.x_names
plot_explain(res_masks, lbls)