|
6 | 6 | from pathlib import Path |
7 | 7 | import itertools |
8 | 8 | import math |
| 9 | +from aiu_fms_testing_utils.utils.paged import get_programs_prompts, ProgramCriteria |
9 | 10 |
|
10 | 11 | FMS_DIR = Path(__file__).parent |
11 | 12 | AIU_FMS_DIR = os.path.join(FMS_DIR, "../../../aiu-fms-testing-utils/") |
@@ -291,28 +292,48 @@ def test_dpp_script( |
291 | 292 | ) |
292 | 293 | print(result_text) |
293 | 294 | with open(os.environ["DT_PROG_CRITERIA_FILEPATH"], "r") as f: |
294 | | - program_criteria_list = json.load(f)["programs"] |
| 295 | + program_criteria_json_list = json.load(f)["programs"] |
| 296 | + program_criteria_list = [] |
| 297 | + for i, d in enumerate(program_criteria_json_list): |
| 298 | + program_criteria_list.append( |
| 299 | + ProgramCriteria( |
| 300 | + i, |
| 301 | + d["max_batch"], |
| 302 | + d["max_tkv"], |
| 303 | + d["batch_granularity"], |
| 304 | + d["tkv_granularity"], |
| 305 | + ) |
| 306 | + ) |
295 | 307 |
|
296 | 308 | if programs is None: |
297 | 309 | program_assertions = [i for i in range(len(program_criteria_list))] |
298 | 310 | shape_assertions = [">=0", ">=0"] |
299 | 311 | else: |
| 312 | + program_map = get_programs_prompts( |
| 313 | + program_criteria_list, |
| 314 | + multiple=64, |
| 315 | + max_batch_size=2, |
| 316 | + max_tkv=512, |
| 317 | + program_cycles=max_new_tokens, |
| 318 | + ) |
300 | 319 | programs_split = programs.split(":") |
301 | 320 | program_ids_str = programs_split[0] |
302 | 321 | shape_assertions = [ |
303 | 322 | f">={_}" if _.isnumeric() else _ for _ in programs_split[1].split(",") |
304 | 323 | ] |
305 | | - match_number = r"\d+" |
306 | | - valid_program_assertions = [ |
307 | | - f">={re.search(match_number, _).group()}" for _ in shape_assertions |
308 | | - ] |
309 | | - # need to add 1 for tkv as that is the first decode |
310 | | - program_assertions = [ |
311 | | - i |
312 | | - for i, p in enumerate(program_criteria_list) |
313 | | - if eval(f"p['max_batch']{valid_program_assertions[0]}") |
314 | | - and eval(f"p['max_tkv']{valid_program_assertions[1]}+1") |
315 | | - ] |
| 324 | + |
| 325 | + program_assertions = [] |
| 326 | + for program_id_seq, shapes in program_map.items(): |
| 327 | + if any( |
| 328 | + ( |
| 329 | + eval( |
| 330 | + f"shape[0]{shape_assertions[0]} and shape[1]{shape_assertions[1]}" |
| 331 | + ) |
| 332 | + for shape in shapes |
| 333 | + ) |
| 334 | + ): |
| 335 | + program_assertions.append(program_id_seq[0].program_id) |
| 336 | + |
316 | 337 | if program_ids_str == "?": |
317 | 338 | program_assertions = program_assertions[:1] |
318 | 339 | elif program_ids_str.isnumeric(): |
|
0 commit comments