Skip to content

Commit ccb9391

Browse files
JRosenkranzflaviabeo
authored andcommitted
removed micro model validation thresholds -- use default unless they are trained
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
1 parent ba4ab5d commit ccb9391

File tree

1 file changed

+13
-24
lines changed

1 file changed

+13
-24
lines changed

tests/models/test_decoders.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -136,37 +136,21 @@
136136
# thresholds are chosen based on 1024 tokens per sequence
137137
# 1% error threshold rate between cpu fp32 and cuda fp16
138138
# if a models failure thresholds do not exist in this dict, default to the default_metrics_threshold defined above
139-
# threshold key is (model_id, is_tiny_model)
139+
# threshold key is model_id
140140
fail_thresholds = {
141-
(LLAMA_3p1_8B_INSTRUCT, True): (
142-
3.7392955756187423,
143-
0.001, # FIXME: compute
144-
),
145-
(GRANITE_3p2_8B_INSTRUCT, True): (
146-
2.996668996810913,
147-
0.001, # FIXME: compute
148-
),
149-
(GRANITE_20B_CODE_INSTRUCT_8K, True): (
150-
3.7392955756187423, # FIXME: compute -- setting to micro llama 3.1 8b instruct
151-
0.001, # FIXME: compute
152-
),
153-
(LLAMA_3p1_70B_INSTRUCT, True): (
154-
3.8235735702514626,
155-
0.001, # FIXME: compute
156-
),
157-
(LLAMA_3p1_8B_INSTRUCT, False): (
141+
LLAMA_3p1_8B_INSTRUCT: (
158142
2.6994638133048965,
159143
0.00047589250549208347,
160144
),
161-
(GRANITE_3p2_8B_INSTRUCT, False): (
145+
GRANITE_3p2_8B_INSTRUCT: (
162146
2.3919514417648315,
163147
0.0005767398688476533,
164148
),
165-
(GRANITE_20B_CODE_INSTRUCT_8K, False): (
149+
GRANITE_20B_CODE_INSTRUCT_8K: (
166150
2.640706129074097,
167151
0.00034344267623964697,
168152
),
169-
(LLAMA_3p1_70B_INSTRUCT, False): (
153+
LLAMA_3p1_70B_INSTRUCT: (
170154
2.841279556751251,
171155
0.0044301633024588115,
172156
),
@@ -530,9 +514,14 @@ def _metric_calculator(r: torch.Tensor, t: torch.Tensor):
530514
# only consider those metrics captured prior to the eos
531515
level_1_metrics = __filter_before_eos(level_1_metrics, eos_indexes)
532516

533-
ce_threshold, diff_threshold = fail_thresholds.get(
534-
(model_path, USE_MICRO_MODELS), default_metrics_threshold
535-
)
517+
# if we do not have real model weights, use a default_metrics_threshold
518+
if USE_MICRO_MODELS and micro_model_path is None:
519+
ce_threshold, diff_threshold = default_metrics_threshold
520+
# if we have real weights, try and get the proper validation metrics threshold
521+
else:
522+
ce_threshold, diff_threshold = fail_thresholds.get(
523+
model_path, default_metrics_threshold
524+
)
536525

537526
# get all failed responses for each metric
538527
ce_fail_responses = filter_failed_level_1_cases(

0 commit comments

Comments
 (0)