Towards AI Can Help your Team Adopt AI: Corporate Training, Consulting, and Talent Solutions.


Building an Audio Classification Model for Automatic Drum Transcription — Here’s What I Learnt
Latest   Machine Learning

Building an Audio Classification Model for Automatic Drum Transcription — Here’s What I Learnt

Last Updated on July 17, 2023 by Editorial Team

Author(s): Yoshi Man

Originally published on Towards AI.

An end-to-end journey from collecting labels to training a computer vision-based audio classification model.

Photo by Nicholas Jeffries on Unsplash

TL;DR — We built a model fine-tuned on InceptionResNetV2 through Keras that categorises drum hits audio clips into 6 main classes — Crash, Hihat, Snare, Kick Drum, Ride, and Toms.

Automatic Music Transcription (AMT), one of the core fundamental tasks within Music Information Retrieval (MIR), essentially refers to the task of getting from acoustic music signals into music notation. Within this lies the subtask of Automatic Drum Transcription (ADT), which focuses specifically on notating drum signals.

For me, the challenge originated from my tiring attempts to transcribe drum tracks into drum sheets for my favorite music, and so the focus was shifted to training a model that could do so for me.

This article documents my end-to-end journey of building a model from scratch, what the obstacles were, and where I see potential improvements moving forward. The hope is that this can give you inspiration, and we can collectively improve and build upon this task.

I can’t possibly go through every single detail. Where appropriate, I will include links for the avid reader (especially theoretical details). This article is split into 6 parts:

  1. Research & Background — What were the methods and approaches used in recent studies on ADT? What are common difficulties/obstacles out there?
  2. Method & Approach — The action plan used in our approach. Also, what’s in our scope for this project?
  3. Data Preprocessing — Getting from audio to image. How do we represent an audio clip with spectrograms?
  4. Gathering Data — The first step in execution. How do we get our dataset for this problem?
  5. Training a Model — Usually the most exciting part. How do we get from data to a predictive model?
  6. Demo and Conclusion

1) Research & Background — Getting Familiar with the Task

The starting point of this project was based on A Review of Automatic Drum Transcription (CW Wu, 2018), which provides a survey of where recent attempts at ADT excelled at and where the shortcomings were.

In point form, here’s what I picked up:

  • Music Style Elements are Difficult to Recognise in Audio — Recognising complex musical style elements such as accents and subtle techniques that are usually included in transcriptions but are often ambiguous in audio form.
  • Polyphony Complicates Things — Audio tracks with other instruments (which are usually the cases with music) are significantly more difficult to transcribe. The presence of other instruments distorts the data point being classified.
  • Drum Notes are Played Together — “drum sounds are usually superimposed on top of each other”, meaning a sound could be a combination of two drum components (e.g., hi-hat + kick drum) and cause ambiguity in classification.
  • RNN vs. CNN, or both? — The two main approaches are sequential models and image recognition models. Two paths entail what tasks are required to get from dataset to model training; most recently, a mix of both has done considerably well even in polyphonic contexts (C. Southall, 2017).
In drum notation, we have superimposed notes and complex elements like flams and triplets etc. (Image by Author)

2) Method & Approach

Considering the existing approaches and pitfalls, we will be using a CNN approach to the problem. This entails that for each drum hit, we will be extracting an image containing features of the audio and using it to finetune an InceptionResNetV2 model using Keras.

The end-to-end approach for this project looks like this:

Rough sketch of method made during planning. (Image by Author)

Important to point out that getting from the original music track to a model and predictions require five separate tasks:

  1. Audio Separation — This tackles the polyphonic problem as above. Recent advances in Music Source Separation made it possible to eliminate a large chunk of our issue with multi-instrumental tracks. By using tools like mvsep, we can extract the drum portion of music easily for free.
  2. Onset Detection — The task to locate each drum hit from a drum track so that we can classify each hit individually
  3. Feature Extraction — A very common method of audio feature extraction is by using the Mel Spectrogram. Essentially, this takes an audio clip, performs Short Time Fourier Transform to get the frequency spectrum, and rescaled to the Mel-Scale to represent the human perception of frequencies. The result is an image that “represents” the audio clip as “seen” by the listener.
  4. Data Labelling — We need to generate our own dataset for this task. Common datasets like the ENST Drums are used across academic research, but if we’re going to do predictions on source-separated audio drum tracks, we better do training on source-separated audio drum tracks.
  5. Model Training — The final step is training the model. We will be finetuning an InceptionResNetV2 model using Keras, and feeding in our own labeled dataset.

