Skip to content

Commit d7884ef

Browse files
added checkpointer.wait()
1 parent 33f4e66 commit d7884ef

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

keras/src/callbacks/orbax_checkpoint.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,14 @@ def wait_until_finished(self):
468468
checkpoints if there might be pending save operations.
469469
"""
470470
# Wait for any async operations to complete
471-
while self.checkpointer.is_saving_in_progress():
472-
import time
471+
try:
472+
self.checkpointer.wait()
473+
except AttributeError:
474+
# Fallback for older Orbax versions that don't have wait() method
475+
while self.checkpointer.is_saving_in_progress():
476+
import time
473477

474-
time.sleep(0.1)
478+
time.sleep(0.1)
475479

476480
def _restore_model_state_from_full_tree(self, state_tree, model=None):
477481
"""Restore model state from full state tree (V1 format)."""

0 commit comments

Comments
 (0)