Lesson Video:


This article is also a Jupyter Notebook available to be run from the top down. There will be code snippets that you can then run in any environment.

Below are the versions of fastai, fastcore, wwf, fastaudio, and torchaudio currently running at the time of writing this:

  • fastai: 2.1.5
  • fastcore: 1.3.4
  • wwf: 0.0.7
  • fastaudio: 0.1.3
  • torchaudio: 0.7.2

(Largely based on rbracco's tutorial, big thanks to him for his work on getting this going for us!)

fastai's audio module has been in development for a while by active forum members:

What makes Audio different?

While it is possible to train on raw audio (we simply pass in a 1D tensor of the signal), what is done now is to convert the audio to what is called a spectrogram to train on.

Free Digit Dataset

Essentially the audio version of MNIST, it contains 2,000 recordings from 10 speakers saying each digit 5 times. First, we'll grab the data and use a custom extract function:

from fastai.vision.all import *
from fastaudio.core.all import *
from fastaudio.augment.all import *
/usr/local/lib/python3.6/dist-packages/torchaudio/backend/utils.py:54: UserWarning: "sox" backend is being deprecated. The default backend will be changed to "sox_io" backend in 0.8.0 and "sox" backend will be removed in 0.9.0. Please migrate to "sox_io" backend. Please refer to https://github.com/pytorch/audio/issues/903 for the detail.
  '"sox" backend is being deprecated. '

tar_extract_at_filename simply extracts at the file name (as the name suggests)

path_dig = untar_data(URLs.SPEAKERS10, extract_func=tar_extract_at_filename)

Now we want to grab just the audio files.

audio_extensions[:5]
('.aif', '.aifc', '.aiff', '.au', '.m3u')
fnames = get_files(path_dig, extensions=audio_extensions)
fnames[:5]
(#5) [Path('/root/.fastai/data/ST-AEDS-20180100_1-OS/f0004_us_f0004_00268.wav'),Path('/root/.fastai/data/ST-AEDS-20180100_1-OS/f0004_us_f0004_00111.wav'),Path('/root/.fastai/data/ST-AEDS-20180100_1-OS/m0003_us_m0003_00309.wav'),Path('/root/.fastai/data/ST-AEDS-20180100_1-OS/f0003_us_f0003_00255.wav'),Path('/root/.fastai/data/ST-AEDS-20180100_1-OS/f0002_us_f0002_00334.wav')]

We can convert any audio file to a tensor with AudioTensor. Let's try opening a file:

at = AudioTensor.create(fnames[0])
at, at.shape
(AudioTensor([[0.0000, 0.0000, 0.0000,  ..., 0.0002, 0.0002, 0.0003]]),
 torch.Size([1, 75520]))
at.show()
<matplotlib.axes._subplots.AxesSubplot at 0x7f0e0419edd8>

Preparing the dataset

fastai_audio has a AudioConfig class which allows us to prepare different settings for our dataset. Currently it has:

  • BasicMelSpectrogram
  • BasicMFCC
  • BasicSpectrogram
  • Voice

We'll be using the Voice module today, as this dataset just contains human voices.

cfg = AudioConfig.Voice()

Our configuration will limit options like the frequency range and the sampling rate

cfg.f_max, cfg.sample_rate
(8000.0, 16000)

We can then make a transform from this configuration to turn raw audio into a workable spectrogram per our settings:

aud2spec = AudioToSpec.from_cfg(cfg)

For our example, we'll crop out the original audio file to 1000 ms

crop1s = ResizeSignal(1000)

Let's build a Pipeline how we'd expect our data to come in

pipe = Pipeline([AudioTensor.create, crop1s, aud2spec])

And try visualizing what our newly made data becomes.

First, we'll remove that cropping:

pipe = Pipeline([AudioTensor.create, aud2spec])
for fn in fnames[:3]:
  audio = AudioTensor.create(fn)
  audio.show()
  pipe(fn).show()
/usr/local/lib/python3.6/dist-packages/torch/functional.py:516: UserWarning: stft will require the return_complex parameter be explicitly  specified in a future PyTorch release. Use return_complex=False  to preserve the current behavior or return_complex=True to return  a complex output. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:653.)
  normalized, onesided, return_complex)
/usr/local/lib/python3.6/dist-packages/torch/functional.py:516: UserWarning: The function torch.rfft is deprecated and will be removed in a future PyTorch release. Use the new torch.fft module functions, instead, by importing torch.fft and calling torch.fft.fft or torch.fft.rfft. (Triggered internally at  /pytorch/aten/src/ATen/native/SpectralOps.cpp:590.)
  normalized, onesided, return_complex)

You can see that they're not all the same size here. Let's add that cropping back in:

pipe = Pipeline([AudioTensor.create, crop1s, aud2spec])
for fn in fnames[:3]:
  audio = AudioTensor.create(fn)
  audio.show()
  pipe(fn).show()

And now everythign is 128x63

Using the DataBlock API:

  • We'll want to use our same transforms we used for the Pipeline
  • An appropriate getter
  • An appropriate labeller

