File tree Expand file tree Collapse file tree 2 files changed +12
-9
lines changed Expand file tree Collapse file tree 2 files changed +12
-9
lines changed Original file line number Diff line number Diff line change @@ -13,7 +13,7 @@ dependencies = [
1313 " numpy" ,
1414 " rich" ,
1515 " termcolor" ,
16- " tinker" ,
16+ " tinker>=0.3.0 " ,
1717 " torch" ,
1818 " transformers" ,
1919 " blobfile" ,
Original file line number Diff line number Diff line change @@ -686,14 +686,17 @@ async def save_checkpoint_and_get_sampling_client(
686686) -> tuple [tinker .SamplingClient , dict [str , Any ]]:
687687 metrics = {}
688688 with timed ("save_checkpoint" , metrics ):
689- path_dict = await checkpoint_utils .save_checkpoint_async (
690- training_client = training_client ,
691- name = f"{ i_batch :06d} " ,
692- log_path = log_path ,
693- loop_state = {"batch" : i_batch },
694- kind = "both" if (i_batch > start_batch and i_batch % save_every == 0 ) else "sampler" ,
695- )
696- return training_client .create_sampling_client (path_dict ["sampler_path" ]), metrics
689+ if i_batch > start_batch and i_batch % save_every == 0 :
690+ path_dict = await checkpoint_utils .save_checkpoint_async (
691+ training_client = training_client ,
692+ name = f"{ i_batch :06d} " ,
693+ log_path = log_path ,
694+ loop_state = {"batch" : i_batch },
695+ kind = "both" ,
696+ )
697+ return training_client .create_sampling_client (path_dict ["sampler_path" ]), metrics
698+ else :
699+ return await training_client .save_weights_and_get_sampling_client_async (), metrics
697700
698701
699702@scope
You can’t perform that action at this time.
0 commit comments