@@ -154,18 +154,18 @@ def __init__(
154154
155155 # Set up save_decision_policy if not provided
156156 if save_decision_policy is None :
157- if save_freq == "epoch" :
158- # For epoch-based saving, save every epoch
159- save_decision_policy = (
160- ocp . training . save_decision_policies . FixedIntervalPolicy ( 1 )
161- )
162- else :
163- # For batch-based saving, save every save_freq batches
164- save_decision_policy = (
165- ocp . training . save_decision_policies . FixedIntervalPolicy (
166- save_freq
167- )
168- )
157+ # Let Keras handle all save decisions - configure Checkpointer
158+ # to save unconditionally when save_pytree/save_pytree_async
159+ # is called
160+ class _AlwaysSavePolicy (
161+ ocp . training . save_decision_policies . SaveDecisionPolicy
162+ ) :
163+ def should_save (
164+ self , current_step_info , previous_steps = None , context = None
165+ ):
166+ return True
167+
168+ save_decision_policy = _AlwaysSavePolicy ( )
169169
170170 # --- Orbax Checkpointer Setup (V1 API) ---
171171 # Map V0 options to V1 parameters
@@ -281,7 +281,8 @@ def _save_checkpoint(self, step, logs=None):
281281
282282 # --- Save Logic (V1 API) ---
283283 # All processes participate in distributed checkpointing
284- # No wait loop needed. The Checkpointer handles overlapping saves.
284+ # Checkpointer is configured to save unconditionally when
285+ # save_pytree is called
285286 if self .verbose > 0 :
286287 print_msg (
287288 f"OrbaxCheckpoint: Triggering async save for step { step } ..."
@@ -360,8 +361,8 @@ def on_epoch_end(self, epoch, logs=None):
360361
361362 if should_save :
362363 # Use epoch number as the step for Orbax save
363- # The Checkpointer will decide if it *actually* saves
364- # based on its internal SaveDecisionPolicy.
364+ # Keras has already made the save decision - Checkpointer will
365+ # save unconditionally
365366 self ._save_checkpoint (step = epoch , logs = logs )
366367
367368 def on_train_end (self , logs = None ):
0 commit comments