For our transforms, we'll want the same ones we used before

item_tfms = [ResizeSignal(1000), aud2spec]

Our filenames are labelled by the number followed by the name of the individual:

  • 4_theo_37.wav
  • 2_nicolas_7.wav
get_y = lambda x: x.name[0]
aud_digit = DataBlock(blocks=(AudioBlock, CategoryBlock),  
                 get_items=get_audio_files, 
                 splitter=RandomSplitter(),
                 item_tfms = item_tfms,
                 get_y=get_y)

And now we can build our DataLoaders

dls = aud_digit.dataloaders(path_dig, bs=64)

Let's look at a batch

dls.show_batch(max_n=3)

Training

Now that we have our Dataloaders, we need to make a model. We'll make a function that changes a Learner's first layer to accept a 1 channel input (similar to how we did for the Bengali.AI model)

def alter_learner(learn, n_channels=1):
  "Adjust a `Learner`'s model to accept `1` channel"
  layer = learn.model[0][0]
  layer.in_channels=n_channels
  layer.weight = nn.Parameter(layer.weight[:,1,:,:].unsqueeze(1))
  learn.model[0][0] = layer
learn = Learner(dls, xresnet18(), CrossEntropyLossFlat(), metrics=accuracy)

Now we need to grab our number of channels:

n_c = dls.one_batch()[0].shape[1]; n_c
1
alter_learner(learn, n_c)

Now we can find our learning rate and fit!

learn.lr_find()
SuggestedLRs(lr_min=0.03019951581954956, lr_steep=0.0030199517495930195)
learn.fit_one_cycle(5, 1e-2)
epoch train_loss valid_loss accuracy time
0 1.161456 0.916622 0.861979 00:20
1 0.405886 0.429746 0.815104 00:20
2 0.190736 0.975245 0.850260 00:20
3 0.094304 0.052621 0.983073 00:20
4 0.049124 0.023142 0.990885 00:20
learn.fit_one_cycle(5, 1e-3)
epoch train_loss valid_loss accuracy time
0 0.030099 0.034156 0.986979 00:11
1 0.031116 0.021291 0.996094 00:12
2 0.032417 0.017661 0.997396 00:11
3 0.025745 0.017490 0.994792 00:11
4 0.022410 0.016554 0.997396 00:11

Not bad for zero data augmentation! But let's see if augmentation can help us out here!

Data Augmentation

We can use the SpectrogramTransformer class to prepare some transforms for us

DBMelSpec = SpectrogramTransformer(mel=True, to_db=True)

Let's take a look at our original settings:

aud2spec.settings
{'mel': 'True',
 'to_db': 'False',
 'sample_rate': 16000,
 'n_fft': 1024,
 'win_length': 1024,
 'hop_length': 128,
 'f_min': 50.0,
 'f_max': 8000.0,
 'pad': 0,
 'n_mels': 128,
 'window_fn': <function _VariableFunctionsClass.hann_window>,
 'power': 2.0,
 'normalized': False,
 'wkwargs': None,
 'stype': 'power',
 'top_db': None,
 'sr': 16000,
 'nchannels': 1}

And we'll narrow this down a bit

aud2spec = DBMelSpec(n_mels=128, f_max=10000, n_fft=1024, hop_length=128, top_db=100)

For our transforms, we'll use:

  • RemoveSilence
    • Splits a signal at points of silence more than 2 * pad_ms (default is 20)
  • CropSignal
    • Crops a signal by duration and adds padding if needed
  • aud2spec
    • Our SpectrogramTransformer with parameters
  • MaskTime
    • Wrapper for MaskFre, which applies einsum operations
  • MaskFreq

Let's look a bit more at the padding CropSignal uses:

There are three different types:

  • AudioPadTypes.Zeros: The default, random zeros before and after
  • AudioPadType.Repeat: Repeat the signal until proper length (great for coustic scene classification and voice recognition, terrible for speech recognition)
  • AudioPadtype.ZerosAfter: This is the default for many other libraries, just pad with zeros until you get the specified length.

Now let's rebuild our DataBlock:

item_tfms = [RemoveSilence(), ResizeSignal(1000), aud2spec, MaskTime(size=4), MaskFreq(size=10)]
aud_digit = DataBlock(blocks=(AudioBlock, CategoryBlock),  
                 get_items=get_audio_files, 
                 splitter=RandomSplitter(),
                 item_tfms = item_tfms,
                 get_y=get_y)
dls = aud_digit.dataloaders(path_dig, bs=128)

Let's look at some augmented data:

dls.show_batch(max_n=3)

Let's try training again. Also, since we have to keep making an adustment to our model, let's make an audio_learner function similar to cnn_learner:

def audio_learner(dls, arch, loss_func, metrics):
  "Prepares a `Learner` for audio processing"
  learn = Learner(dls, arch, loss_func, metrics=metrics)
  n_c = dls.one_batch()[0].shape[1]
  if n_c == 1: alter_learner(learn)
  return learn
