diff --git a/pyproject.toml b/pyproject.toml index b7601b7..604788d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ dependencies = [ "numpy", "rich", "termcolor", - "tinker", + "tinker>=0.3.0", "torch", "transformers", "blobfile", diff --git a/tinker_cookbook/rl/train.py b/tinker_cookbook/rl/train.py index 9d81b5e..10857e8 100644 --- a/tinker_cookbook/rl/train.py +++ b/tinker_cookbook/rl/train.py @@ -686,14 +686,17 @@ async def save_checkpoint_and_get_sampling_client( ) -> tuple[tinker.SamplingClient, dict[str, Any]]: metrics = {} with timed("save_checkpoint", metrics): - path_dict = await checkpoint_utils.save_checkpoint_async( - training_client=training_client, - name=f"{i_batch:06d}", - log_path=log_path, - loop_state={"batch": i_batch}, - kind="both" if (i_batch > start_batch and i_batch % save_every == 0) else "sampler", - ) - return training_client.create_sampling_client(path_dict["sampler_path"]), metrics + if i_batch > start_batch and i_batch % save_every == 0: + path_dict = await checkpoint_utils.save_checkpoint_async( + training_client=training_client, + name=f"{i_batch:06d}", + log_path=log_path, + loop_state={"batch": i_batch}, + kind="both", + ) + return training_client.create_sampling_client(path_dict["sampler_path"]), metrics + else: + return await training_client.save_weights_and_get_sampling_client_async(), metrics @scope