@@ -247,6 +247,7 @@ def __load_validation_info(
247247 )
248248
249249 if os .path .exists (full_path ):
250+ dprint (f"cpu validation info found for seed={ seed } -- loading it" )
250251 return load_validation_information (full_path , "logits" , batch_size , tokenizer )
251252 else :
252253 return None
@@ -368,12 +369,12 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
368369 ** padding_kwargs ,
369370 )
370371
371- if save_validation_info_outputs :
372- cpu_validation_info .save (
373- __get_validation_info_full_path (
374- model_path , batch_size , seq_length , max_new_tokens , 0
372+ if save_validation_info_outputs :
373+ cpu_validation_info .save (
374+ __get_validation_info_full_path (
375+ model_path , batch_size , seq_length , max_new_tokens , 0
376+ )
375377 )
376- )
377378 cpu_static_tokens = cpu_validation_info .get_info ("tokens" )
378379 eos_indexes = __find_eos_index (
379380 cpu_static_tokens , tokenizer .eos_token_id , seq_length , max_new_tokens
@@ -430,15 +431,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
430431 attn_algorithm = "math" ,
431432 ** padding_kwargs ,
432433 )
433- dprint (
434- f"cpu validation info extracted for validation level 1 - iter={ i } "
435- )
436- if save_validation_info_outputs :
437- cpu_validation_info .save (
438- __get_validation_info_full_path (
439- model_path , batch_size , seq_length , max_new_tokens , i
440- )
434+ dprint (
435+ f"cpu validation info extracted for validation level 1 - iter={ i } "
441436 )
437+ if save_validation_info_outputs :
438+ cpu_validation_info .save (
439+ __get_validation_info_full_path (
440+ model_path , batch_size , seq_length , max_new_tokens , i
441+ )
442+ )
442443 cpu_static_tokens = cpu_validation_info .get_info ("tokens" )
443444 eos_indexes = __find_eos_index (
444445 cpu_static_tokens ,
0 commit comments