Skip to content

Commit 4c1fdd1

Browse files
authored
Merge pull request #26 from foundation-model-stack/validation_info_save_when_not_found
fixed bug where cpu validation information is saved, even when loaded
2 parents 24b02da + 970d095 commit 4c1fdd1

File tree

1 file changed

+14
-13
lines changed

1 file changed

+14
-13
lines changed

tests/models/test_decoders.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,7 @@ def __load_validation_info(
247247
)
248248

249249
if os.path.exists(full_path):
250+
dprint(f"cpu validation info found for seed={seed} -- loading it")
250251
return load_validation_information(full_path, "logits", batch_size, tokenizer)
251252
else:
252253
return None
@@ -368,12 +369,12 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
368369
**padding_kwargs,
369370
)
370371

371-
if save_validation_info_outputs:
372-
cpu_validation_info.save(
373-
__get_validation_info_full_path(
374-
model_path, batch_size, seq_length, max_new_tokens, 0
372+
if save_validation_info_outputs:
373+
cpu_validation_info.save(
374+
__get_validation_info_full_path(
375+
model_path, batch_size, seq_length, max_new_tokens, 0
376+
)
375377
)
376-
)
377378
cpu_static_tokens = cpu_validation_info.get_info("tokens")
378379
eos_indexes = __find_eos_index(
379380
cpu_static_tokens, tokenizer.eos_token_id, seq_length, max_new_tokens
@@ -430,15 +431,15 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
430431
attn_algorithm="math",
431432
**padding_kwargs,
432433
)
433-
dprint(
434-
f"cpu validation info extracted for validation level 1 - iter={i}"
435-
)
436-
if save_validation_info_outputs:
437-
cpu_validation_info.save(
438-
__get_validation_info_full_path(
439-
model_path, batch_size, seq_length, max_new_tokens, i
440-
)
434+
dprint(
435+
f"cpu validation info extracted for validation level 1 - iter={i}"
441436
)
437+
if save_validation_info_outputs:
438+
cpu_validation_info.save(
439+
__get_validation_info_full_path(
440+
model_path, batch_size, seq_length, max_new_tokens, i
441+
)
442+
)
442443
cpu_static_tokens = cpu_validation_info.get_info("tokens")
443444
eos_indexes = __find_eos_index(
444445
cpu_static_tokens,

0 commit comments

Comments
 (0)