|
48 | 48 | USE_MICRO_MODELS = os.environ.get("FMS_TEST_SHAPES_USE_MICRO_MODELS", "1") == "1" |
49 | 49 | USE_DISTRIBUTED = os.environ.get("FMS_TEST_SHAPES_DISTRIBUTED", "0") == "1" |
50 | 50 | FORCE_VALIDATION_LEVEL_1 = os.environ.get("FMS_TEST_SHAPES_FORCE_VALIDATION_LEVEL_1", "0") == "1" |
| 51 | +skip_assertions = os.environ.get("FMS_TEST_SHAPES_SKIP_ASSERTIONS", {}) |
51 | 52 | validation_info_dir = os.environ.get( |
52 | 53 | "FMS_TEST_SHAPES_VALIDATION_INFO_DIR", "/tmp/models/validation_info" |
53 | 54 | ) |
|
99 | 100 | if isinstance(common_max_new_tokens, str): |
100 | 101 | common_max_new_tokens = [int(mnt) for mnt in common_max_new_tokens.split(",")] |
101 | 102 |
|
| 103 | +# pass metrics to skip as a comma separated list (ce,mean_diff) |
| 104 | +if isinstance(skip_assertions, str): |
| 105 | + _skip_assertions = [] |
| 106 | + for metric in skip_assertions.split(","): |
| 107 | + metric = metric.lower() |
| 108 | + if metric not in {"ce", "mean_diff"}: |
| 109 | + pytest.fail("FMS_TEST_SHAPES_SKIP_ASSERTIONS can only accept metrics ce and mean_diff") |
| 110 | + _skip_assertions.append(metric) |
| 111 | + skip_assertions = set(_skip_assertions) |
| 112 | + |
102 | 113 | common_shapes = list( |
103 | 114 | itertools.product( |
104 | 115 | common_model_paths, |
@@ -521,12 +532,14 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
521 | 532 | ce_failure_rate = len(ce_fail_responses_list) / total_tokens |
522 | 533 | dprint(f"mean diff failure rate: {diff_failure_rate}") |
523 | 534 | dprint(f"cross entropy loss failure rate: {ce_failure_rate}") |
524 | | - assert diff_failure_rate < failure_rate_threshold, ( |
525 | | - f"failure rate for mean diff was too high: {diff_failure_rate}" |
526 | | - ) |
527 | | - assert ce_failure_rate < failure_rate_threshold, ( |
528 | | - f"failure rate for cross entropy loss was too high: {ce_failure_rate}" |
529 | | - ) |
| 535 | + if "mean_diff" not in skip_assertions: |
| 536 | + assert diff_failure_rate < failure_rate_threshold, ( |
| 537 | + f"failure rate for mean diff was too high: {diff_failure_rate}" |
| 538 | + ) |
| 539 | + if "ce" not in skip_assertions: |
| 540 | + assert ce_failure_rate < failure_rate_threshold, ( |
| 541 | + f"failure rate for cross entropy loss was too high: {ce_failure_rate}" |
| 542 | + ) |
530 | 543 |
|
531 | 544 | print("passed validation level 1") |
532 | 545 | else: |
|
0 commit comments