Skip to content

Commit 5c7c921

Browse files
authored
Merge pull request #27 from foundation-model-stack/force_validation_level_1
Add option to force validation level 1 testing in test_decoders
2 parents 6e7e360 + c1d05de commit 5c7c921

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

tests/models/test_decoders.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
)
4545
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
4646
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
47+
FORCE_VALIDATION_LEVEL_1 = os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
4748
validation_info_dir = os.environ.get(
4849
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info"
4950
)
@@ -394,9 +395,15 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
394395
aiu_validation_info.get_info("tokens"), cpu_static_tokens
395396
)
396397

398+
failed_validation_level_0 = len(failed_responses) != 0
399+
397400
# if level 0 fails validation, validate level 1
398-
if len(failed_responses) != 0:
399-
print("failed validation level 0, testing validation level 1")
401+
if FORCE_VALIDATION_LEVEL_1 or failed_validation_level_0:
402+
403+
if failed_validation_level_0:
404+
dprint("failed validation level 0, testing validation level 1")
405+
else:
406+
dprint("passed validation level 0, testing validation level 1")
400407

401408
# metric calculator based on the cross-entropy and mean diff for each decode step
402409
def _metric_calculator(r: torch.Tensor, t: torch.Tensor):

0 commit comments

Comments
 (0)