Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
6328350
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and resto…
amitsrivastava78 Oct 21, 2025
ca71da6
Fix unused variable in orbax checkpoint test
amitsrivastava78 Oct 22, 2025
4dfa903
fixed failing cases
amitsrivastava78 Oct 22, 2025
7742139
fixed review comments
amitsrivastava78 Oct 22, 2025
822396f
Improve OrbaxCheckpoint implementation
amitsrivastava78 Oct 24, 2025
61bd5e6
Fix code formatting and remove unused variable
amitsrivastava78 Oct 24, 2025
19d2495
Add OrbaxCheckpoint callback with conditional exports and improved te…
amitsrivastava78 Oct 24, 2025
b56dc7b
Improve OrbaxCheckpoint: preserve nested structures, enhance tests
amitsrivastava78 Oct 28, 2025
7722e30
Fixed review comments
amitsrivastava78 Oct 31, 2025
eb7855d
Migration to Orbax V1
amitsrivastava78 Nov 5, 2025
aaf6e20
Fix sklearn wrapper CI tests by marking pipeline consistency checks a…
amitsrivastava78 Nov 10, 2025
cd881dd
made distributed structure proper
amitsrivastava78 Nov 10, 2025
9417027
Fixed sav decision between keras and orbax
amitsrivastava78 Nov 11, 2025
b7a0dff
Optimize Orbax checkpoint for JAX backend
amitsrivastava78 Nov 11, 2025
33f4e66
Optimize Orbax checkpoint for JAX backend with compatibility check
amitsrivastava78 Nov 11, 2025
d7884ef
added checkpointer.wait()
amitsrivastava78 Nov 12, 2025
13aec2e
Improve OrbaxCheckpoint callback with optimizations and cleanup
amitsrivastava78 Nov 13, 2025
a2938ea
Simplify OrbaxCheckpoint API to match ModelCheckpoint parity
amitsrivastava78 Nov 13, 2025
4d659f4
Removed the experimental import
amitsrivastava78 Nov 13, 2025
ce30b36
Add comprehensive OrbaxCheckpoint tests with loading verification
amitsrivastava78 Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions keras/api/_tf_keras/keras/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
3 changes: 3 additions & 0 deletions keras/api/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
from keras.src.callbacks.model_checkpoint import (
ModelCheckpoint as ModelCheckpoint,
)
from keras.src.callbacks.orbax_checkpoint import (
OrbaxCheckpoint as OrbaxCheckpoint,
)
from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import (
ReduceLROnPlateau as ReduceLROnPlateau,
Expand Down
1 change: 1 addition & 0 deletions keras/src/callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler
from keras.src.callbacks.model_checkpoint import ModelCheckpoint
from keras.src.callbacks.monitor_callback import MonitorCallback
from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint
from keras.src.callbacks.progbar_logger import ProgbarLogger
from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau
from keras.src.callbacks.remote_monitor import RemoteMonitor
Expand Down
294 changes: 294 additions & 0 deletions keras/src/callbacks/orbax_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,294 @@
import warnings

import numpy as np

from keras.src import backend
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.callbacks.monitor_callback import (
MonitorCallback, # For metric monitoring logic
)
from keras.src.utils.io_utils import print_msg
from keras.src.utils.module_utils import ocp

# Context and AsyncOptions are accessed through the lazy-loaded ocp module


def _get_state_tree(model):
"""Get the complete model state as a nested tree structure."""
# For JAX backend, preserve native arrays if JAX >= 0.7.0
# to avoid unnecessary conversions. Otherwise convert to numpy.
did_numpy_conversion = False
if backend.backend() == "jax":
import jax
from packaging import version

# Check JAX version directly (JAX 0.7.0+ supports better array handling)
if version.parse(jax.__version__) >= version.parse("0.7.0"):
state_tree = model.get_state_tree()
else:
# Fallback to numpy conversion for older JAX versions
state_tree = model.get_state_tree(value_format="numpy_array")
did_numpy_conversion = True
else:
state_tree = model.get_state_tree(value_format="numpy_array")
did_numpy_conversion = True

# Convert numpy scalar types to Python types for Orbax compatibility
# Only needed when we did numpy conversion
if did_numpy_conversion:

def convert_scalars(obj):
if isinstance(obj, np.ndarray) and obj.ndim == 0:
# Convert 0-dimensional numpy arrays (scalars) to Python types
return obj.item()
elif isinstance(obj, np.generic):
# Convert numpy scalar types (like np.float32) to Python types
return obj.item()
else:
return obj

return tree.map_structure(convert_scalars, state_tree)
else:
return state_tree


@keras_export("keras.callbacks.OrbaxCheckpoint")
class OrbaxCheckpoint(MonitorCallback):
"""Callback to save and load model state using Orbax with a similar API to
ModelCheckpoint.

This callback saves the model's weights and optimizer state asynchronously
using Orbax, allowing training to continue without blocking for I/O.

Example:

```python
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])

EPOCHS = 10
checkpoint_dir = '/tmp/ckpt'
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
directory=checkpoint_dir,
monitor='val_accuracy',
mode='max',
save_best_only=True)

# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])

# Alternatively, save checkpoints every N batches -
orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint(
directory=checkpoint_dir,
save_freq=100) # Save every 100 batches

model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback])
```

Args:
directory: string, path to the directory where to save the checkpoints.
monitor: The metric name to monitor (e.g., 'val_loss').
verbose: Verbosity mode, 0 or 1.
save_best_only: if `save_best_only=True`, it only saves when the model
is considered the "best" based on the monitored quantity.
save_weights_only: if `save_weights_only=True`, only the model's weights
will be saved. Otherwise, both weights and optimizer state will be
saved. Defaults to False.
mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`.
save_freq: `'epoch'` or integer. Frequency to save checkpoints.
max_to_keep: Integer, maximum number of recent checkpoints to keep.
If None, keeps all. Defaults to 5.
save_on_background: Boolean, whether to save asynchronously in the
background. Defaults to True.
initial_value_threshold: Floating point initial "best" value for the
monitor, used with `save_best_only`.
"""

def __init__(
self,
directory,
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
save_freq="epoch",
max_to_keep=5,
save_on_background=True,
initial_value_threshold=None,
):
# Ensure orbax is available
ocp.initialize()

# Initialize MonitorCallback for handling 'monitor', 'mode', 'best'
# logic
super().__init__(monitor, mode, initial_value_threshold)

self.directory = directory
self.verbose = verbose
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.save_freq = save_freq
self.max_to_keep = max_to_keep
self.save_on_background = save_on_background
self._batches_seen_since_last_saving = 0
self._last_batch_seen = 0
self._current_epoch = 0 # Keep track of epoch
self._total_batches_seen = 0 # Global batch counter for step tracking

if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
raise ValueError("Unrecognized save_freq")

# --- Orbax Checkpointer Setup (V1 API) ---
policies = []
if max_to_keep is not None:
policies.append(
ocp.training.preservation_policies.LatestN(max_to_keep)
)

# Use AnyPreservationPolicy to combine them.
preservation_policy = None
if policies:
preservation_policy = (
ocp.training.preservation_policies.AnyPreservationPolicy(
policies
)
)

# Create the V1 Checkpointer with direct parameter passing
# Orbax will handle directory creation on all processes as needed
self.checkpointer = ocp.training.Checkpointer(
directory=directory,
preservation_policy=preservation_policy,
)

def _should_save_on_batch(self, batch):
"""Check if we should save on this batch."""
if self.save_freq == "epoch":
return False

if batch <= self._last_batch_seen: # New epoch.
add_batches = batch + 1
else:
add_batches = batch - self._last_batch_seen
self._batches_seen_since_last_saving += add_batches
self._last_batch_seen = batch
self._total_batches_seen += add_batches

if self._batches_seen_since_last_saving >= self.save_freq:
self._batches_seen_since_last_saving = 0
return True
return False

def _save_checkpoint(self, step, logs=None):
"""Save a checkpoint at the given step."""

# --- Prepare Composite State (Backend-Agnostic) ---
state_tree = _get_state_tree(self.model)

# Save the nested state structures directly (preserving layer
# names and structure)
composite_state = {
"trainable_variables": state_tree["trainable_variables"],
}

# Include optimizer state unless save_weights_only is True
if not self.save_weights_only and "optimizer_variables" in state_tree:
composite_state["optimizer_variables"] = state_tree[
"optimizer_variables"
]

# --- Save Logic (V1 API) ---
# All processes participate in distributed checkpointing
# Checkpointer is configured to save unconditionally when
# save_pytree is called
if self.verbose > 0:
print_msg(
f"OrbaxCheckpoint: Triggering async save for step {step}..."
)

# Use a single with statement. If context_options is empty,
# Context() uses defaults.
with ocp.Context():
if self.save_on_background:
self.checkpointer.save_pytree_async(step, composite_state)
else:
self.checkpointer.save_pytree(step, composite_state)

def on_train_batch_end(self, batch, logs=None):
if self._should_save_on_batch(batch):
# Handle save_best_only logic for batch-level saving
should_save = True
if self.save_best_only:
current = logs.get(self.monitor) if logs else None
if current is None:
warnings.warn(
f"Can save best model only with {self.monitor} "
f"available, skipping save at batch {batch}.",
stacklevel=2,
)
should_save = False
elif not self._is_improvement(current, self.best):
should_save = False
else:
# Update best value when there's improvement
self.best = current

if should_save:
# Use global batch count for Orbax save step
step = self._total_batches_seen
self._save_checkpoint(step=step, logs=logs)

def on_epoch_end(self, epoch, logs=None):
self._current_epoch = epoch
if self.monitor_op is None:
self._set_monitor_op() # From MonitorCallback

# For save_freq="epoch", save at every epoch
should_save = self.save_freq == "epoch"

# Handle save_best_only logic
if should_save and self.save_best_only:
current = logs.get(self.monitor) if logs else None
if current is None:
warnings.warn(
f"Can save best model only with {self.monitor} available, "
f"skipping save at epoch {epoch}.",
stacklevel=2,
)
should_save = False
elif not self._is_improvement(current, self.best):
should_save = False
else:
# Update best value when there's improvement
self.best = current

if should_save:
# Use epoch number as the step for Orbax save
# Keras has already made the save decision - Checkpointer will
# save unconditionally
self._save_checkpoint(step=epoch, logs=logs)

def on_train_end(self, logs=None):
# Close the Checkpointer to ensure all pending saves complete
try:
self.checkpointer.close()
except Exception:
pass # Ignore errors during cleanup

def wait_until_finished(self):
"""Wait for any in-progress checkpoint operations to complete.
This method blocks until all asynchronous checkpoint save operations
have completed. It should be called before attempting to load
checkpoints if there might be pending save operations.
"""
# Wait for any async operations to complete
try:
self.checkpointer.wait()
except AttributeError:
# Fallback for older Orbax versions that don't have wait() method
while self.checkpointer.is_saving_in_progress():
import time

time.sleep(0.1)
Loading