
The GPU Bottleneck: Why Your Multi-GPU Training is Crawling (and How to Fix It!) 🚀 | GPU 瓶頸:為什麼你的多 GPU 訓練比你想像的還要慢(以及如何解決!)
Last Updated on September 25, 2025 by Editorial Team
Author(s): ChalBe
Originally published on Towards AI.
The GPU Bottleneck: Why Your Multi-GPU Training is Crawling (and How to Fix It!) 🚀 | GPU 瓶頸:為什麼你的多 GPU 訓練比你想像的還要慢(以及如何解決!)
tags: Pytorch
| DistributedDataParalle(DDP)
| Performance Optimization
So, you’ve assembled a beast of machine with multiple GPUs, ready to conquer the world of deep learning. But when you kick off your training, it feels… underwhelming. You’ve got all this horsepower, but the progress bar inches along at a snail’s pace. What gives? It’s a super common problem, and it usually comes down to one thing: the GPUs are spending more time talking to each other than actually working! The culprit? This is the infamous communication bottleneck.
你是不是覺得,你那台超讚的多 GPU 機器沒有發揮應有的潛力?硬體這麼強,但訓練速度卻像…蝸牛。這其實是個超常見的問題,通常都歸結為一個原因:你的 GPU 們花太多時間在彼此「聊天」,而不是專心工作!元兇是?GPU 之間的通訊瓶頸。
Think of it like a team project where everyone completes their part, but then they all have to gather in one room to discuss and combine their work. If that room is small and everyone is talking at once, progress grinds to a halt. This blog post will explain how to optimize this “team meeting” to unlock the full potential of your multi-GPU setup.
這篇部落格文章將解釋如何最佳化這個「團隊會議」,以釋放多 GPU 設定的全部潛力。
Understanding the Problem
The “Team Meeting” Problem in Distributed Training | 分散式訓練中的「團隊會議」問題
During distributed training, each GPU calculates a set of gradients. To ensure the model updates consistently across all GPUs, these gradients must be collected, averaged, and then distributed back to every GPU. This is the All-Reduce
operation, and it's the main source of the communication bottleneck.
在分散式訓練期間,每個 GPU 都會計算一組梯度。為了確保模型在所有 GPU 上一致地更新,這些梯度必須被收集、平均,然後分發回每個 GPU。這就是 All-Reduce
操作,也是通訊瓶頸的主要來源。

Mathematically, if ∇L_w represents the gradients calculated by each worker w
, the final aggregated gradient ∇L that every worker will use to update its model is the average of all individual gradients:
若以數學來表示,如果 ∇L_w 代表每個工作者 w
計算出的梯度,那麼所有worker最終將會用來更新模型的聚合梯度 ∇L ,則是所有個別梯度的平均值:

where W is the total number of workers (GPUs). This collective communication step ensures that all model replicas stay in sync with the same gradient information.
其中 W 為工作者(GPUs)的總數。這個集體通訊步驟確保了所有模型複本都使用相同的梯度資訊來保持同步。
The time spent on All-Reduce
can easily exceed the time spent on actual computation, especially with slower interconnects like PCIe. But don't worry, we have a few tricks up our sleeve!
花在 All-Reduce
上的時間可以輕易地超過花在實際運算上的時間,特別是在使用像 PCIe 這樣較慢的互連時。但別擔心,我們有一些錦囊妙計!
The Three Pillars of Optimization | 優化的三大支柱
Let’s dive into three powerful techniques to slash communication overhead.
最厲害的招數就是將你的通訊與運算重疊。想像一下,學生們不是等著開會,而是在進行下一個專案部分時,同時分享他們已經完成的工作。這讓「會議時間」幾乎消失了。
- Gradient Accumulation | 梯度累積:
This is a fantastic trick for using a much larger effective batch size than your GPU’s memory can handle. Think of it like a meticulous chef preparing a huge cake: they mix the ingredients for one small layer at a time, but they don’t bake it until they’ve mixed all the layers. In our case, we accumulate the gradients from multiple mini-batches before doing a single, massive update.
想像一位廚師在準備一個巨大的蛋糕。他不會試圖一次在一個小碗裡混合所有材料,而是分層處理,一次混合一層的份量。他會把每一層混合好的麵糊都倒進一個大盆裡,直到所有層都混合完畢,才將整個蛋糕送進烤箱。
在我們的例子中,我們處理好幾個較小的「mini-batch」,並在每個 GPU 上本地累積它們的梯度。我們只在經過一定數量的步驟後,才觸發 All-Reduce 操作(也就是「烘烤」的步驟)。這意味著更少、但更有意義的更新。

