Skip to content

Commit 7722e30

Browse files
Fixed review comments
1 parent b56dc7b commit 7722e30

File tree

7 files changed

+78
-66
lines changed

7 files changed

+78
-66
lines changed

keras/src/backend/jax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from keras.src.backend.jax.core import cast
1515
from keras.src.backend.jax.core import compute_output_spec
1616
from keras.src.backend.jax.core import cond
17+
from keras.src.backend.jax.core import convert_checkpoint_value
1718
from keras.src.backend.jax.core import convert_to_numpy
1819
from keras.src.backend.jax.core import convert_to_tensor
1920
from keras.src.backend.jax.core import device_scope

keras/src/backend/jax/core.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,35 @@ def device_scope(device_name):
572572
else:
573573
jax_device = device_name
574574
return jax.default_device(jax_device)
575+
576+
577+
def convert_checkpoint_value(value, dtype, shape):
578+
"""Convert a value for checkpoint restoration, preserving JAX arrays for
579+
sharding.
580+
581+
This function handles the special case of checkpoint restoration where JAX
582+
arrays should be preserved for sharding support, while other values are
583+
converted to JAX arrays with the specified dtype and shape.
584+
585+
Args:
586+
value: The value to convert (can be JAX array, numpy array, or other
587+
types)
588+
dtype: The target dtype
589+
shape: The target shape
590+
591+
Returns:
592+
A JAX array with the specified dtype and shape, or the original JAX
593+
array if it was already a JAX array.
594+
"""
595+
# For JAX backend, preserve JAX arrays for sharding support
596+
if hasattr(value, "__array_namespace__") or str(type(value)).startswith(
597+
"<class 'jax"
598+
):
599+
# value is already a JAX array, return as-is to preserve sharding
600+
return value
601+
elif isinstance(value, np.ndarray):
602+
# Convert numpy array to JAX array
603+
return jnp.array(value).astype(dtype).reshape(shape)
604+
else:
605+
# Convert other types to JAX array
606+
return jnp.array(value, dtype=dtype).reshape(shape)

keras/src/backend/tensorflow/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from keras.src.backend.tensorflow.core import cast
1414
from keras.src.backend.tensorflow.core import compute_output_spec
1515
from keras.src.backend.tensorflow.core import cond
16+
from keras.src.backend.tensorflow.core import convert_checkpoint_value
1617
from keras.src.backend.tensorflow.core import convert_to_numpy
1718
from keras.src.backend.tensorflow.core import convert_to_tensor
1819
from keras.src.backend.tensorflow.core import device_scope

keras/src/backend/tensorflow/core.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,3 +696,23 @@ def __exit__(self, *args, **kwargs):
696696

697697
def device_scope(device_name):
698698
return tf.device(device_name)
699+
700+
701+
def convert_checkpoint_value(value, dtype, shape):
702+
"""Convert a value for checkpoint restoration.
703+
704+
For TensorFlow backend, convert to numpy arrays with specified dtype and
705+
shape.
706+
707+
Args:
708+
value: The value to convert
709+
dtype: The target dtype
710+
shape: The target shape
711+
712+
Returns:
713+
A numpy array with the specified dtype and shape.
714+
"""
715+
if isinstance(value, np.ndarray):
716+
return value.astype(dtype).reshape(shape)
717+
else:
718+
return np.array(value, dtype=dtype).reshape(shape)

keras/src/backend/torch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from keras.src.backend.torch.core import cast
3030
from keras.src.backend.torch.core import compute_output_spec
3131
from keras.src.backend.torch.core import cond
32+
from keras.src.backend.torch.core import convert_checkpoint_value
3233
from keras.src.backend.torch.core import convert_to_numpy
3334
from keras.src.backend.torch.core import convert_to_tensor
3435
from keras.src.backend.torch.core import device_scope

keras/src/backend/torch/core.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,3 +730,22 @@ def backward(ctx, grad_output):
730730
if not isinstance(grads, tuple):
731731
grads = (grads,)
732732
return (None,) + grads
733+
734+
735+
def convert_checkpoint_value(value, dtype, shape):
736+
"""Convert a value for checkpoint restoration.
737+
738+
For PyTorch backend, convert to numpy arrays with specified dtype and shape.
739+
740+
Args:
741+
value: The value to convert
742+
dtype: The target dtype
743+
shape: The target shape
744+
745+
Returns:
746+
A numpy array with the specified dtype and shape.
747+
"""
748+
if isinstance(value, np.ndarray):
749+
return value.astype(dtype).reshape(shape)
750+
else:
751+
return np.array(value, dtype=dtype).reshape(shape)

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 4 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import numpy as np
55

