Skip to content

Commit 6d9311b

Browse files
Use save_weights_and_get_sampling_client_async in rl/train.py (#84)
1 parent b58c178 commit 6d9311b

File tree

2 files changed

+12
-9
lines changed

2 files changed

+12
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ dependencies = [
1313
"numpy",
1414
"rich",
1515
"termcolor",
16-
"tinker",
16+
"tinker>=0.3.0",
1717
"torch",
1818
"transformers",
1919
"blobfile",

tinker_cookbook/rl/train.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -686,14 +686,17 @@ async def save_checkpoint_and_get_sampling_client(
686686
) -> tuple[tinker.SamplingClient, dict[str, Any]]:
687687
metrics = {}
688688
with timed("save_checkpoint", metrics):
689-
path_dict = await checkpoint_utils.save_checkpoint_async(
690-
training_client=training_client,
691-
name=f"{i_batch:06d}",
692-
log_path=log_path,
693-
loop_state={"batch": i_batch},
694-
kind="both" if (i_batch > start_batch and i_batch % save_every == 0) else "sampler",
695-
)
696-
return training_client.create_sampling_client(path_dict["sampler_path"]), metrics
689+
if i_batch > start_batch and i_batch % save_every == 0:
690+
path_dict = await checkpoint_utils.save_checkpoint_async(
691+
training_client=training_client,
692+
name=f"{i_batch:06d}",
693+
log_path=log_path,
694+
loop_state={"batch": i_batch},
695+
kind="both",
696+
)
697+
return training_client.create_sampling_client(path_dict["sampler_path"]), metrics
698+
else:
699+
return await training_client.save_weights_and_get_sampling_client_async(), metrics
697700

698701

699702
@scope

0 commit comments

Comments
 (0)