In this notebook we'll be looking at comparing the TabNet
architecture to our regular fastai
fully connected models. We'll be utilizing Michael Grankin's fast_tabnet
wrapper to utilize the model.
TabNet is an attention-based network for tabular data, originating here. Let's first look at our fastai architecture and then compare it with TabNet utilizing the fastdot
library.
First let's build our data real quick so we know just what we're visualizing. We'll use ADULTs
again
from fastai.tabular.all import *
path = untar_data(URLs.ADULT_SAMPLE)
df = pd.read_csv(path/'adult.csv')
We'll build our TabularPandas
object and the DataLoaders
cat_names = ['workclass', 'education', 'marital-status', 'occupation', 'relationship', 'race']
cont_names = ['age', 'fnlwgt', 'education-num']
procs = [Categorify, FillMissing, Normalize]
splits = RandomSplitter()(range_of(df))
to = TabularPandas(df, procs, cat_names, cont_names, y_names="salary", splits=splits)
dls = to.dataloaders(bs=1)
And let's look at one batch to understand how the data is coming into the model:
dls.one_batch()
So we can see first is our categoricals, second is our continuous, and the third is our y
. With this in mind, let's make a TabularModel
with 200 and 100 layer sizes:
learn = tabular_learner(dls, layers=[200,100], metrics=accuracy)
Now a basic visualization of this model can be made with fastdot
like below:
from fastdot import *
How does this compare to TabNet? This is TabNet:
So a few things to note, we now have two transformers, one that keeps an eye on the features and another that keeps an eye on the attention. We could call the Attention transformer the encoder and the Feature transformer the decoder. What this attention let's us do is see exactly how our model is behaving, moreso than just how our feature importance and other techniques "guess"
Now that we have this done, how do we make a model?
First we need to grab the embedding matrix sizes:
emb_szs = get_emb_sz(to); emb_szs
And now we can make use of our model! There's many different values we can pass in, here's a brief summary:
n_d
: Dimensions of the prediction layer (usually between 4 to 64)n_a
: Dimensions of the attention layer (similar ton_d
)n_steps
: Number of sucessive steps in our network (usually 3 to 10)gamma
: A scalling factor for updating attention (usually between 1.0 to 2.0)momentum
: Momentum in all batch normalizationn_independent
: Number of independant GLU layers in each block (default is 2)n_shared
: Number of shared GLU layers in each block (default is 2)epsilon
: Should be kept very low (avoidlog(0)
Let's build one similar to the model we showed in the above. To do so we'll set the dimensions of the prediction layer to 8, the number of attention layer dimensions to 32, and our steps to 4:
from pytorch_tabnet.tab_network import TabNetNoEmbeddings
class TabNetModel(Module):
"Attention model for tabular data."
def __init__(self, emb_szs, n_cont, out_sz, embed_p=0., y_range=None,
n_d=8, n_a=8,
n_steps=3, gamma=1.3,
n_independent=2, n_shared=2, epsilon=1e-15,
virtual_batch_size=128, momentum=0.02):
self.embeds = nn.ModuleList([Embedding(ni, nf) for ni,nf in emb_szs])
self.emb_drop = nn.Dropout(embed_p)
self.bn_cont = nn.BatchNorm1d(n_cont)
n_emb = sum(e.embedding_dim for e in self.embeds)
self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range
self.tab_net = TabNetNoEmbeddings(n_emb + n_cont, out_sz, n_d, n_a, n_steps,
gamma, n_independent, n_shared, epsilon, virtual_batch_size, momentum)
def forward(self, x_cat, x_cont, att=False):
if self.n_emb != 0:
x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
x = torch.cat(x, 1)
x = self.emb_drop(x)
if self.n_cont != 0:
x_cont = self.bn_cont(x_cont)
x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
x, m_loss, m_explain, masks = self.tab_net(x)
if self.y_range is not None:
x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]
if att:
return x, m_loss, m_explain, masks
else:
return x
First we need to make new DataLoaders
because we currently have a batch size of 1
dls = to.dataloaders(bs=1024)
Then build the model:
net = TabNetModel(emb_szs, len(to.cont_names), dls.c, n_d=8, n_a=32, n_steps=1);
learn = Learner(dls, net, CrossEntropyLossFlat(), metrics=accuracy, opt_func=ranger)
learn.fit_flat_cos(5, 1e-1)
Now as you can see it actually didn't take that long to get to the 83% fairly quickly. On my other tests I wasn't able to do quite as well but try it out! The code is here for you to use and play with.
dl = learn.dls.test_dl(df.iloc[:20], bs=1)
batch = dl.one_batch()
batch
from pytorch_tabnet.tab_model import *
cat_dims = [emb_szs[i][1] for i in range(len(emb_szs))]
cat_dims
Cat dims are the first bits of the embedding sizes. cat_idxs
are what index in the batch our categorical variables come from. In our case it's everything after 3
cat_idxs = [3,4,5,6,7,8, 9]
cat_idxs
tot = len(to.cont_names) + len(to.cat_names)
The 42 comes from the first input out of the embeddings:
matrix = create_explain_matrix(tot,
cat_dims,
cat_idxs,
42)
dl = learn.dls.test_dl(df.iloc[:20], bs=1)
Now let's patch in an explainer
function to Learner
@patch
def explain(x:Learner, dl:TabDataLoader):
"Get explain values for a set of predictions"
dec_y = []
x.model.eval()
for batch_nb, data in enumerate(dl):
with torch.no_grad():
out, M_loss, M_explain, masks = x.model(data[0], data[1], True)
for key, value in masks.items():
masks[key] = csc_matrix.dot(value.cpu().numpy(), matrix)
if batch_nb == 0:
res_explain = csc_matrix.dot(M_explain.cpu().numpy(),
matrix)
res_masks = masks
else:
res_explain = np.vstack([res_explain,
csc_matrix.dot(M_explain.cpu().numpy(),
matrix)])
for key, value in masks.items():
res_masks[key] = np.vstack([res_masks[key], value])
dec_y.append(int(learn.loss_func.decodes(out)))
return dec_y, res_masks, res_explain
We'll pass in a DataLoader
dec_y, res_masks, res_explain = learn.explain(dl)
And now we can visualize them with plot_explain
def plot_explain(masks, lbls, figsize=(12,12)):
"Plots masks with `lbls` (`dls.x_names`)"
fig = plt.figure(figsize=figsize)
ax = fig.add_axes([0.1, 0.1, 0.8, 0.8])
plt.yticks(np.arange(0, len(masks[0]), 1.0))
plt.xticks(np.arange(0, len(masks[0][0]), 1.0))
ax.set_xticklabels(lbls, rotation=90)
plt.ylabel('Sample Number')
plt.xlabel('Variable')
plt.imshow(masks[0])
We pass in the masks and the x_names
and we can see for each input how it affected the output:
lbls = dls.x_names
plot_explain(res_masks, lbls)