Skip to content

Commit 9aedece

Browse files
authored
fix resume checkpoint conflict in frontend (#59)
1 parent c0179eb commit 9aedece

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

tinker_cookbook/rl/train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ async def do_sync_training_with_stream_minibatch(
324324
"""
325325
# Initial sampling client
326326
sampling_client, _ = await save_checkpoint_and_get_sampling_client(
327-
training_client, start_batch, cfg.log_path, cfg.save_every
327+
training_client, start_batch, cfg.log_path, cfg.save_every, start_batch
328328
)
329329

330330
for i_batch in range(start_batch, end_batch):
@@ -680,6 +680,7 @@ async def save_checkpoint_and_get_sampling_client(
680680
i_batch: int,
681681
log_path: str,
682682
save_every: int,
683+
start_batch: int = 0,
683684
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
684685
metrics = {}
685686
with timed("save_checkpoint", metrics):
@@ -688,7 +689,7 @@ async def save_checkpoint_and_get_sampling_client(
688689
name=f"{i_batch:06d}",
689690
log_path=log_path,
690691
loop_state={"batch": i_batch},
691-
kind="both" if (i_batch > 0 and i_batch % save_every == 0) else "sampler",
692+
kind="both" if (i_batch > start_batch and i_batch % save_every == 0) else "sampler",
692693
)
693694
return training_client.create_sampling_client(path_dict["sampler_path"]), metrics
694695

@@ -949,7 +950,7 @@ async def do_sync_training(
949950
"""Implements fully synchronous on-policy training"""
950951
# Initial sampling client
951952
sampling_client, _ = await save_checkpoint_and_get_sampling_client(
952-
training_client, start_batch, cfg.log_path, cfg.save_every
953+
training_client, start_batch, cfg.log_path, cfg.save_every, start_batch
953954
)
954955

955956
for i_batch in range(start_batch, end_batch):

0 commit comments

Comments
 (0)