3) Data Preprocessing — Separate, Detect, Extract

This part consists of taking an audio signal as an input and extracting useful features from it in the form of an image. In between are steps laid out in our method from steps 1–3.

Audio Source Separation — Extracting the drums

A simple way to take the drums out from a music file is through mvsep, a free tool to separate musical components from a music file. Simply upload the audio file and select the model to separate with. While exploring the different options for models, I have found that the Demucs4 HT Model (developed by Facebook Research) sounds the greatest for drum separation., a free tool to separate musical components from a music file (Image by Author)

Onset Detection — Detecting each hit

The librosa library makes it very handy to do so by using the librosa.onset.onset_detect function. Given that drum beats are usually uniform, we can also use librosa.beat.beat_trackto get a click track that best fits the drum track. Below are both methods and what the output sounds like.

import matplotlib.pyplot as plt
from IPython.display import Audio
import librosa, librosa.display

# load in the audio file, first 30 seconds
samples, sr = librosa.load(AUDIO_FILE_PATH, duration=30)
# get the timestamps of each hit or each beat
onset_times = librosa.onset.onset_detect(y=samples, sr=sr, units='time')
beat_times = librosa.beat.beat_track(y=samples, sr=sr,
start_bpm=bpm, units='time')[1]
# get the click tracks
clicks = librosa.clicks(onset_times, sr=sr, length=len(samples))
display(Audio((samples + clicks), rate=sr))
clicks = librosa.clicks(beat_times, sr=sr, length=len(samples))
display(Audio((samples + clicks), rate=sr))
fig, ax = plt.subplots(nrows=3, sharex=True, figsize=(10,5))
librosa.display.waveplot(samples, sr=sr, offset=10, ax=ax[0])
onset_times = librosa.onset.onset_detect(y=samples, sr=sr, units='time')
ax[1].vlines(onset_times+10,-1, 1)
beat_times = librosa.beat.beat_track(y=samples, sr=sr,
start_bpm=bpm, units='time')[1]
ax[2].vlines(beat_times+10,-1, 1)
(i) the wave plot of the audio clip (from 10 to 20s) (ii) Onsets detected (iii) Beat times

Here’s what they sound like:

Onset detection — Each click represents an onset detected
Beat track — Each click is regularly spread out that best estimate the tempo of the beat

For our purposes, it would make more sense to go with Onset Detection, as we would like to classify each hit while segmenting by the beat might mean that each window could consist of multiple hits.

Given that we’ve now got irregularity in the window size for each hit, we would need to regularise the window size either by trimming or padding the audio clip to a fixed length.

Mel Spectrograms — Extracting features

Mel Spectrograms (Image by Author)

Spectrograms are image data that represents audio clip in a meaningful way in terms of their frequency (pitch) and intensity (loudness) across a time domain. The “Mel” portion further transforms the spectrogram into a representation that follows the human perception of pitch and sound. The combination of both meant that we were able to turn an audio clip into a visibly representable form, where we can train Convolutional Neural Networks for classification.

For more details on the intricacies and how they work, Ketan Doshi has a brilliant series of articles that explains it in simple terms:

Audio Deep Learning Made Simple (Part 2): Why Mel Spectrograms perform better

A Gentle Guide to processing audio in Python. What are Mel Spectrograms and how to generate them, in Plain English.

In code form for our purposes, we can use librosa to get the Mel Spectrogram image array directly:

def get_mel_spectrogram(samples: np.array, sr: int = 44100, 
) -> np.array:
:param samples (np.array): samples array of the audio
:param sr (int): sample rate used for the samples
:return mel_spectrogram (np.array): np.array containing
melspectrogram features in decibels

hop_length = len(samples)//target_shape[0]

# get mel spectrogram image data
mel_features = librosa.feature.melspectrogram(
y=samples, sr=sr, hop_length=hop_length, n_mels=target_shape[0])

# clip the array to fit our target image shape
mel_features = mel_features[:, :target_shape[1]]

# convert to decibels as and normalize the image for efficiency
mel_in_db = librosa.power_to_db(mel_features, ref=np.max)
scaler = MinMaxScaler(feature_range=(0, 1))
return scaler.fit_transform(mel_in_db)