66
from keras.src import backend
7-
from keras.src import ops
87
from keras.src import tree
98
from keras.src.api_export import keras_export
109
from keras.src.callbacks.monitor_callback import (
@@ -33,11 +32,6 @@ def convert_scalars(obj):
3332
return tree.map_structure(convert_scalars, state_tree)
3433

3534

36-
def _flatten_state_tree_values(state_tree):
37-
"""Flatten nested state tree into a list of values in consistent order."""
38-
return tree.flatten(state_tree)
39-
40-
4135
def _reconstruct_state_tree_with_values(structure, values):
4236
"""Reconstruct state tree structure with provided values."""
4337
value_iter = iter(values)
@@ -62,64 +56,14 @@ def _reconstruct_value(obj):
6256
return np.array(value, dtype=obj.dtype)
6357
elif isinstance(obj, np.ndarray):
6458
# obj is a numpy array
65-
if isinstance(value, np.ndarray):
66-
return value.astype(obj.dtype).reshape(obj.shape)
67-
else:
68-
return np.array(value, dtype=obj.dtype).reshape(obj.shape)
59+
# Use backend-specific conversion that handles JAX arrays properly
60+
return backend.convert_checkpoint_value(value, obj.dtype, obj.shape)
6961
else:
7062
return value
7163

7264
return tree.map_structure(_reconstruct_value, structure)
7365

7466

75-
def _restore_legacy_format(
76-
checkpoint_data, target_model, save_optimizer_state, save_metrics_state
77-
):
78-
"""Restore from the old flat format for backward compatibility."""
79-
# Restore model weights
80-
if "model_weights" in checkpoint_data:
81-
model_weights_np = checkpoint_data["model_weights"]
82-
# Convert NumPy arrays back to backend tensors and assign to
83-
# model
84-
for i, weight_np in enumerate(model_weights_np):
85-
# Convert numpy array back to appropriate backend tensor
86-
weight_tensor = ops.convert_to_tensor(weight_np)
87-
target_model.weights[i].assign(weight_tensor)
88-
89-
# Restore optimizer state if available
90-
if "optimizer_state" in checkpoint_data and save_optimizer_state:
91-
optimizer_vars_np = checkpoint_data["optimizer_state"]
92-
# Only restore if the variable counts match
93-
if len(optimizer_vars_np) == len(target_model.optimizer.variables):
94-
# Convert NumPy arrays back to backend tensors and assign to
95-
# optimizer
96-
for i, var_np in enumerate(optimizer_vars_np):
97-
var_tensor = ops.convert_to_tensor(var_np)
98-
target_model.optimizer.variables[i].assign(var_tensor)
99-
100-
# Restore metrics state if available
101-
if (
102-
"metrics_state" in checkpoint_data
103-
and save_metrics_state
104-
and hasattr(target_model, "metrics")
105-
):
106-
metrics_vars_np = checkpoint_data["metrics_state"]
107-
metric_idx = 0
108-
for metric in target_model.metrics:
109-
if (
110-
hasattr(metric, "variables")
111-
and metric.variables
112-
and metric_idx < len(metrics_vars_np)
113-
):
114-
metric_vars_np = metrics_vars_np[metric_idx]
115-
# Restore metric variables
116-
for i, var_np in enumerate(metric_vars_np):
117-
if i < len(metric.variables):
118-
var_tensor = ops.convert_to_tensor(var_np)
119-
metric.variables[i].assign(var_tensor)
120-
metric_idx += 1
121-
122-
12367
@keras_export("keras.callbacks.OrbaxCheckpoint")
12468
class OrbaxCheckpoint(MonitorCallback):
12569
"""Callback to save and load model state using Orbax with a similar API to
@@ -574,14 +518,8 @@ def _restore_model_state(self, checkpoint_data, model=None):
574518
checkpoint_data["model_state"], target_model
575519
)
576520
else:
577-
# Fallback to legacy format
578-
_restore_legacy_format(
579-
checkpoint_data,
580-
target_model,
581-
self.save_optimizer_state,
582-
self.save_metrics_state,
583-
)
584-
return True
521+
# Unsupported checkpoint format
522+
return False
585523

586524
def _restore_from_nested_structures(self, checkpoint_data, target_model):
587525
"""Restore from the new nested structures format."""

0 commit comments

Comments
 (0)