From 707210de71f5753e6a6332e7c91abb5aea7db3ab Mon Sep 17 00:00:00 2001 From: Andrii Grynenko Date: Mon, 10 Nov 2025 16:32:47 -0800 Subject: [PATCH 1/2] Use save_weights_and_get_sampling_client_async in rl/train.py --- tinker_cookbook/rl/train.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 9d81b5e..10857e8 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -686,14 +686,17 @@ async def save_checkpoint_and_get_sampling_client( ) -> tuple[tinker.SamplingClient, dict[str, Any]]: metrics = {} with timed("save_checkpoint", metrics): - path_dict = await checkpoint_utils.save_checkpoint_async( - training_client=training_client, - name=f"{i_batch:06d}", - log_path=log_path, - loop_state={"batch": i_batch}, - kind="both" if (i_batch > start_batch and i_batch % save_every == 0) else "sampler", - ) - return training_client.create_sampling_client(path_dict["sampler_path"]), metrics + if i_batch > start_batch and i_batch % save_every == 0: + path_dict = await checkpoint_utils.save_checkpoint_async( + training_client=training_client, + name=f"{i_batch:06d}", + log_path=log_path, + loop_state={"batch": i_batch}, + kind="both", + ) + return training_client.create_sampling_client(path_dict["sampler_path"]), metrics + else: + return await training_client.save_weights_and_get_sampling_client_async(), metrics @scope From 8d027c691df4dc916f1d4f7b1ee9444962f01dcd Mon Sep 17 00:00:00 2001 From: Andrii Grynenko Date: Tue, 11 Nov 2025 14:32:10 -0800 Subject: [PATCH 2/2] Update tinker dependency version in pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b7601b7..604788d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "numpy", "rich", "termcolor", - "tinker", + "tinker>=0.3.0", "torch", "transformers", "blobfile",