Skip to content

Commit aeb7c9d

Browse files
authored
Merge pull request #59 from foundation-model-stack/skip_assertions
added an option to skip assertions in validation
2 parents ec48353 + 5370e26 commit aeb7c9d

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

tests/models/test_decoders.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1"
4949
USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1"
5050
FORCE_VALIDATION_LEVEL_1 = os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1"
51+
skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {})
5152
validation_info_dir = os.environ.get(
5253
"FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info"
5354
)
@@ -99,6 +100,16 @@
99100
if isinstance(common_max_new_tokens, str):
100101
common_max_new_tokens = [int(mnt) for mnt in common_max_new_tokens.split(",")]
101102

103+
# pass metrics to skip as a comma separated list (ce,mean_diff)
104+
if isinstance(skip_assertions, str):
105+
_skip_assertions = []
106+
for metric in skip_assertions.split(","):
107+
metric = metric.lower()
108+
if metric not in {"ce", "mean_diff"}:
109+
pytest.fail("FMS_TEST_SHAPES_SKIP_ASSERTIONS can only accept metrics ce and mean_diff")
110+
_skip_assertions.append(metric)
111+
skip_assertions = set(_skip_assertions)
112+
102113
common_shapes = list(
103114
itertools.product(
104115
common_model_paths,
@@ -521,12 +532,14 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
521532
ce_failure_rate = len(ce_fail_responses_list) / total_tokens
522533
dprint(f"mean diff failure rate: {diff_failure_rate}")
523534
dprint(f"cross entropy loss failure rate: {ce_failure_rate}")
524-
assert diff_failure_rate < failure_rate_threshold, (
525-
f"failure rate for mean diff was too high: {diff_failure_rate}"
526-
)
527-
assert ce_failure_rate < failure_rate_threshold, (
528-
f"failure rate for cross entropy loss was too high: {ce_failure_rate}"
529-
)
535+
if "mean_diff" not in skip_assertions:
536+
assert diff_failure_rate < failure_rate_threshold, (
537+
f"failure rate for mean diff was too high: {diff_failure_rate}"
538+
)
539+
if "ce" not in skip_assertions:
540+
assert ce_failure_rate < failure_rate_threshold, (
541+
f"failure rate for cross entropy loss was too high: {ce_failure_rate}"
542+
)
530543

531544
print("passed validation level 1")
532545
else:

0 commit comments

Comments
 (0)