Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Read by thought-leaders and decision-makers around the world. Phone Number: +1-650-246-9381 Email: pub@towardsai.net
228 Park Avenue South New York, NY 10003 United States
Website: Publisher: https://towardsai.net/#publisher Diversity Policy: https://towardsai.net/about Ethics Policy: https://towardsai.net/about Masthead: https://towardsai.net/about
Name: Towards AI Legal Name: Towards AI, Inc. Description: Towards AI is the world's leading artificial intelligence (AI) and technology publication. Founders: Roberto Iriondo, , Job Title: Co-founder and Advisor Works for: Towards AI, Inc. Follow Roberto: X, LinkedIn, GitHub, Google Scholar, Towards AI Profile, Medium, ML@CMU, FreeCodeCamp, Crunchbase, Bloomberg, Roberto Iriondo, Generative AI Lab, Generative AI Lab VeloxTrend Ultrarix Capital Partners Denis Piffaretti, Job Title: Co-founder Works for: Towards AI, Inc. Louie Peters, Job Title: Co-founder Works for: Towards AI, Inc. Louis-François Bouchard, Job Title: Co-founder Works for: Towards AI, Inc. Cover:
Towards AI Cover
Logo:
Towards AI Logo
Areas Served: Worldwide Alternate Name: Towards AI, Inc. Alternate Name: Towards AI Co. Alternate Name: towards ai Alternate Name: towardsai Alternate Name: towards.ai Alternate Name: tai Alternate Name: toward ai Alternate Name: toward.ai Alternate Name: Towards AI, Inc. Alternate Name: towardsai.net Alternate Name: pub.towardsai.net
5 stars – based on 497 reviews

Frequently Used, Contextual References

TODO: Remember to copy unique IDs whenever it needs used. i.e., URL: 304b2e42315e

Resources

Free: 6-day Agentic AI Engineering Email Guide.
Learnings from Towards AI's hands-on work with real clients.
A Deep Dive into Distributed Checkpointing: Using Orbax with Torchax on TPUs
Latest   Machine Learning

A Deep Dive into Distributed Checkpointing: Using Orbax with Torchax on TPUs

Last Updated on May 29, 2026 by Editorial Team

Author(s): Pratiksha Patnaik

Originally published on Towards AI.

A Deep Dive into Distributed Checkpointing: Using Orbax with Torchax on TPUs

A Deep Dive into Distributed Checkpointing: Using Orbax with Torchax on TPUs

Training large deep learning models is an exercise in managing risks. Hardware glitches, network drops, spot instance preemption, and sudden cloud infrastructure hiccups can instantly wipe out days of expensive training progress. This is why checkpointing systems are critical.

But as AI models grow, traditional ways of saving progress (checkpointing) can create massive data traffic jams, slowing everything down. If you are using Torchax or keeping an eye on Google’s upcoming native framework, TorchTPU to run your PyTorch projects you need a saving system built for massive scale.

That is where Orbax comes in.

This deep dive focuses on how to use Orbax to build a high performance, automatic recovery system, while also breaking down the inner workings of Torchax and Google’s new native TorchTPU framework.

⚡🌉 What is Torchax?

Historically, AI researchers faced a tough choice: write in PyTorch for its massive ecosystem and easy to use design, or use JAX for its blazing-fast speed on Google’s hardware.

Torchax bridges this gap. It lets you run standard, unmodified PyTorch models directly on top of JAX. By adding just a couple of lines of code torchax.enable_globally() and .to('jax')Torchax automatically translates your PyTorch code into a format JAX understands. This gives you JAX's massive hardware acceleration without making you rewrite your model from scratch.

Google is also building TorchTPU, an upcoming native framework that will make running PyTorch on their chips even more seamless. But whether you use Torchax today or TorchTPU tomorrow, scaling up your training means you need a rock solid way to save your progress without slowing down your hardware. That is exactly why mastering Orbax checkpointing is so critical.

🛣️🔥 The Road Ahead: Enter TorchTPU

While Torchax relies on translation layers to bridge PyTorch and JAX, Google recently unveiled TorchTPU as the future of running PyTorch natively on their hardware. Built to completely eliminate those middleman translation layers, TorchTPU integrates directly into the core PyTorch runtime.

TorchTPU shifts to an “Eager First” philosophy. Instead of forcing you into complex, heavy compilation steps before your code can run, TorchTPU behaves like standard PyTorch. This means you get seamless, step-by-step debugging alongside a breakthrough Fused Eager Mode. This mode watches your PyTorch operations on the fly and automatically groups them together to boost hardware performance by 50% to over 100%. Under the hood, it translates standard PyTorch commands straight into optimized instructions for TPU clusters entirely removing any need for JAX.

📦💾 What is Orbax Checkpointing?

Orbax is a high performance saving system built from the ground up for massive AI models.

