|
13 | 13 |
|
14 | 14 | BATCH_SIZES = [0, 1, 2, 3, 4, 8] |
15 | 15 | ENFORCE_HETEROGENEOUS = [True, False] |
| 16 | +LEN_ENFORCE_SIZES = [0, 1, 2, 3, 4] |
16 | 17 | TRUNCATION = [True, False] |
17 | 18 | ENFORCE_TRUNCATION_SIZE = [ |
18 | 19 | [], |
|
58 | 59 | TOKENIZER = AutoTokenizer.from_pretrained("ibm-granite/granite-3.3-8b-instruct") |
59 | 60 |
|
60 | 61 |
|
| 62 | +def _replace_begin_mid_end( |
| 63 | + prompt_list: list[str], target_count: int = 1, target_length: int = 128 |
| 64 | +): |
| 65 | + """Replaces slots in the list with new of target length: |
| 66 | + - First `target_count` slots |
| 67 | + - Middle `target_count` slots |
| 68 | + - Last `target_count` slots |
| 69 | +
|
| 70 | + Args: |
| 71 | + prompt_list (list[str]): a list of dummy strings. |
| 72 | + target_count (int, optional): how many slots to replace. Defaults to 1. |
| 73 | + target_length (int, optional): how long the string will be. |
| 74 | + """ |
| 75 | + |
| 76 | + replacement_block = ["enforce" * target_length] * target_count |
| 77 | + |
| 78 | + if target_count >= 1: |
| 79 | + beginning = replacement_block + prompt_list[target_count:] |
| 80 | + mid = len(prompt_list) // 2 |
| 81 | + pointer = max(0, mid - target_count // 2) |
| 82 | + middle = ( |
| 83 | + prompt_list[:pointer] |
| 84 | + + replacement_block |
| 85 | + + prompt_list[pointer + target_count :] |
| 86 | + ) |
| 87 | + end = prompt_list[:-target_count] + replacement_block |
| 88 | + else: |
| 89 | + beginning = prompt_list |
| 90 | + middle = prompt_list |
| 91 | + end = prompt_list |
| 92 | + return (beginning, middle, end) |
| 93 | + |
| 94 | + |
61 | 95 | def _prepare_sub_sharegpt_dataset(prompt_length_min, prompt_length_max, tokenizer): |
62 | 96 | dataset_path = os.environ.get( |
63 | 97 | "SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json") |
@@ -247,3 +281,64 @@ def test_get_truncation(enforce_truncation_size, available_sizes): |
247 | 281 | assert "size_to_enforce" in f"{e}" |
248 | 282 | except Exception as e: |
249 | 283 | pytest.fail(f"Unexpeced exception: {e}") |
| 284 | + |
| 285 | + |
| 286 | +ENFORCE_SIZES_COMBO = list(product(BATCH_SIZES, LEN_ENFORCE_SIZES)) |
| 287 | + |
| 288 | + |
| 289 | +@pytest.mark.parametrize("batch_size, target_count", ENFORCE_SIZES_COMBO) |
| 290 | +def test_enforce_sizes(batch_size, target_count): |
| 291 | + print(f"{batch_size=}, {target_count=}") |
| 292 | + base_text = "base" |
| 293 | + basic_seq_len = 64 |
| 294 | + prompt_list = [base_text * basic_seq_len] * batch_size |
| 295 | + enforce_len = 128 |
| 296 | + list_of_prompt_list = _replace_begin_mid_end(prompt_list, target_count, enforce_len) |
| 297 | + print(list_of_prompt_list) |
| 298 | + reference = None |
| 299 | + for prompt_list in list_of_prompt_list: |
| 300 | + try: |
| 301 | + prompts_and_sizes = __sample_requests( |
| 302 | + prompt_list, |
| 303 | + batch_size, |
| 304 | + TOKENIZER, |
| 305 | + 32, |
| 306 | + enforce_len, |
| 307 | + None, |
| 308 | + False, |
| 309 | + [enforce_len] * target_count, |
| 310 | + False, |
| 311 | + ) |
| 312 | + except ValueError as e: |
| 313 | + assert "is smaller than" in f"{e}" |
| 314 | + continue |
| 315 | + |
| 316 | + # Given this test case final batch size should equal returned prompts_and_sizes |
| 317 | + assert len(prompts_and_sizes) == batch_size |
| 318 | + if reference is None: |
| 319 | + reference = prompts_and_sizes.copy() |
| 320 | + # all different prompts should yield the same result (without seed it should be sorted) |
| 321 | + assert prompts_and_sizes == reference |
| 322 | + num_found = 0 |
| 323 | + for _, sizes in prompts_and_sizes: |
| 324 | + if sizes == 128: |
| 325 | + num_found += 1 |
| 326 | + # Verify that all inserted enforceable_sizes are found |
| 327 | + assert num_found == target_count |
| 328 | + |
| 329 | + try: |
| 330 | + half_batch_prompts_and_sizes = __sample_requests( |
| 331 | + prompt_list, |
| 332 | + batch_size // 2, |
| 333 | + TOKENIZER, |
| 334 | + 32, |
| 335 | + enforce_len, |
| 336 | + None, |
| 337 | + False, |
| 338 | + [enforce_len] * target_count, |
| 339 | + False, |
| 340 | + ) |
| 341 | + except ValueError as e: |
| 342 | + assert "is smaller than" in f"{e}" |
| 343 | + continue |
| 344 | + assert len(half_batch_prompts_and_sizes) == batch_size // 2 |
0 commit comments