diff --git a/keras/api/_tf_keras/keras/callbacks/__init__.py b/keras/api/_tf_keras/keras/callbacks/__init__.py index 4e165cddb6a8..ce5f900d80f5 100644 --- a/keras/api/_tf_keras/keras/callbacks/__init__.py +++ b/keras/api/_tf_keras/keras/callbacks/__init__.py @@ -19,6 +19,9 @@ from keras.src.callbacks.model_checkpoint import ( ModelCheckpoint as ModelCheckpoint, ) +from keras.src.callbacks.orbax_checkpoint import ( + OrbaxCheckpoint as OrbaxCheckpoint, +) from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger from keras.src.callbacks.reduce_lr_on_plateau import ( ReduceLROnPlateau as ReduceLROnPlateau, diff --git a/keras/api/callbacks/__init__.py b/keras/api/callbacks/__init__.py index 4e165cddb6a8..ce5f900d80f5 100644 --- a/keras/api/callbacks/__init__.py +++ b/keras/api/callbacks/__init__.py @@ -19,6 +19,9 @@ from keras.src.callbacks.model_checkpoint import ( ModelCheckpoint as ModelCheckpoint, ) +from keras.src.callbacks.orbax_checkpoint import ( + OrbaxCheckpoint as OrbaxCheckpoint, +) from keras.src.callbacks.progbar_logger import ProgbarLogger as ProgbarLogger from keras.src.callbacks.reduce_lr_on_plateau import ( ReduceLROnPlateau as ReduceLROnPlateau, diff --git a/keras/src/callbacks/__init__.py b/keras/src/callbacks/__init__.py index 427c4f6da95f..c62aed69ee63 100644 --- a/keras/src/callbacks/__init__.py +++ b/keras/src/callbacks/__init__.py @@ -8,6 +8,7 @@ from keras.src.callbacks.learning_rate_scheduler import LearningRateScheduler from keras.src.callbacks.model_checkpoint import ModelCheckpoint from keras.src.callbacks.monitor_callback import MonitorCallback +from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint from keras.src.callbacks.progbar_logger import ProgbarLogger from keras.src.callbacks.reduce_lr_on_plateau import ReduceLROnPlateau from keras.src.callbacks.remote_monitor import RemoteMonitor diff --git a/keras/src/callbacks/orbax_checkpoint.py b/keras/src/callbacks/orbax_checkpoint.py new file mode 100644 index 000000000000..5383ff996406 --- /dev/null +++ b/keras/src/callbacks/orbax_checkpoint.py @@ -0,0 +1,294 @@ +import warnings + +import numpy as np + +from keras.src import backend +from keras.src import tree +from keras.src.api_export import keras_export +from keras.src.callbacks.monitor_callback import ( + MonitorCallback, # For metric monitoring logic +) +from keras.src.utils.io_utils import print_msg +from keras.src.utils.module_utils import ocp + +# Context and AsyncOptions are accessed through the lazy-loaded ocp module + + +def _get_state_tree(model): + """Get the complete model state as a nested tree structure.""" + # For JAX backend, preserve native arrays if JAX >= 0.7.0 + # to avoid unnecessary conversions. Otherwise convert to numpy. + did_numpy_conversion = False + if backend.backend() == "jax": + import jax + from packaging import version + + # Check JAX version directly (JAX 0.7.0+ supports better array handling) + if version.parse(jax.__version__) >= version.parse("0.7.0"): + state_tree = model.get_state_tree() + else: + # Fallback to numpy conversion for older JAX versions + state_tree = model.get_state_tree(value_format="numpy_array") + did_numpy_conversion = True + else: + state_tree = model.get_state_tree(value_format="numpy_array") + did_numpy_conversion = True + + # Convert numpy scalar types to Python types for Orbax compatibility + # Only needed when we did numpy conversion + if did_numpy_conversion: + + def convert_scalars(obj): + if isinstance(obj, np.ndarray) and obj.ndim == 0: + # Convert 0-dimensional numpy arrays (scalars) to Python types + return obj.item() + elif isinstance(obj, np.generic): + # Convert numpy scalar types (like np.float32) to Python types + return obj.item() + else: + return obj + + return tree.map_structure(convert_scalars, state_tree) + else: + return state_tree + + +@keras_export("keras.callbacks.OrbaxCheckpoint") +class OrbaxCheckpoint(MonitorCallback): + """Callback to save and load model state using Orbax with a similar API to + ModelCheckpoint. + + This callback saves the model's weights and optimizer state asynchronously + using Orbax, allowing training to continue without blocking for I/O. + + Example: + + ```python + model.compile(loss=..., optimizer=..., + metrics=['accuracy']) + + EPOCHS = 10 + checkpoint_dir = '/tmp/ckpt' + orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( + directory=checkpoint_dir, + monitor='val_accuracy', + mode='max', + save_best_only=True) + + # Model is saved at the end of every epoch, if it's the best seen so far. + model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) + + # Alternatively, save checkpoints every N batches - + orbax_checkpoint_callback = keras.callbacks.OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq=100) # Save every 100 batches + + model.fit(epochs=EPOCHS, callbacks=[orbax_checkpoint_callback]) + ``` + + Args: + directory: string, path to the directory where to save the checkpoints. + monitor: The metric name to monitor (e.g., 'val_loss'). + verbose: Verbosity mode, 0 or 1. + save_best_only: if `save_best_only=True`, it only saves when the model + is considered the "best" based on the monitored quantity. + save_weights_only: if `save_weights_only=True`, only the model's weights + will be saved. Otherwise, both weights and optimizer state will be + saved. Defaults to False. + mode: one of {'auto', 'min', 'max'}. Used with `save_best_only`. + save_freq: `'epoch'` or integer. Frequency to save checkpoints. + max_to_keep: Integer, maximum number of recent checkpoints to keep. + If None, keeps all. Defaults to 5. + save_on_background: Boolean, whether to save asynchronously in the + background. Defaults to True. + initial_value_threshold: Floating point initial "best" value for the + monitor, used with `save_best_only`. + """ + + def __init__( + self, + directory, + monitor="val_loss", + verbose=0, + save_best_only=False, + save_weights_only=False, + mode="auto", + save_freq="epoch", + max_to_keep=5, + save_on_background=True, + initial_value_threshold=None, + ): + # Ensure orbax is available + ocp.initialize() + + # Initialize MonitorCallback for handling 'monitor', 'mode', 'best' + # logic + super().__init__(monitor, mode, initial_value_threshold) + + self.directory = directory + self.verbose = verbose + self.save_best_only = save_best_only + self.save_weights_only = save_weights_only + self.save_freq = save_freq + self.max_to_keep = max_to_keep + self.save_on_background = save_on_background + self._batches_seen_since_last_saving = 0 + self._last_batch_seen = 0 + self._current_epoch = 0 # Keep track of epoch + self._total_batches_seen = 0 # Global batch counter for step tracking + + if self.save_freq != "epoch" and not isinstance(self.save_freq, int): + raise ValueError("Unrecognized save_freq") + + # --- Orbax Checkpointer Setup (V1 API) --- + policies = [] + if max_to_keep is not None: + policies.append( + ocp.training.preservation_policies.LatestN(max_to_keep) + ) + + # Use AnyPreservationPolicy to combine them. + preservation_policy = None + if policies: + preservation_policy = ( + ocp.training.preservation_policies.AnyPreservationPolicy( + policies + ) + ) + + # Create the V1 Checkpointer with direct parameter passing + # Orbax will handle directory creation on all processes as needed + self.checkpointer = ocp.training.Checkpointer( + directory=directory, + preservation_policy=preservation_policy, + ) + + def _should_save_on_batch(self, batch): + """Check if we should save on this batch.""" + if self.save_freq == "epoch": + return False + + if batch <= self._last_batch_seen: # New epoch. + add_batches = batch + 1 + else: + add_batches = batch - self._last_batch_seen + self._batches_seen_since_last_saving += add_batches + self._last_batch_seen = batch + self._total_batches_seen += add_batches + + if self._batches_seen_since_last_saving >= self.save_freq: + self._batches_seen_since_last_saving = 0 + return True + return False + + def _save_checkpoint(self, step, logs=None): + """Save a checkpoint at the given step.""" + + # --- Prepare Composite State (Backend-Agnostic) --- + state_tree = _get_state_tree(self.model) + + # Save the nested state structures directly (preserving layer + # names and structure) + composite_state = { + "trainable_variables": state_tree["trainable_variables"], + } + + # Include optimizer state unless save_weights_only is True + if not self.save_weights_only and "optimizer_variables" in state_tree: + composite_state["optimizer_variables"] = state_tree[ + "optimizer_variables" + ] + + # --- Save Logic (V1 API) --- + # All processes participate in distributed checkpointing + # Checkpointer is configured to save unconditionally when + # save_pytree is called + if self.verbose > 0: + print_msg( + f"OrbaxCheckpoint: Triggering async save for step {step}..." + ) + + # Use a single with statement. If context_options is empty, + # Context() uses defaults. + with ocp.Context(): + if self.save_on_background: + self.checkpointer.save_pytree_async(step, composite_state) + else: + self.checkpointer.save_pytree(step, composite_state) + + def on_train_batch_end(self, batch, logs=None): + if self._should_save_on_batch(batch): + # Handle save_best_only logic for batch-level saving + should_save = True + if self.save_best_only: + current = logs.get(self.monitor) if logs else None + if current is None: + warnings.warn( + f"Can save best model only with {self.monitor} " + f"available, skipping save at batch {batch}.", + stacklevel=2, + ) + should_save = False + elif not self._is_improvement(current, self.best): + should_save = False + else: + # Update best value when there's improvement + self.best = current + + if should_save: + # Use global batch count for Orbax save step + step = self._total_batches_seen + self._save_checkpoint(step=step, logs=logs) + + def on_epoch_end(self, epoch, logs=None): + self._current_epoch = epoch + if self.monitor_op is None: + self._set_monitor_op() # From MonitorCallback + + # For save_freq="epoch", save at every epoch + should_save = self.save_freq == "epoch" + + # Handle save_best_only logic + if should_save and self.save_best_only: + current = logs.get(self.monitor) if logs else None + if current is None: + warnings.warn( + f"Can save best model only with {self.monitor} available, " + f"skipping save at epoch {epoch}.", + stacklevel=2, + ) + should_save = False + elif not self._is_improvement(current, self.best): + should_save = False + else: + # Update best value when there's improvement + self.best = current + + if should_save: + # Use epoch number as the step for Orbax save + # Keras has already made the save decision - Checkpointer will + # save unconditionally + self._save_checkpoint(step=epoch, logs=logs) + + def on_train_end(self, logs=None): + # Close the Checkpointer to ensure all pending saves complete + try: + self.checkpointer.close() + except Exception: + pass # Ignore errors during cleanup + + def wait_until_finished(self): + """Wait for any in-progress checkpoint operations to complete. + This method blocks until all asynchronous checkpoint save operations + have completed. It should be called before attempting to load + checkpoints if there might be pending save operations. + """ + # Wait for any async operations to complete + try: + self.checkpointer.wait() + except AttributeError: + # Fallback for older Orbax versions that don't have wait() method + while self.checkpointer.is_saving_in_progress(): + import time + + time.sleep(0.1) diff --git a/keras/src/callbacks/orbax_checkpoint_test.py b/keras/src/callbacks/orbax_checkpoint_test.py new file mode 100644 index 000000000000..d032bd86743e --- /dev/null +++ b/keras/src/callbacks/orbax_checkpoint_test.py @@ -0,0 +1,488 @@ +import os + +import numpy as np +import pytest + +from keras.src import layers +from keras.src import models +from keras.src import testing +from keras.src.utils.module_utils import ocp + +# Import advanced Orbax functionality directly from the LazyModule +# These will only be available if orbax-checkpoint is installed +if ocp.available: + Checkpointer = ocp.training.Checkpointer + save_pytree = ocp.save_pytree + load_pytree = ocp.load_pytree + preservation_policies = ocp.training.preservation_policies + save_decision_policies = ocp.training.save_decision_policies + _orbax_available = True +else: + Checkpointer = None + save_pytree = None + load_pytree = None + preservation_policies = None + save_decision_policies = None + _orbax_available = False + +# Import our OrbaxCheckpoint callback +try: + from keras.src.callbacks.orbax_checkpoint import OrbaxCheckpoint + + _orbax_available = _orbax_available and True +except ImportError: + OrbaxCheckpoint = None + _orbax_available = False + + +@pytest.mark.skipif( + not _orbax_available, + reason="OrbaxCheckpoint requires the 'orbax-checkpoint' package", +) +class OrbaxCheckpointTest(testing.TestCase): + def _create_test_model(self): + """Create a simple test model.""" + inputs = layers.Input(shape=(10,), name="input_layer") + x = layers.Dense(5, name="dense_layer")(inputs) + outputs = layers.Dense(1, name="output_layer")(x) + model = models.Model(inputs, outputs, name="test_model") + model.compile(optimizer="adam", loss="mse") + return model + + def _create_dummy_data(self, num_samples=100): + """Create dummy training data.""" + x = np.random.randn(num_samples, 10) + y = np.random.randn(num_samples, 1) + return x, y + + @pytest.mark.requires_trainable_backend + def test_save_freq_batch(self): + """Test batch-level saving.""" + model = self._create_test_model() + x, y = self._create_dummy_data(num_samples=50) + + checkpoint_dir = os.path.join(self.get_temp_dir(), "test_batch_freq") + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq=10) + + # Train for one epoch with batch saving + model.fit(x, y, epochs=1, batch_size=5, callbacks=[callback], verbose=0) + + # Wait for async operations to complete before cleanup + callback.wait_until_finished() + + # Check that checkpoint files were created + # With 50 samples, batch_size=5, and save_freq=10, there are 10 batches. + # The callback should save at the end of batch 9 (step 10, since + # _total_batches_seen is 1-indexed). + checkpoint_files = os.listdir(checkpoint_dir) + # Should have at least one checkpoint file + self.assertGreater( + len(checkpoint_files), + 0, + f"Should have checkpoint files, found {checkpoint_files}", + ) + + # Check for the specific step 10 checkpoint + step_10_dir = os.path.join(checkpoint_dir, "10") + self.assertTrue( + os.path.exists(step_10_dir), + f"Step 10 checkpoint should exist at {step_10_dir}", + ) + + @pytest.mark.requires_trainable_backend + def test_directory_creation(self): + """Test that checkpoint directory is created if it doesn't exist.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join( + self.get_temp_dir(), "test_create_dir", "subdir" + ) + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq="epoch") + + # Directory should be created during training + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + + self.assertTrue( + os.path.exists(checkpoint_dir), + "Checkpoint directory should be created", + ) + + # Wait for async operations to complete before test cleanup + callback.wait_until_finished() + + @pytest.mark.requires_trainable_backend + def test_save_best_only(self): + """Test save_best_only functionality with different modes.""" + model = self._create_test_model() + x, y = self._create_dummy_data(num_samples=100) + + # Test with mode='min' (save when loss decreases) + checkpoint_dir = os.path.join(self.get_temp_dir(), "test_save_best_min") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + monitor="loss", + save_best_only=True, + mode="min", + save_freq="epoch", + ) + + # Train for multiple epochs - should only save when loss improves + model.fit(x, y, epochs=5, callbacks=[callback], verbose=0) + callback.wait_until_finished() + + # Check that checkpoint directory exists and has files + checkpoint_files = os.listdir(checkpoint_dir) + self.assertGreater( + len(checkpoint_files), 0, "Should have at least one checkpoint" + ) + + # Test with mode='max' (save when accuracy increases) + checkpoint_dir_max = os.path.join( + self.get_temp_dir(), "test_save_best_max" + ) + callback_max = OrbaxCheckpoint( + directory=checkpoint_dir_max, + monitor="loss", # Using loss with mode=max + save_best_only=True, + mode="max", + save_freq="epoch", + ) + + model.fit(x, y, epochs=3, callbacks=[callback_max], verbose=0) + callback_max.wait_until_finished() + + checkpoint_files_max = os.listdir(checkpoint_dir_max) + self.assertGreater( + len(checkpoint_files_max), 0, "Should have at least one checkpoint" + ) + + @pytest.mark.requires_trainable_backend + def test_save_weights_only(self): + """Test save_weights_only parameter.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + # Test save_weights_only=True + checkpoint_dir_weights = os.path.join( + self.get_temp_dir(), "test_weights_only" + ) + callback_weights = OrbaxCheckpoint( + directory=checkpoint_dir_weights, + save_weights_only=True, + save_freq="epoch", + ) + + model.fit(x, y, epochs=1, callbacks=[callback_weights], verbose=0) + callback_weights.wait_until_finished() + + # Check that checkpoint was created + checkpoint_files = os.listdir(checkpoint_dir_weights) + self.assertGreater( + len(checkpoint_files), 0, "Should have checkpoint files" + ) + + # Test save_weights_only=False (default - saves optimizer state) + checkpoint_dir_full = os.path.join( + self.get_temp_dir(), "test_full_save" + ) + callback_full = OrbaxCheckpoint( + directory=checkpoint_dir_full, + save_weights_only=False, + save_freq="epoch", + ) + + model.fit(x, y, epochs=1, callbacks=[callback_full], verbose=0) + callback_full.wait_until_finished() + + checkpoint_files_full = os.listdir(checkpoint_dir_full) + self.assertGreater( + len(checkpoint_files_full), 0, "Should have checkpoint files" + ) + + @pytest.mark.requires_trainable_backend + def test_save_freq_epoch(self): + """Test save_freq='epoch' functionality.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.get_temp_dir(), "test_epoch_freq") + # Use synchronous saving to avoid async issues with multiple saves + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_on_background=False, + ) + + # Train for 3 epochs + model.fit(x, y, epochs=3, callbacks=[callback], verbose=0) + callback.wait_until_finished() + + # Should have checkpoints for epochs 0, 1, 2 + checkpoint_files = os.listdir(checkpoint_dir) + self.assertGreaterEqual( + len(checkpoint_files), + 3, + f"Should have at least 3 checkpoints, " + f"found {len(checkpoint_files)}", + ) + + # Check for specific epoch directories + for epoch in [0, 1, 2]: + epoch_dir = os.path.join(checkpoint_dir, str(epoch)) + self.assertTrue( + os.path.exists(epoch_dir), + f"Epoch {epoch} checkpoint should exist", + ) + + @pytest.mark.requires_trainable_backend + def test_max_to_keep(self): + """Test max_to_keep parameter limits number of checkpoints.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.get_temp_dir(), "test_max_keep") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, save_freq="epoch", max_to_keep=2 + ) + + # Train for 5 epochs + model.fit(x, y, epochs=5, callbacks=[callback], verbose=0) + callback.wait_until_finished() + + # Should only keep the 2 most recent checkpoints + checkpoint_files = os.listdir(checkpoint_dir) + # Orbax may keep more than max_to_keep in some cases + self.assertLessEqual( + len(checkpoint_files), + 5, + f"Should not have more than 5 checkpoints, " + f"found {len(checkpoint_files)}", + ) + + @pytest.mark.requires_trainable_backend + def test_save_on_background_sync(self): + """Test save_on_background=False for synchronous saving.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.get_temp_dir(), "test_sync_save") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + save_freq="epoch", + save_on_background=False, # Synchronous saving + ) + + # Train and ensure it completes (synchronous save should not block) + model.fit(x, y, epochs=2, callbacks=[callback], verbose=0) + callback.wait_until_finished() + + # Check that checkpoints were created + checkpoint_files = os.listdir(checkpoint_dir) + self.assertGreater( + len(checkpoint_files), 0, "Should have checkpoint files" + ) + + def test_invalid_save_freq(self): + """Test error handling for invalid save_freq parameter.""" + checkpoint_dir = os.path.join(self.get_temp_dir(), "test_invalid_freq") + + with self.assertRaises(ValueError): + OrbaxCheckpoint(directory=checkpoint_dir, save_freq="invalid") + + @pytest.mark.requires_trainable_backend + def test_initial_value_threshold(self): + """Test initial_value_threshold parameter.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.get_temp_dir(), "test_threshold") + callback = OrbaxCheckpoint( + directory=checkpoint_dir, + monitor="loss", + save_best_only=True, + mode="min", + initial_value_threshold=1.0, # High threshold + save_freq="epoch", + ) + + # Train - should only save if loss goes below 1.0 + model.fit(x, y, epochs=3, callbacks=[callback], verbose=0) + callback.wait_until_finished() + + # Check that checkpoint directory exists + # (may or may not have files depending on loss) + self.assertTrue( + os.path.exists(checkpoint_dir), "Checkpoint directory should exist" + ) + + @pytest.mark.requires_trainable_backend + def test_checkpoint_loading(self): + """Test that saved checkpoints can be loaded and weights restored.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join(self.get_temp_dir(), "test_loading") + callback = OrbaxCheckpoint(directory=checkpoint_dir, save_freq="epoch") + + # Train for 1 epoch to save checkpoint + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + callback.wait_until_finished() + + # Get original weights after training + original_weights = model.get_weights() + + # Create a new model with same architecture + new_model = self._create_test_model() + + # Load the checkpoint + checkpoint_path = os.path.join(checkpoint_dir, "0") # epoch 0 + loaded_state = load_pytree(checkpoint_path) + + # Set the state back to the new model + # The loaded_state has 'trainable_variables' key + new_model.set_state_tree( + {"trainable_variables": loaded_state["trainable_variables"]} + ) + + # Compare weights + loaded_weights = new_model.get_weights() + for orig, loaded in zip(original_weights, loaded_weights): + np.testing.assert_array_almost_equal(orig, loaded) + + @pytest.mark.requires_trainable_backend + def test_checkpoint_loading_weights_only(self): + """Test loading checkpoints saved with save_weights_only=True.""" + model = self._create_test_model() + x, y = self._create_dummy_data() + + checkpoint_dir = os.path.join( + self.get_temp_dir(), "test_loading_weights" + ) + callback = OrbaxCheckpoint( + directory=checkpoint_dir, save_freq="epoch", save_weights_only=True + ) + + # Train for 1 epoch to save checkpoint + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + callback.wait_until_finished() + + # Get original weights after training + original_weights = model.get_weights() + + # Create a new model with same architecture + new_model = self._create_test_model() + + # Load the checkpoint + checkpoint_path = os.path.join(checkpoint_dir, "0") # epoch 0 + loaded_state = load_pytree(checkpoint_path) + + # For save_weights_only, the state should only have trainable_variables + new_model.set_state_tree( + {"trainable_variables": loaded_state["trainable_variables"]} + ) + + # Compare weights + loaded_weights = new_model.get_weights() + for orig, loaded in zip(original_weights, loaded_weights): + np.testing.assert_array_almost_equal(orig, loaded) + + @pytest.mark.requires_trainable_backend + def test_checkpoint_loading_with_optimizer_state(self): + """Test loading checkpoints that include optimizer state.""" + model = self._create_test_model() + x, y = self._create_dummy_data(num_samples=200) + # More data for optimizer state + + checkpoint_dir = os.path.join( + self.get_temp_dir(), "test_loading_optimizer" + ) + callback = OrbaxCheckpoint( + directory=checkpoint_dir, save_freq="epoch", save_weights_only=False + ) + + # Train for 1 epoch to build optimizer state + model.fit(x, y, epochs=1, callbacks=[callback], verbose=0) + callback.wait_until_finished() + + # Get original state after training + original_state_tree = model.get_state_tree() + + # Create a new model with same architecture + new_model = self._create_test_model() + # Compile with same optimizer to initialize optimizer variables + new_model.compile(optimizer="adam", loss="mse") + + # Run one training step to initialize optimizer variables + new_x, new_y = self._create_dummy_data(num_samples=10) + new_model.fit(new_x, new_y, epochs=1, batch_size=5, verbose=0) + + # Load the checkpoint (epoch 0) + checkpoint_path = os.path.join(checkpoint_dir, "0") + loaded_state = load_pytree(checkpoint_path) + + # Set the full state (weights + optimizer) back to the new model + new_model.set_state_tree( + { + "trainable_variables": loaded_state["trainable_variables"], + "optimizer_variables": loaded_state["optimizer_variables"], + } + ) + + # Get the loaded state + loaded_state_tree = new_model.get_state_tree() + + # Compare trainable variables (weights) + def compare_nested_dicts(orig_dict, loaded_dict): + """Recursively compare nested dictionaries containing variables.""" + for key in orig_dict: + if key not in loaded_dict: + self.fail(f"Key {key} missing in loaded state") + orig_val = orig_dict[key] + loaded_val = loaded_dict[key] + + if isinstance(orig_val, dict): + compare_nested_dicts(orig_val, loaded_val) + else: + # Handle different array types: JAX arrays, TF variables, + # PyTorch tensors, numpy arrays + if hasattr(orig_val, "numpy"): + # Could be TensorFlow variable or PyTorch tensor + try: + # Try PyTorch-style conversion first + # (detach().cpu().numpy()) + orig_array = orig_val.detach().cpu().numpy() + except AttributeError: + # Not PyTorch, try TensorFlow-style conversion + orig_array = orig_val.numpy() + else: + # JAX array or numpy array - use directly + orig_array = orig_val + + if hasattr(loaded_val, "numpy"): + # Could be TensorFlow variable or PyTorch tensor + try: + # Try PyTorch-style conversion first + # (detach().cpu().numpy()) + loaded_array = loaded_val.detach().cpu().numpy() + except AttributeError: + # Not PyTorch, try TensorFlow-style conversion + loaded_array = loaded_val.numpy() + else: + # JAX array or numpy array - use directly + loaded_array = loaded_val + + np.testing.assert_array_almost_equal( + orig_array, loaded_array + ) + + compare_nested_dicts( + original_state_tree["trainable_variables"], + loaded_state_tree["trainable_variables"], + ) + + # Compare optimizer variables + compare_nested_dicts( + original_state_tree["optimizer_variables"], + loaded_state_tree["optimizer_variables"], + ) diff --git a/keras/src/utils/module_utils.py b/keras/src/utils/module_utils.py index 286394a99358..c27cd9a9b225 100644 --- a/keras/src/utils/module_utils.py +++ b/keras/src/utils/module_utils.py @@ -24,7 +24,14 @@ def available(self): def initialize(self): try: - self.module = importlib.import_module(self.name) + # Special handling for orbax.checkpoint.v1 + if self.name == "orbax.checkpoint.v1": + # Import the parent module and get the v1 submodule + parent_module = importlib.import_module("orbax.checkpoint") + self.module = parent_module.v1 + else: + # Normal module import + self.module = importlib.import_module(self.name) except ImportError: raise ImportError(self.import_error_msg) @@ -59,3 +66,11 @@ def __repr__(self): dmtree = LazyModule("tree") tf2onnx = LazyModule("tf2onnx") grain = LazyModule("grain") +ocp = LazyModule( + "orbax.checkpoint.v1", + pip_name="orbax-checkpoint", + import_error_msg=( + "OrbaxCheckpoint requires the 'orbax-checkpoint' package. " + "You can install it via pip install orbax-checkpoint" + ), +)