|
136 | 136 | # thresholds are chosen based on 1024 tokens per sequence |
137 | 137 | # 1% error threshold rate between cpu fp32 and cuda fp16 |
138 | 138 | # if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above |
139 | | -# threshold key is (model_id, is_tiny_model) |
| 139 | +# threshold key is model_id |
140 | 140 | fail_thresholds = { |
141 | | - (LLAMA_3p1_8B_INSTRUCT, True): ( |
142 | | - 3.7392955756187423, |
143 | | - 0.001, # FIXME: compute |
144 | | - ), |
145 | | - (GRANITE_3p2_8B_INSTRUCT, True): ( |
146 | | - 2.996668996810913, |
147 | | - 0.001, # FIXME: compute |
148 | | - ), |
149 | | - (GRANITE_20B_CODE_INSTRUCT_8K, True): ( |
150 | | - 3.7392955756187423, # FIXME: compute -- setting to micro llama 3.1 8b instruct |
151 | | - 0.001, # FIXME: compute |
152 | | - ), |
153 | | - (LLAMA_3p1_70B_INSTRUCT, True): ( |
154 | | - 3.8235735702514626, |
155 | | - 0.001, # FIXME: compute |
156 | | - ), |
157 | | - (LLAMA_3p1_8B_INSTRUCT, False): ( |
| 141 | + LLAMA_3p1_8B_INSTRUCT: ( |
158 | 142 | 2.6994638133048965, |
159 | 143 | 0.00047589250549208347, |
160 | 144 | ), |
161 | | - (GRANITE_3p2_8B_INSTRUCT, False): ( |
| 145 | + GRANITE_3p2_8B_INSTRUCT: ( |
162 | 146 | 2.3919514417648315, |
163 | 147 | 0.0005767398688476533, |
164 | 148 | ), |
165 | | - (GRANITE_20B_CODE_INSTRUCT_8K, False): ( |
| 149 | + GRANITE_20B_CODE_INSTRUCT_8K: ( |
166 | 150 | 2.640706129074097, |
167 | 151 | 0.00034344267623964697, |
168 | 152 | ), |
169 | | - (LLAMA_3p1_70B_INSTRUCT, False): ( |
| 153 | + LLAMA_3p1_70B_INSTRUCT: ( |
170 | 154 | 2.841279556751251, |
171 | 155 | 0.0044301633024588115, |
172 | 156 | ), |
@@ -530,9 +514,14 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor): |
530 | 514 | # only consider those metrics captured prior to the eos |
531 | 515 | level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes) |
532 | 516 |
|
533 | | - ce_threshold, diff_threshold = fail_thresholds.get( |
534 | | - (model_path, USE_MICRO_MODELS), default_metrics_threshold |
535 | | - ) |
| 517 | + # if we do not have real model weights, use a default_metrics_threshold |
| 518 | + if USE_MICRO_MODELS and micro_model_path is None: |
| 519 | + ce_threshold, diff_threshold = default_metrics_threshold |
| 520 | + # if we have real weights, try and get the proper validation metrics threshold |
| 521 | + else: |
| 522 | + ce_threshold, diff_threshold = fail_thresholds.get( |
| 523 | + model_path, default_metrics_threshold |
| 524 | + ) |
536 | 525 |
|
537 | 526 | # get all failed responses for each metric |
538 | 527 | ce_fail_responses = filter_failed_level_1_cases( |
|
0 commit comments