Traditional PyTorch saving tools rely on torch.save, which forces your computer's CPU to pack your entire model into a single, giant file. When your model is huge, this creates a massive traffic jam that freezes your training loop. Orbax fixes this by breaking your data into organized, smart chunks that are built to move fast across large networks.

🛠️✨ Core Features of Orbax CheckpointManager:

  • True Asynchronous Writes (No More Waiting): Orbax hands off the job of saving to background worker threads. Instead of freezing your training loop while files slowly write to the cloud, your model resumes training almost instantly.
  • Flexible Resharding (Mix & Match Hardware): If you save your model while training on a cluster of 8 TPUs, Orbax lets you seamlessly load that exact same save file later onto 16, 32, or a completely different layout of chips without breaking your code.
  • Automatic Clean-up (Smart Storage): By simply setting a rule like max_to_keep=3, Orbax automatically deletes your oldest saves in the background. This keeps your storage organized and prevents sudden, massive cloud bills.

⚙️💻 Environment Setup

1. Initialize and Activate the Virtual Environment

python3 -m venv .venv
source .venv/bin/activate
pip install --upgrade pip

2. Install Torchax ,Orbax, GCSFS

pip install torchax orbax gcsfs

4. Install the CPU Version of PyTorch (Linux)

pip install torch --index-url https://download.pytorch.org/whl/cpu

3. Install Jax for TPUs

If deploying on Google Cloud TPU VM instances, ensure your JAX environment is properly linked to the Cloud TPU acceleration runtimes to prevent CPU execution fallbacks:

pip install -U "jax[tpu]" 

📝🔄 Code Breakdown: Robust Save, Load, and Recovery

“Let’s look at a practical example. The script below trains a basic Pytorch model for 5 epochs and walks you through how Orbax automatically saves (serializes) your checkpoints every 100 steps and loads (deserializes) the checkpoint once the training is resumed again.”

import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchax
from torchax.interop import jax_jit, torch_view
import orbax.checkpoint as ocp
import jax