Now that we’re able to go from a music clip to getting Mel Spectrograms for each drum hit detected, we can now begin to create our own dataset for the problem.

4) Gathering Data — Labelling our own data

The complicated bit is over, and now comes the tedious part of creating our own dataset.

The reason why we’re not using public like the ENST Drums is that we want to train on data that model will see at prediction. Datasets like ENST Drums are professionally recorded with minimal noise, whereas the tracks we’ll be making predictions on are noisy as a result of imperfections from audio source separation.

In order to create our own dataset, we’ll need to label it ourselves. Here comes pigeon, a package that allows us to quickly set up our own data labeling sessions.

pigeon — A tool to annotate data on Jupyter (Image by Anastasis Germanidis)

Since the module only supported image and text data, we needed to fork it to add the option. After adding some basic features to the code, the result is the following:

forked pigeon — Adapted tool for labeling audio data (Image by Author)

The resultant data set is a JSON for each labeling session containing labels of each hit for a selected drum track:

Example data generated during labeling sessions (Image by Author)

These JSON files, along with the drum track paths, can then be converted to images with labels assigned to them, which we can finally train a CNN on. But first, we need to get a lot of labels in first.

5) Model Training — Preprocess and Finetune

After 5 hours of audio labeling, with a total of 4513 labeled audio clips across 10 different drum tracks, we are now ready to train our first model.

Model training process flow (Image by Author)

Model training consists of 2 main steps:

  1. Data Preprocessing — This time, preprocessing means getting an audio-to-label dataset, splitting the data into train, validation, and test sets, rebalancing the train set through audio augmentation, and finally saving the images into a local folder
  2. Model Finetuning — Taking a pre-trained image classification model and adding layers on top for our use cases. We will be finetuning an InceptionResNetV2 Model.

Data Preprocessing — Creating a trainable dataset.

To get from the JSON files, we have a valid dataset for training. The following is done:

Process flow for getting from JSON to a dataset (Image by Author)
  1. JSON to Audio — A simple process of taking the JSON file, taking the individual samples with onset detection from the drum track, and assigning it to the class we’ve previously labeled it as
  2. Train/Validation/Test Split — Do a stratified train/val/test split to preserve the original class distribution of the dataset. Important to ensure that we are validating AND testing, AND predicting using data of the same distribution, which our original data best approximates
  3. Upsampling with Audio Augmentation the Train Set — Class imbalance exists and affects the model's ability to account for less frequent classes. Upsampling gives more training examples for the model to recognize, while augmentation helps synthesize more examples within the sample space that we did not have during data collection.
Final labelled dataset, clearly imbalanced (Image by Author)

Audiomentations is a popular package that can help with augmenting audio data. It is extremely important that we do not apply data augmentation techniques we typically perform on image data, as transformations like rotation, stretching, and masking do not inherently have the same meaning in the audio space.

Augmentations are done on the audio level, not the image (Image by Author)

Audio augmentations are done at random probabilities, here we’ve applied augmentations to 50% of our training data.

def apply_augmentation(samples: np.array):
:param samples (np.array): samples array of the audio
:return augmented (np.array): np.array containing samples
of augmented audio

