@@ -324,7 +324,7 @@ async def do_sync_training_with_stream_minibatch(
324324 """
325325 # Initial sampling client
326326 sampling_client , _ = await save_checkpoint_and_get_sampling_client (
327- training_client , start_batch , cfg .log_path , cfg .save_every
327+ training_client , start_batch , cfg .log_path , cfg .save_every , start_batch
328328 )
329329
330330 for i_batch in range (start_batch , end_batch ):
@@ -680,6 +680,7 @@ async def save_checkpoint_and_get_sampling_client(
680680 i_batch : int ,
681681 log_path : str ,
682682 save_every : int ,
683+ start_batch : int = 0 ,
683684) -> tuple [tinker .SamplingClient , dict [str , Any ]]:
684685 metrics = {}
685686 with timed ("save_checkpoint" , metrics ):
@@ -688,7 +689,7 @@ async def save_checkpoint_and_get_sampling_client(
688689 name = f"{ i_batch :06d} " ,
689690 log_path = log_path ,
690691 loop_state = {"batch" : i_batch },
691- kind = "both" if (i_batch > 0 and i_batch % save_every == 0 ) else "sampler" ,
692+ kind = "both" if (i_batch > start_batch and i_batch % save_every == 0 ) else "sampler" ,
692693 )
693694 return training_client .create_sampling_client (path_dict ["sampler_path" ]), metrics
694695
@@ -949,7 +950,7 @@ async def do_sync_training(
949950 """Implements fully synchronous on-policy training"""
950951 # Initial sampling client
951952 sampling_client , _ = await save_checkpoint_and_get_sampling_client (
952- training_client , start_batch , cfg .log_path , cfg .save_every
953+ training_client , start_batch , cfg .log_path , cfg .save_every , start_batch
953954 )
954955
955956 for i_batch in range (start_batch , end_batch ):
0 commit comments