44import numpy as np
55
66from keras .src import backend
7- from keras .src import ops
87from keras .src import tree
98from keras .src .api_export import keras_export
109from keras .src .callbacks .monitor_callback import (
@@ -33,11 +32,6 @@ def convert_scalars(obj):
3332 return tree .map_structure (convert_scalars , state_tree )
3433
3534
36- def _flatten_state_tree_values (state_tree ):
37- """Flatten nested state tree into a list of values in consistent order."""
38- return tree .flatten (state_tree )
39-
40-
4135def _reconstruct_state_tree_with_values (structure , values ):
4236 """Reconstruct state tree structure with provided values."""
4337 value_iter = iter (values )
@@ -62,64 +56,14 @@ def _reconstruct_value(obj):
6256 return np .array (value , dtype = obj .dtype )
6357 elif isinstance (obj , np .ndarray ):
6458 # obj is a numpy array
65- if isinstance (value , np .ndarray ):
66- return value .astype (obj .dtype ).reshape (obj .shape )
67- else :
68- return np .array (value , dtype = obj .dtype ).reshape (obj .shape )
59+ # Use backend-specific conversion that handles JAX arrays properly
60+ return backend .convert_checkpoint_value (value , obj .dtype , obj .shape )
6961 else :
7062 return value
7163
7264 return tree .map_structure (_reconstruct_value , structure )
7365
7466
75- def _restore_legacy_format (
76- checkpoint_data , target_model , save_optimizer_state , save_metrics_state
77- ):
78- """Restore from the old flat format for backward compatibility."""
79- # Restore model weights
80- if "model_weights" in checkpoint_data :
81- model_weights_np = checkpoint_data ["model_weights" ]
82- # Convert NumPy arrays back to backend tensors and assign to
83- # model
84- for i , weight_np in enumerate (model_weights_np ):
85- # Convert numpy array back to appropriate backend tensor
86- weight_tensor = ops .convert_to_tensor (weight_np )
87- target_model .weights [i ].assign (weight_tensor )
88-
89- # Restore optimizer state if available
90- if "optimizer_state" in checkpoint_data and save_optimizer_state :
91- optimizer_vars_np = checkpoint_data ["optimizer_state" ]
92- # Only restore if the variable counts match
93- if len (optimizer_vars_np ) == len (target_model .optimizer .variables ):
94- # Convert NumPy arrays back to backend tensors and assign to
95- # optimizer
96- for i , var_np in enumerate (optimizer_vars_np ):
97- var_tensor = ops .convert_to_tensor (var_np )
98- target_model .optimizer .variables [i ].assign (var_tensor )
99-
100- # Restore metrics state if available
101- if (
102- "metrics_state" in checkpoint_data
103- and save_metrics_state
104- and hasattr (target_model , "metrics" )
105- ):
106- metrics_vars_np = checkpoint_data ["metrics_state" ]
107- metric_idx = 0
108- for metric in target_model .metrics :
109- if (
110- hasattr (metric , "variables" )
111- and metric .variables
112- and metric_idx < len (metrics_vars_np )
113- ):
114- metric_vars_np = metrics_vars_np [metric_idx ]
115- # Restore metric variables
116- for i , var_np in enumerate (metric_vars_np ):
117- if i < len (metric .variables ):
118- var_tensor = ops .convert_to_tensor (var_np )
119- metric .variables [i ].assign (var_tensor )
120- metric_idx += 1
121-
122-
12367@keras_export ("keras.callbacks.OrbaxCheckpoint" )
12468class OrbaxCheckpoint (MonitorCallback ):
12569 """Callback to save and load model state using Orbax with a similar API to
@@ -574,14 +518,8 @@ def _restore_model_state(self, checkpoint_data, model=None):
574518 checkpoint_data ["model_state" ], target_model
575519 )
576520 else :
577- # Fallback to legacy format
578- _restore_legacy_format (
579- checkpoint_data ,
580- target_model ,
581- self .save_optimizer_state ,
582- self .save_metrics_state ,
583- )
584- return True
521+ # Unsupported checkpoint format
522+ return False
585523
586524 def _restore_from_nested_structures (self , checkpoint_data , target_model ):
587525 """Restore from the new nested structures format."""
0 commit comments