gaussian_noise = AddGaussianNoise(

time_mask = TimeMask(
time_shift = Shift(
pitch_shift = PitchShift(
gain = Gain(p=0.5)
augmenter = Compose(
[time_shift, gain, pitch_shift, time_mask, gaussian_noise])
return augmenter(samples=samples,

4. Saving the Images — After augmentations and oversampling have been completed on training data, for all Train, Val, and Test sets we save the data into a local directory sorted by their classes ready for the model to be finetuned.

Model Finetuning in Keras — InceptionResNetV2

Keras provides a range of pre-trained models for computer vision tasks; this makes it easy for us to adapt the model for our purposes — which we call finetuning a model.

The benefit of doing so is that the pre-trained models already have capabilities, such as edge detection and curvature detection, that could be useful for our tasks.

Computer Vision Models available on Keras (Image by Author)

The process through code is quite simple and standard. We first get the model from Keras and freeze the original weights from the pre-trained model. This is so that we can feed in our images, get the last convolution layer flattened, and use it as an input to our own dense neural network.

The below hyperparameters, such as the number of layers, and number of neurons per hidden layer, are tuned via experiments.

from tensorflow.keras import models, layers, optimizers
from tensorflow.keras.applications import InceptionResNetV2
from tensorflow.keras.preprocessing.image import ImageDataGenerator

import mlflow
import mlflow.keras
def get_model(path=None):
if path is None:
conv_base = InceptionResNetV2(weights="imagenet",
input_shape=(256, 256, 3)
# make it so the conv_base is not trainable
conv_base.trainable = False
# add more layers on top of the Inception model
model = models.Sequential()
model.add(layers.Dense(2048, activation='relu'))
model.add(layers.Dense(1024, activation='relu'))
model.add(layers.Dense(512, activation='relu'))
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(128, activation='relu'))
model.add(layers.Dense(6, activation='sigmoid')) # 6 classes
model = models.load_model(path)
return model

Training the model can be done with the following:


train_datagen = ImageDataGenerator(rescale=1./255)
val_datagen = ImageDataGenerator(rescale=1./255)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
target_size=(256, 256),
validation_generator = val_datagen.flow_from_directory(
target_size=(256, 256),
test_generator = test_datagen.flow_from_directory(
target_size=(256, 256),
# initialise and build CNN model based on InceptionResNetV2
model = get_model()
history =

Model Evaluation — Evaluation against the test set

After multiple experiments with hyperparameter tuning, the best model was picked out. The evaluation metrics of accuracy and F1 scores were considered in the model selection phase. For predicted labels, our only option was to pick argmax to be the labels, but realistically during inference, we could use a threshold to provide multilabel outputs.

from keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import confusion_matrix, f1_score, accuracy_score
import mlflow
import mlflow.keras

import matplotlib.pyplot as plt
import seaborn as sns
def get_confusion_matrix(model, test_generator):
probs = model.predict(test_generator)
y_pred = np.argmax(probs, axis=1)
m = confusion_matrix(test_generator.classes, y_pred)
df = pd.DataFrame(m, index=list(test_generator.class_indices.keys()),
plt.figure(figsize = (10,7))
return sns.heatmap(df, annot=True, cmap='Blues', fmt='g',
xticklabels=True, yticklabels=True)

model = mlflow.keras.load_model(BEST_MODEL_PATH)
# test data
test_datagen = ImageDataGenerator(rescale=1. / 255)
test_generator = test_datagen.flow_from_directory(
target_size=(256, 256),
class_mode='categorical', shuffle=False)
probs = model.predict(test_generator)
y_pred = np.argmax(probs, axis=1)
print(f1_score(test_generator.classes, y_pred, average='weighted'))
print(accuracy_score(test_generator.classes, y_pred)
get_confusion_matrix(model, test_generator)
Test Accuracy — 0.907 // Test F1-Score — 0.908 (Image by Author)

A large part of model training will be experimenting with hyperparameters. This particular model trained fared better but the time training took 5 hours to do so.

6) Demo and Conclusion

To demo the results, a Streamlit App was created locally to transcribe audio based on a YouTube link. By feeding in any YouTube link containing a drums-only track, predictions will be generated from the model.

Demo on Drums Transcriber (Image by Author)

Here’s the Github Repo for the project:

GitHub — yoshi-man/DrumTranscriber

This package helps users transcribe drum audio hits into 6 classes — Hihat, Crash, Kick Drum, Snare, Ride, and Toms…

Wrapping Up…

In this article, we went from raw audio signals to extracting the drum track, isolating each drum hit, and training a model with our own labeled data to classify each drum hit. In this process, we’ve used the following:

  1. mvsep to separate and extract the drum tracks
  2. librosa to work with audio such as onset detection and mel spectrograms
  3. pigeon to create our own data labeling session
  4. audiomentations to perform audio augmentations for oversampling our training data
  5. scikitlearn, mlflow, and Keras to track and finetune an InceptionResNetV2 model to classify each drum hit
  6. streamlit to demo our app

Next Steps — Beyond this

Of course, transcription is a task to get from audio to an actual transcribed musical notation. We’ve stopped short of this as it requires more effort to get there, which I hope will get to in future posts.

Thanks for reading!

Appendix — More to check out on this topic

Research articles that were also useful:

In addition to academic approaches to the problem, the links below were crucial in catching up on understanding the general toolkits and methods around audio classification:

Audio Deep Learning Made Simple: Sound Classification, step-by-step

An end-to-end example and architecture for Audio Deep Learning’s foundational application scenario, in Plain English.

YouTube Playlists:

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI

Feedback ↓