learn = audio_learner(dls, xresnet18(), CrossEntropyLossFlat(), accuracy)
learn.lr_find()
SuggestedLRs(lr_min=0.04365158379077912, lr_steep=0.002511886414140463)
learn.fit_one_cycle(10, 3e-3)
epoch train_loss valid_loss accuracy time
0 4.650460 1.274344 0.776042 00:17
1 1.882244 0.165905 0.928385 00:17
2 0.958342 0.023563 0.993490 00:17
3 0.537469 0.020922 0.994792 00:17
4 0.317845 0.012210 0.996094 00:17
5 0.193368 0.030548 0.990885 00:17
6 0.118598 0.014498 0.996094 00:17
7 0.074519 0.002216 1.000000 00:17
8 0.047171 0.003482 1.000000 00:17
9 0.030317 0.002047 1.000000 00:17
learn.fit_one_cycle(10, 3e-4)
epoch train_loss valid_loss accuracy time
0 0.003665 0.003438 1.000000 00:17
1 0.004991 0.003278 1.000000 00:17
2 0.005746 0.017678 0.997396 00:17
3 0.005073 0.002702 0.998698 00:17
4 0.003955 0.002549 1.000000 00:17
5 0.003551 0.001628 1.000000 00:17
6 0.003751 0.001688 1.000000 00:17
7 0.003312 0.003636 0.998698 00:17
8 0.003430 0.003046 1.000000 00:17
9 0.003002 0.002288 1.000000 00:17

With the help of some of our data augmentation, we were able to perform a bit higher!

Mel Frequency Cepstral Coefficient (MFCC)

Now let's look at that MFCC option we said earlier. MFCC's are a "linear cosine transform of a log power spectrum on a nonlinear mel scale of frequency" - Wikipedia. But what does that mean?

Let's try it out!

aud2mfcc = AudioToMFCC(n_mfcc=40, melkwargs={'n_fft':2048, 'hop_length':256,
                                             'n_mels':128})
item_tfms = [ResizeSignal(1000), aud2mfcc]

There's a shortcut for replacing the item transforms in a DataBlock:

aud_digit.item_tfms
(#8) [ToTensor:
encodes: (PILMask,object) -> encodes
(PILBase,object) -> encodes
decodes: ,Resample:
encodes: (AudioTensor,object) -> encodes
decodes: ,DownmixMono:
encodes: (AudioTensor,object) -> encodes
decodes: ,RemoveSilence:
encodes: (AudioTensor,object) -> encodes
decodes: ,ResizeSignal:
encodes: (AudioTensor,object) -> encodes
decodes: ,AudioToSpec:
encodes: (AudioTensor,object) -> encodes
decodes: ,MaskTime:
encodes: (AudioSpectrogram,object) -> encodes
decodes: ,MaskFreq:
encodes: (AudioSpectrogram,object) -> encodes
decodes: ]
aud_digit.item_tfms = item_tfms
dls = aud_digit.dataloaders(path_dig, bs=128)
dls.show_batch(max_n=3)

Now let's build our learner and train again!

learn = audio_learner(dls, xresnet18(), CrossEntropyLossFlat(), accuracy)
learn.lr_find()
SuggestedLRs(lr_min=0.07585775852203369, lr_steep=0.0020892962347716093)
learn.fit_one_cycle(5, 1e-2)
epoch train_loss valid_loss accuracy time
0 1.969720 1.324741 0.812500 00:09
1 0.805437 0.774530 0.738281 00:09
2 0.427395 0.071427 0.970052 00:09
3 0.251631 0.052916 0.980469 00:09
4 0.156765 0.026019 0.989583 00:09

Now we can begin to see why choosing your augmentation is important!

MFCC + Delta:

The last transform we'll discuss is the Delta transform:

Local estimate of the derivative of the input data along the selected axis.

This allows multiple-channeled inputs from one signal

item_tfms = [ResizeSignal(1000), aud2mfcc, Delta()]
aud_digit.item_tfms = item_tfms
dls = aud_digit.dataloaders(path_dig, bs=128)
dls.show_batch(max_n=3)

Let's try training one more time:

learn = audio_learner(dls, xresnet18(), CrossEntropyLossFlat(), accuracy)
learn.lr_find()
SuggestedLRs(lr_min=0.15848932266235352, lr_steep=0.0014454397605732083)
learn.fit_one_cycle(5, 1e-2)
epoch train_loss valid_loss accuracy time
0 1.581891 0.626583 0.882812 00:13
1 0.678377 0.197535 0.912760 00:13
2 0.367090 0.094837 0.962240 00:13
3 0.220476 0.022505 0.992188 00:13
4 0.141336 0.033642 0.988281 00:13

Let's try fitting for a few more:

learn.fit_one_cycle(5, 1e-2/10)
epoch train_loss valid_loss accuracy time
0 0.029490 0.022874 0.993490 00:13
1 0.027482 0.022706 0.992188 00:13
2 0.025360 0.015615 0.994792 00:13
3 0.026397 0.026350 0.990885 00:13
4 0.027179 0.022311 0.993490 00:13