Skip to content

Commit f6c9a8b

Browse files
authored
Merge pull request #151 from foundation-model-stack/fix_test_scripts_assertions
fixed test_scripts program assertions
2 parents 99e6bd1 + adc276e commit f6c9a8b

File tree

1 file changed

+33
-12
lines changed

1 file changed

+33
-12
lines changed

tests/models/test_scripts.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pathlib import Path
77
import itertools
88
import math
9+
from aiu_fms_testing_utils.utils.paged import get_programs_prompts, ProgramCriteria
910

1011
FMS_DIR = Path(__file__).parent
1112
AIU_FMS_DIR = os.path.join(FMS_DIR, "../../../aiu-fms-testing-utils/")
@@ -291,28 +292,48 @@ def test_dpp_script(
291292
)
292293
print(result_text)
293294
with open(os.environ["DT_PROG_CRITERIA_FILEPATH"], "r") as f:
294-
program_criteria_list = json.load(f)["programs"]
295+
program_criteria_json_list = json.load(f)["programs"]
296+
program_criteria_list = []
297+
for i, d in enumerate(program_criteria_json_list):
298+
program_criteria_list.append(
299+
ProgramCriteria(
300+
i,
301+
d["max_batch"],
302+
d["max_tkv"],
303+
d["batch_granularity"],
304+
d["tkv_granularity"],
305+
)
306+
)
295307

296308
if programs is None:
297309
program_assertions = [i for i in range(len(program_criteria_list))]
298310
shape_assertions = [">=0", ">=0"]
299311
else:
312+
program_map = get_programs_prompts(
313+
program_criteria_list,
314+
multiple=64,
315+
max_batch_size=2,
316+
max_tkv=512,
317+
program_cycles=max_new_tokens,
318+
)
300319
programs_split = programs.split(":")
301320
program_ids_str = programs_split[0]
302321
shape_assertions = [
303322
f">={_}" if _.isnumeric() else _ for _ in programs_split[1].split(",")
304323
]
305-
match_number = r"\d+"
306-
valid_program_assertions = [
307-
f">={re.search(match_number, _).group()}" for _ in shape_assertions
308-
]
309-
# need to add 1 for tkv as that is the first decode
310-
program_assertions = [
311-
i
312-
for i, p in enumerate(program_criteria_list)
313-
if eval(f"p['max_batch']{valid_program_assertions[0]}")
314-
and eval(f"p['max_tkv']{valid_program_assertions[1]}+1")
315-
]
324+
325+
program_assertions = []
326+
for program_id_seq, shapes in program_map.items():
327+
if any(
328+
(
329+
eval(
330+
f"shape[0]{shape_assertions[0]} and shape[1]{shape_assertions[1]}"
331+
)
332+
for shape in shapes
333+
)
334+
):
335+
program_assertions.append(program_id_seq[0].program_id)
336+
316337
if program_ids_str == "?":
317338
program_assertions = program_assertions[:1]
318339
elif program_ids_str.isnumeric():

0 commit comments

Comments
 (0)