In essence, this process involves two key steps and can be represented mathematically:
- 本地累積 (Local Accumulation): It process several
mini-batches
, and the gradients from these individual backward passes are automatically summed up in the model parameters.
Let N be the total number of mini-batches
, and K be the accumulation step size. For each worker (GPU) w, the accumulated gradient G_{w}^{\text{acc}}
is the sum of the gradients from its local mini-batches:

where \nable L(w_{i}, D_{i})
is the gradient of the loss function with respect to the wights w_i
for mini-batch
D_i
.
- 一次大的 All-Reduce (One Big All-Reduce): After this local accumulation, a single
All-Reduce
operation combines the gradients from all workers to get the final gradientG^{\text{final}}
:

where W is the number of workers.
Here is how you can implement gradient accumulation in a typical PyTorch training loop.
import torch
import torch.nn as nn
import torch.optim as optim
# ... other imports ...
# Initialize model, criterion, and optimizer
model = nn.Linear(10, 1)
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
# Hyperparameters
accumulation_steps = 4
# Your training loop
for i, (inputs, labels) in enumerate(data_loader):
# Perform a forward pass
outputs = model(inputs)
# Calculate the loss and normalize it by accumulation steps
loss = criterion(outputs, labels) / accumulation_steps
# Perform a backward pass to accumulate gradients
loss.backward()
# Update weights only after accumulating gradients for 'accumulation_steps'
if (i + 1) % accumulation_steps == 0:
# Step 1: Update model weights based on accumulated gradients
optimizer.step()
# Step 2: Clear gradients for the next accumulation cycle
optimizer.zero_grad()
2. Gradient Compression | 梯度壓縮:
If Gradient Accumulation is about reducing the frequency of communication, Gradient Compression is about reducing the size of each communication packet.
如果說「梯度累積」是為了降低通訊的頻率,「梯度壓縮」就是為了減小每次通訊封包的大小。透過在梯度發送到網路上之前將其縮小,我們可以顯著減少 All-Reduce 操作所花費的時間。
- Quantization | 量化: This is the most common form of compression. The core idea is to reduce the numerical precision of the gradients from 32-floating-point number (FP32) to 16-bit (FP16) or even 8-bit (INT8). This can slash the data volume by 50% to 75%.
- 這是最常見的壓縮形式。其核心思想是降低梯度的數值精度,將梯度從 32 位元浮點數 (FP32) 降低到 16 位元 (FP16) 或甚至是 8 位元 (INT8)。這樣可以將數據量減少 50% 到 75%。

Mathematically, the quantization of a floating-point number x can be expressed as:

where S is a scaling factor that determines the range of quantized values. This maps the original floating-point values to a smaller set of integers or low-precision floating-point numbers. For example, using FP16 in PyTorch converts each element of the gradient from 32 bits to 16 bits.
Here is a conceptual example of how quantization works. Note that in a real distributed setup, a custom communication hook would be needed to send the scale and zero-point along with the quantized tensor.
import torch
def quantize_tensor_int8(tensor):
"""Quantizes a float tensor to 8-bit unsigned integer."""
min_val, max_val = tensor.min(), tensor.max()
scale = (max_val - min_val) / 255
zero_point = min_val
quantized_tensor = torch.round((tensor - zero_point) / scale).to(torch.uint8)
return quantized_tensor, scale, zero_point
def dequantize_tensor_int8(quantized_tensor, scale, zero_point):
"""Dequantizes an 8-bit unsigned integer tensor back to float."""
return quantized_tensor.float() * scale + zero_point
- Sparsification | 稀疏化: This is a more aggressive compression method that only transmits the most important, or “significant”, gradients when ignoring the tiny ones that might be close to zero. This technique is based on a common observation: during training, most gradient values are very small.
這是一種更為激進的壓縮方法,它只傳輸最重要或「最顯著」的梯度,同時忽略那些可能接近於零的微小梯度。這項技術基於一個常見的觀察:在訓練過程中,絕大多數的梯度值都非常小。


Mathematically, this can be represented by applying a mask function M(⋅) to the gradient vector g:

