-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and restore #21762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
amitsrivastava78
wants to merge
20
commits into
keras-team:master
Choose a base branch
from
amitsrivastava78:orbax-checkpoint-test-improvements
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+805
−1
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
6328350
Added OrbaxCheckpoint for keras 3.0 for Data centric saving and resto…
amitsrivastava78 ca71da6
Fix unused variable in orbax checkpoint test
amitsrivastava78 4dfa903
fixed failing cases
amitsrivastava78 7742139
fixed review comments
amitsrivastava78 822396f
Improve OrbaxCheckpoint implementation
amitsrivastava78 61bd5e6
Fix code formatting and remove unused variable
amitsrivastava78 19d2495
Add OrbaxCheckpoint callback with conditional exports and improved te…
amitsrivastava78 b56dc7b
Improve OrbaxCheckpoint: preserve nested structures, enhance tests
amitsrivastava78 7722e30
Fixed review comments
amitsrivastava78 eb7855d
Migration to Orbax V1
amitsrivastava78 aaf6e20
Fix sklearn wrapper CI tests by marking pipeline consistency checks a…
amitsrivastava78 cd881dd
made distributed structure proper
amitsrivastava78 9417027
Fixed sav decision between keras and orbax
amitsrivastava78 b7a0dff
Optimize Orbax checkpoint for JAX backend
amitsrivastava78 33f4e66
Optimize Orbax checkpoint for JAX backend with compatibility check
amitsrivastava78 d7884ef
added checkpointer.wait()
amitsrivastava78 13aec2e
Improve OrbaxCheckpoint callback with optimizations and cleanup
amitsrivastava78 a2938ea
Simplify OrbaxCheckpoint API to match ModelCheckpoint parity
amitsrivastava78 4d659f4
Removed the experimental import
amitsrivastava78 ce30b36
Add comprehensive OrbaxCheckpoint tests with loading verification
amitsrivastava78 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,294 @@ | ||
| import warnings | ||
|
|
||
| import numpy as np | ||
|
|
||
| from keras.src import backend | ||
| from keras.src import tree | ||
| from keras.src.api_export import keras_export | ||
| from keras.src.callbacks.monitor_callback import ( | ||
| MonitorCallback, # For metric monitoring logic | ||
| ) | ||
| from keras.src.utils.io_utils import print_msg | ||
| from keras.src.utils.module_utils import ocp | ||
|
|
||
| # Context and AsyncOptions are accessed through the lazy-loaded ocp module | ||
|
|
||
|
|
||
| def _get_state_tree(model): | ||
| """Get the complete model state as a nested tree structure.""" | ||
| # For JAX backend, preserve native arrays if JAX >= 0.7.0 | ||
| # to avoid unnecessary conversions. Otherwise convert to numpy. | ||
| did_numpy_conversion = False | ||
| if backend.backend() == "jax": | ||
| import jax | ||
| from packaging import version | ||
|
|
||
| # Check JAX version directly (JAX 0.7.0+ supports better array handling) | ||
| if version.parse(jax.__version__) >= version.parse("0.7.0"): | ||
| state_tree = model.get_state_tree() | ||
| else: | ||
| # Fallback to numpy conversion for older JAX versions | ||
| state_tree = model.get_state_tree(value_format="numpy_array") | ||
| did_numpy_conversion = True | ||
| else: | ||
| state_tree = model.get_state_tree(value_format="numpy_array") | ||
| did_numpy_conversion = True | ||
|
|
||
| # Convert numpy scalar types to Python types for Orbax compatibility | ||
| # Only needed when we did numpy conversion | ||
| if did_numpy_conversion: | ||
|
|
||
| def convert_scalars(obj): | ||
| if isinstance(obj, np.ndarray) and obj.ndim == 0: | ||
| # Convert 0-dimensional numpy arrays (scalars) to Python types | ||
| return obj.item() | ||
| elif isinstance(obj, np.generic): | ||
| # Convert numpy scalar types (like np.float32) to Python types | ||
| return obj.item() | ||
| else: | ||
| return obj | ||
|
|
||
| return tree.map_structure(convert_scalars, state_tree) | ||
| else: | ||
| return state_tree | ||
|
|
||
|
|
||
| @keras_export("keras.callbacks.OrbaxCheckpoint") | ||
| class OrbaxCheckpoint(MonitorCallback): | ||
| """Callback to save and load model state using Orbax with a similar API to | ||
| ModelCheckpoint. | ||
|
|
||
| This callback saves the model's weights and optimizer state asynchronously | ||
| using Orbax, allowing training to continue without blocking for I/O. | ||
|
|
||
| Example: | ||
|
|
||
| ```python | ||
| model.compile(loss=..., optimizer=..., | ||
| metrics=['accuracy']) | ||
|
|
||
| EPOCHS = 10 | ||
| checkpoint_dir = '/tmp/ckpt' | ||
| orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( | ||
| directory=checkpoint_dir, | ||
| monitor='val_accuracy', | ||
| mode='max', | ||
| save_best_only=True) | ||
|
|
||
| # Model is saved at the end of every epoch, if it's the best seen so far. | ||
| model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) | ||
|
|
||
| # Alternatively, save checkpoints every N batches - | ||
| orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( | ||
| directory=checkpoint_dir, | ||
| save_freq=100) # Save every 100 batches | ||
|
|
||
| model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) | ||
| ``` | ||
|
|
||
| Args: | ||
| directory: string, path to the directory where to save the checkpoints. | ||
| monitor: The metric name to monitor (e.g., 'val_loss'). | ||
| verbose: Verbosity mode, 0 or 1. | ||
| save_best_only: if `save_best_only=True`, it only saves when the model | ||
| is considered the "best" based on the monitored quantity. | ||
| save_weights_only: if `save_weights_only=True`, only the model's weights | ||
| will be saved. Otherwise, both weights and optimizer state will be | ||
| saved. Defaults to False. | ||
| mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`. | ||
| save_freq: `'epoch'` or integer. Frequency to save checkpoints. | ||
| max_to_keep: Integer, maximum number of recent checkpoints to keep. | ||
| If None, keeps all. Defaults to 5. | ||
| save_on_background: Boolean, whether to save asynchronously in the | ||
| background. Defaults to True. | ||
| initial_value_threshold: Floating point initial "best" value for the | ||
| monitor, used with `save_best_only`. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| directory, | ||
| monitor="val_loss", | ||
| verbose=0, | ||
| save_best_only=False, | ||
| save_weights_only=False, | ||
| mode="auto", | ||
| save_freq="epoch", | ||
| max_to_keep=5, | ||
| save_on_background=True, | ||
| initial_value_threshold=None, | ||
| ): | ||
| # Ensure orbax is available | ||
| ocp.initialize() | ||
|
|
||
| # Initialize MonitorCallback for handling 'monitor', 'mode', 'best' | ||
| # logic | ||
| super().__init__(monitor, mode, initial_value_threshold) | ||
|
|
||
| self.directory = directory | ||
| self.verbose = verbose | ||
| self.save_best_only = save_best_only | ||
| self.save_weights_only = save_weights_only | ||
| self.save_freq = save_freq | ||
| self.max_to_keep = max_to_keep | ||
| self.save_on_background = save_on_background | ||
| self._batches_seen_since_last_saving = 0 | ||
| self._last_batch_seen = 0 | ||
| self._current_epoch = 0 # Keep track of epoch | ||
| self._total_batches_seen = 0 # Global batch counter for step tracking | ||
|
|
||
| if self.save_freq != "epoch" and not isinstance(self.save_freq, int): | ||
| raise ValueError("Unrecognized save_freq") | ||
|
|
||
| # --- Orbax Checkpointer Setup (V1 API) --- | ||
| policies = [] | ||
| if max_to_keep is not None: | ||
| policies.append( | ||
| ocp.training.preservation_policies.LatestN(max_to_keep) | ||
| ) | ||
|
|
||
| # Use AnyPreservationPolicy to combine them. | ||
| preservation_policy = None | ||
| if policies: | ||
| preservation_policy = ( | ||
| ocp.training.preservation_policies.AnyPreservationPolicy( | ||
| policies | ||
| ) | ||
| ) | ||
|
|
||
| # Create the V1 Checkpointer with direct parameter passing | ||
| # Orbax will handle directory creation on all processes as needed | ||
| self.checkpointer = ocp.training.Checkpointer( | ||
| directory=directory, | ||
| preservation_policy=preservation_policy, | ||
| ) | ||
|
|
||
| def _should_save_on_batch(self, batch): | ||
| """Check if we should save on this batch.""" | ||
| if self.save_freq == "epoch": | ||
| return False | ||
|
|
||
| if batch <= self._last_batch_seen: # New epoch. | ||
| add_batches = batch + 1 | ||
| else: | ||
| add_batches = batch - self._last_batch_seen | ||
| self._batches_seen_since_last_saving += add_batches | ||
| self._last_batch_seen = batch | ||
| self._total_batches_seen += add_batches | ||
|
|
||
| if self._batches_seen_since_last_saving >= self.save_freq: | ||
| self._batches_seen_since_last_saving = 0 | ||
| return True | ||
| return False | ||
|
|
||
| def _save_checkpoint(self, step, logs=None): | ||
| """Save a checkpoint at the given step.""" | ||
|
|
||
| # --- Prepare Composite State (Backend-Agnostic) --- | ||
| state_tree = _get_state_tree(self.model) | ||
|
|
||
| # Save the nested state structures directly (preserving layer | ||
| # names and structure) | ||
| composite_state = { | ||
| "trainable_variables": state_tree["trainable_variables"], | ||
| } | ||
|
|
||
| # Include optimizer state unless save_weights_only is True | ||
| if not self.save_weights_only and "optimizer_variables" in state_tree: | ||
| composite_state["optimizer_variables"] = state_tree[ | ||
| "optimizer_variables" | ||
| ] | ||
|
|
||
| # --- Save Logic (V1 API) --- | ||
| # All processes participate in distributed checkpointing | ||
| # Checkpointer is configured to save unconditionally when | ||
| # save_pytree is called | ||
| if self.verbose > 0: | ||
| print_msg( | ||
| f"OrbaxCheckpoint: Triggering async save for step {step}..." | ||
| ) | ||
|
|
||
| # Use a single with statement. If context_options is empty, | ||
| # Context() uses defaults. | ||
| with ocp.Context(): | ||
| if self.save_on_background: | ||
| self.checkpointer.save_pytree_async(step, composite_state) | ||
| else: | ||
| self.checkpointer.save_pytree(step, composite_state) | ||
|
|
||
| def on_train_batch_end(self, batch, logs=None): | ||
| if self._should_save_on_batch(batch): | ||
| # Handle save_best_only logic for batch-level saving | ||
| should_save = True | ||
| if self.save_best_only: | ||
| current = logs.get(self.monitor) if logs else None | ||
| if current is None: | ||
| warnings.warn( | ||
| f"Can save best model only with {self.monitor} " | ||
| f"available, skipping save at batch {batch}.", | ||
| stacklevel=2, | ||
| ) | ||
| should_save = False | ||
| elif not self._is_improvement(current, self.best): | ||
| should_save = False | ||
| else: | ||
| # Update best value when there's improvement | ||
| self.best = current | ||
|
|
||
| if should_save: | ||
| # Use global batch count for Orbax save step | ||
| step = self._total_batches_seen | ||
| self._save_checkpoint(step=step, logs=logs) | ||
|
|
||
| def on_epoch_end(self, epoch, logs=None): | ||
| self._current_epoch = epoch | ||
| if self.monitor_op is None: | ||
| self._set_monitor_op() # From MonitorCallback | ||
|
|
||
| # For save_freq="epoch", save at every epoch | ||
| should_save = self.save_freq == "epoch" | ||
|
|
||
| # Handle save_best_only logic | ||
| if should_save and self.save_best_only: | ||
| current = logs.get(self.monitor) if logs else None | ||
| if current is None: | ||
| warnings.warn( | ||
| f"Can save best model only with {self.monitor} available, " | ||
| f"skipping save at epoch {epoch}.", | ||
| stacklevel=2, | ||
| ) | ||
| should_save = False | ||
| elif not self._is_improvement(current, self.best): | ||
| should_save = False | ||
| else: | ||
| # Update best value when there's improvement | ||
| self.best = current | ||
|
|
||
| if should_save: | ||
| # Use epoch number as the step for Orbax save | ||
| # Keras has already made the save decision - Checkpointer will | ||
| # save unconditionally | ||
| self._save_checkpoint(step=epoch, logs=logs) | ||
|
|
||
| def on_train_end(self, logs=None): | ||
| # Close the Checkpointer to ensure all pending saves complete | ||
| try: | ||
| self.checkpointer.close() | ||
| except Exception: | ||
| pass # Ignore errors during cleanup | ||
|
|
||
| def wait_until_finished(self): | ||
| """Wait for any in-progress checkpoint operations to complete. | ||
| This method blocks until all asynchronous checkpoint save operations | ||
| have completed. It should be called before attempting to load | ||
| checkpoints if there might be pending save operations. | ||
| """ | ||
| # Wait for any async operations to complete | ||
| try: | ||
| self.checkpointer.wait() | ||
| except AttributeError: | ||
| # Fallback for older Orbax versions that don't have wait() method | ||
| while self.checkpointer.is_saving_in_progress(): | ||
| import time | ||
|
|
||
| time.sleep(0.1) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.