Gradient Checkpointing
To “scale” new heights in model training

Gradient Checkpointing (aka Re-compute technique/activation checkpointing) is an approach that trades compute for memory and is helpful in scenarios where the available GPU memory is not enough to accommodate a large model. It was published originally in 2016 [Link]

In Short #7 U+007C What is Gradient Accumulation ?, we learned how to train a model with a large enough batch size in spite of low GPU memory.

But what if the model is large enough, and we can’t use even a batch size of 1?

Gradient checkpointing helps here by decreasing the memory footprint required for executing the model. So even if a large model outsizes the GPU, we still have a silver lining.

It does this dynamically by NOT storing all the intermediate activations during the forward pass, thereby saving precious memory.

Let’s take an example with a computation graph with A1 and A2 as the intermediate activations.

Instead of pre-computing both A1 and A2, it skips computing A1 during the forward pass.

Running with torch.no_grad() ensures that intermediate activations are not stored for those parameters.

It’s only during the backward pass, that the skipped activations are calculated. And this makes the backward pass slower, but while saving some memory.

The slowdown in speed is 20%, but the memory cost, as per the paper, is transformed as:

Which layers are checkpointed?

It is implemented internally in Pytorch and other deep-learning frameworks. But one of the ideas from the papers recommends:

This is so that the speed of backward passes is not hampered much, and the calculations are still computationally cheap.

How to implement it?

In PyTorch, it has got a simple checkpoint API:

Tensorflow users can checkout here


Hope you enjoyed this !!