where g_i
is the i-th element of the gradient vector g, and τ is predetermined threshold. Only when the absolute value of a gradient is greater than or equal to τ is it kept and transmitted. This reduces the amount of data that needs to be communicated, especially for models where the gradients are naturally sparse.
The following code demonstrates a conceptual implementation of sparsification using a communication hook in PyTorch:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def topk_sparsification_hook(state: object, bucket: dist.GradBucket) -> torch.futures.Future[torch.Tensor]:
"""
This hook sparsifies the gradient by keeping only the top k% of values.
"""
tensor = bucket.get_tensor()
# Define the percentage of gradients to keep (e.g., 10%)
k_percentage = 0.1
k = max(1, int(tensor.numel() * k_percentage))
# Find the top k values by magnitude
topk_values, topk_indices = torch.topk(tensor.abs(), k)
# Create a new dense tensor, initialized to zeros
sparse_tensor = torch.zeros_like(tensor)
# Copy only the top k original gradient values to the new tensor
sparse_tensor.view(-1)[topk_indices] = tensor.view(-1)[topk_indices]
# All-reduce the sparse tensor across all GPUs
fut = dist.all_reduce(sparse_tensor, op=dist.ReduceOp.SUM, async_op=True).get_future()
def average_callback(fut):
# After summing across all GPUs, average the result
reduced_tensor = fut.wait()[0]
reduced_tensor /= dist.get_world_size()
return reduced_tensor
return fut.then(average_callback)
# --- How to register the hook ---
# model = DDP(model)
# model.register_comm_hook(state=None, hook=topk_sparsification_hook)
3. Communication Overlapping | 通訊重疊: This is the ultimate optimization. Imagine your team members don’t wait for the meeting to start. As soon as one person finishes their part, they immediately start sharing it with others while they begin their next task.
這是終極的優化技巧,而最棒的是 — — 如果你正在使用 PyTorch 的 DistributedDataParallel (DDP),你已經免費享受到這個功能了!

In PyTorch, this means we start the All-Reduce for the gradients of one layer as soon as they’re computed, while the backward pass continues to calculate gradients for the next layer. This hides the communication latency, making it feel “free.” This is handled automatically by PyTorch’s DistributedDataParallel
(DDP).
這個想法是將「先運算,後通訊」的順序性流程,轉變為一個平行的流程。DDP 非常聰明,它不會等到整個反向傳播 (backward pass) 完成後才開始 All-Reduce。一旦某個特定層的梯度計算完成,DDP 會立即在背景開始將它傳送給其他 GPU。在通訊進行的同時,CPU 會繼續進行反向傳播,計算下一層的梯度。 這有效地將通訊延遲隱藏在運算時間之後。這個「團隊會議」被拆分成許多微小且重疊的片段進行,因此感覺上幾乎不花任何時間。

The latency of a Ring AllReduce operation, for example, is dominated by the number of rounds and the time of the slowest path in each round. It can be expressed as

where N is the number of nodes and T is the time of the slowest path (node pair). DDP minimizes this bottleneck by overlapping the communication time with the computation time of the next layer, making the total time for a backward pass more efficient.
The following code shows a basic DDP training loop. By simply wrapping your model in DDP, you get this key optimization right out of the box.
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import os
def setup_ddp(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group("nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
def train_model():
rank = int(os.environ['RANK'])
world_size = int(os.environ['WORLD_SIZE'])
setup_ddp(rank, world_size)
model = nn.Linear(10, 1).to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.01)
print(f"[{rank}] Starting training...")
for epoch in range(5):
dummy_data = torch.randn(32, 10).to(rank)
target = torch.randn(32, 1).to(rank)
output = ddp_model(dummy_data)
loss = loss_fn(output, target)
optimizer.zero_grad()
loss.backward() # DDP automatically handles the All-Reduce here!
optimizer.step()
if rank == 0:
print(f"Epoch {epoch} finished with loss: {loss.item():.4f}")
dist.destroy_process_group()
if __name__ == '__main__':
# You would run this with `torchrun --nproc_per_node=2 your_script.py`
train_model()
The Takeaway | 總結
The biggest mistake people make is thinking they need to write a ton of complex code to get distributed training right. The truth is, modern frameworks like PyTorch have done a lot of the heavy lifting for you. By just properly setting up DistributedDataParallel
and leveraging techniques like gradient accumulation, you're well on your way to building a truly optimized distributed training pipeline.
Happy training! 祝你最終釋放你多 GPU 設備的真正潛力,訓練愉快!
Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.
Published via Towards AI
Take our 90+ lesson From Beginner to Advanced LLM Developer Certification: From choosing a project to deploying a working product this is the most comprehensive and practical LLM course out there!
Towards AI has published Building LLMs for Production—our 470+ page guide to mastering LLMs with practical projects and expert insights!

Discover Your Dream AI Career at Towards AI Jobs
Towards AI has built a jobs board tailored specifically to Machine Learning and Data Science Jobs and Skills. Our software searches for live AI jobs each hour, labels and categorises them and makes them easily searchable. Explore over 40,000 live jobs today with Towards AI Jobs!
Note: Content contains the views of the contributing authors and not Towards AI.