# 1. Define the model
class MyModel(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x

def torchax_to_pure_jax(state_dict):
"""Extracts raw jax.Arrays from torchax.Tensors using the .jax() method."""
return jax.tree_util.tree_map(
lambda x: x.jax() if isinstance(x, torch.Tensor) and hasattr(x, 'jax') else x,
state_dict
)

# 2. Setup environment
torchax.enable_globally()
model = MyModel().to('jax')

# --- ORBAX NATIVE CONFIGURATION ---
checkpoint_dir = "<ADD_A_GCS_PATH>"

# Configures Orbax to track files named 'step__100', 'step__200', etc.
options = ocp.CheckpointManagerOptions(
max_to_keep=3,
save_interval_steps=1,
step_prefix="step_"
)

mgr = ocp.CheckpointManager(
checkpoint_dir,
item_handlers=ocp.StandardCheckpointHandler(),
options=options
)

# 3. Define training boundaries
total_epochs = 5
steps_per_epoch = 100
total_steps = total_epochs * steps_per_epoch

# ==========================================================
# PHASE 1: ORBAX AUTO-DISCOVERY & WEIGHT UPDATE
# ==========================================================
print(f"Scanning checkpoint directory: {checkpoint_dir}")
latest_step = mgr.latest_step()

start_global_step = 0

if latest_step is not None:
print(f"🌟 Found latest step with 'step_' prefix! -> Step: {latest_step}")
print("Loading state dictionary via Orbax...")

# Setup an empty JAX structure template for Orbax to restore into
jax_weights_template = torchax_to_pure_jax(model.state_dict())
state_template = {
'weights': jax_weights_template,
'epoch': 0,
'global_step': 0
}

load_start = time.perf_counter()

# Restore the saved dictionary containing raw JAX arrays
restored_state = mgr.restore(latest_step, args=ocp.args.StandardRestore(state_template))
restored_jax_weights = restored_state['weights']

# Directly overwrite the layer parameters at the Torchax tensor buffer level
with torch.no_grad():
for name, param in model.named_parameters():
if name in restored_jax_weights:
param._jax_array = restored_jax_weights[name]

load_end = time.perf_counter()
print(f"➔ Orbax Deserialization Time: {load_end - load_start:.4f} seconds.")
print(f"Successfully resumed training from Global Step: {restored_state['global_step']}\n")

start_global_step = restored_state['global_step']
else:
print("No existing checkpoints found with 'step_' format. Starting a fresh training run.\n")


# ==========================================================
# PHASE 2: ORBAX TRAINING LOOP WITH SERIALIZATION
# ==========================================================
for global_step in range(start_global_step + 1, total_steps + 1):

epoch = ((global_step - 1) // steps_per_epoch) + 1
local_step = ((global_step - 1) % steps_per_epoch) + 1

if global_step % 100 == 0:
# Flatten Torchax tensors to raw JAX tensors for seamless serialization
jax_weights_to_save = torchax_to_pure_jax(model.state_dict())

state_to_save = {
'weights': jax_weights_to_save,
'epoch': epoch,
'global_step': global_step
}

print(f"Saving checkpoint via Orbax at Step: {global_step}...")

save_start = time.perf_counter()

# Saves data directly to folders named 'step_100', 'step_200', etc.
mgr.save(
global_step,
args=ocp.args.StandardSave(state_to_save)
)

# Block until async serialization write is completely finalized on disk
mgr.wait_until_finished()

save_end = time.perf_counter()
print(f"➔ Orbax Serialization & Write Time: {save_end - save_start:.4f} seconds.")
print(f"ℹ️ Currently tracked steps in manager: {mgr.all_steps()}\n")

mgr.close()
print("Training and checkpointing complete.")
Checkpoint Saving
Checkpoint Loading

📂 Deconstructing the Orbax Directory Anatomy

When Orbax saves a checkpoint (for example, inside my_checkpoints/step_100/), it doesn't just dump a single, giant, unreadable file. Instead, it uses a powerful storage engine called TensorStore to build a highly organized, database like folder structure. This layout is specifically engineered to handle massive models spread across multiple computers without slowing down.

Write on Medium

If you peek inside a step folder, here is what you will find:

my_checkpoints/step__100/
├── _METADATA
├── _sharding
├── array_metadatas/
├── commit_success.txt
├── d/
├── manifest.ocdbt
└── ocdbt.process_0/

Here is what those files actually are and why they exist:

  1. _METADATA

This file holds high-level blueprint data (in JSON format) describing your checkpoint. Think of it as the master key: it tells Orbax exactly how your Python dictionary objects are structured so it knows how to unflatten and rebuild them back into your computer’s memory during a restore.

2. _sharding

One of Orbax’s biggest superpowers is its native support for parallel training. The _sharding directory remembers exactly how your model's data was split up across your hardware chips (like an 8 core TPU). Because Orbax tracks this, you can safely save a model on an 8 chip setup and effortlessly load it onto a completely different chip layout later.

3. array_metadatas/

This folder contains tiny descriptor files for every single tensor layer in your model weights (like fc1.weight). It notes basic details like data types (e.g., float32) and shapes (e.g., [120, 784]). Orbax reads these tiny files first so your hardware can prepare the exact amount of memory needed before downloading the actual heavy data.

4. commit_success.txt

This file is crucial for preventing broken saves. Copying massive AI models to cloud storage takes time. Orbax writes commit_success.txt as the absolute final step after a save is 100% complete and verified. When your code looks for the latest checkpoint, it checks for this file. If a network glitch happens mid-save, this file will be missing, and Orbax will safely ignore the broken folder.

5. d/

This is where the actual raw binary weight data lives. Instead of cramming everything into one file, Orbax streams values into separate data files. This allows multiple hardware chips to read and write their assigned chunks at the exact same time without locking each other out.

6. manifest.ocdbt & ocdbt.process_0

These files form the core of Orbax’s underlying storage technology: OCDBT, which stands for Optionally-Cooperative Distributed B-tree.

Think of manifest.ocdbt as an incredibly smart library index. Instead of forcing Orbax to read through a massive file to find a specific layer, it looks at this index, finds the exact byte location, and pulls just that layer instantly. This database layout allows for incredibly fast, partial restores without wasting time reading the rest of the model.

🏁 Wrapping Up: What We Learned

Building and scaling massive deep learning models on TPUs doesn’t have to mean sacrificing performance for reliability. Here is the quick breakdown of what we covered:

  • Torchax bridges the gap today by letting you run standard PyTorch models directly on top of JAX’s blazing fast runtime with just a couple of lines of code.
  • TorchTPU is Google’s upcoming native framework built from the ground up for their hardware, bringing a flexible “Eager First” design and huge performance jumps.
  • Orbax completely removes traditional training pauses with background async writes, lets you dynamically mix and match your hardware layouts, and automatically keeps your storage folders clean.

📚🔗 References

Join thousands of data leaders on the AI newsletter. Join over 80,000 subscribers and keep up to date with the latest developments in AI. From research to projects and ideas. If you are building an AI startup, an AI-related product, or a service, we invite you to consider becoming a sponsor.

Published via Towards AI


Towards AI Academy

We Build Enterprise-Grade AI. We'll Teach You to Master It Too.

15 engineers. 100,000+ students. Towards AI Academy teaches what actually survives production.

Start free — no commitment:

6-Day Agentic AI Engineering Email Guide — one practical lesson per day

Agents Architecture Cheatsheet — 3 years of architecture decisions in 6 pages

Our courses:

AI Engineering Certification — 90+ lessons from project selection to deployed product. The most comprehensive practical LLM course out there.

Agent Engineering Course — Hands on with production agent architectures, memory, routing, and eval frameworks — built from real enterprise engagements.

AI for Work — Understand, evaluate, and apply AI for complex work tasks.

Note: Article content contains the views of the contributing authors and not Towards AI.