Skip to content

Commit 178bc89

Browse files
authored
Merge pull request #121 from kcirred/sampling_mod
modification to enforce_size behavior to start accepting samples even before enforcing sizes when there is sufficient space
2 parents 4a74f0a + 16011ea commit 178bc89

File tree

2 files changed

+113
-3
lines changed

2 files changed

+113
-3
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -378,7 +378,10 @@ def __sample_requests(
378378
random.Random(seed).shuffle(dataset)
379379

380380
for prompt, prompt_len in dataset:
381-
if len(filtered_dataset) == num_requests and not enforce_sizes:
381+
if (
382+
len(filtered_dataset) + len(enforced_dataset) == num_requests
383+
and not enforce_sizes
384+
):
382385
break
383386

384387
# NOTE: This section is for enforce heterogeneous, does not work with enforce_sizes
@@ -387,7 +390,6 @@ def __sample_requests(
387390
and max_heterogeneous_combinations > len(filtered_dataset)
388391
and len(filtered_dataset) < num_requests
389392
):
390-
# for _, size in filtered_dataset:
391393
current_padded_size = pad_size_dict[prompt_len]
392394

393395
if current_padded_size not in seen_sizes:
@@ -403,6 +405,7 @@ def __sample_requests(
403405
# NOTE: this should not be `elif` despite enforce_sizes and enforce_sizes_with_truncation
404406
# are mutually exclusive because we allow same prompt to be used in enforce_sizes_with_truncation
405407
# even if it is taken from enforce_sizes
408+
truncation_found = None
406409
if enforce_sizes_with_truncation:
407410
truncation_found: Tuple[int, int] = next(
408411
(
@@ -422,6 +425,16 @@ def __sample_requests(
422425
)
423426
enforced_dataset.append((truncated_prompt, truncate_to_size))
424427
enforce_sizes_with_truncation.remove(truncation_found)
428+
# This condition allows adding prompts to the final dataset as long as there is
429+
# sufficient space allocated for sizes that need to be enforced.
430+
if (
431+
not truncation_found
432+
and current_padded_size not in enforce_sizes
433+
and len(filtered_dataset) + len(enforced_dataset)
434+
< num_requests
435+
- (len(enforce_sizes) + len(enforce_sizes_with_truncation))
436+
):
437+
filtered_dataset.append((prompt, prompt_len))
425438

426439
# when not enforcing heterogeneous or when exhausted all possible prompt_lengths
427440
else:
@@ -441,10 +454,12 @@ def __sample_requests(
441454
print(
442455
f"There may be prompt size repeats because {num_requests=} while {max_heterogeneous_combinations=}"
443456
)
444-
if enforced_dataset:
457+
if enforced_dataset and enforce_heterogeneous:
445458
filtered_dataset = _merge_enforce_keep_heterogeneous(
446459
enforced_dataset, filtered_dataset, num_requests
447460
)
461+
elif enforced_dataset:
462+
filtered_dataset = enforced_dataset + filtered_dataset
448463

449464
if len(filtered_dataset) != num_requests:
450465
warnings.warn("Returning dataset not equal to number requested", stacklevel=2)

tests/utils/test_sampling.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
BATCH_SIZES = [0, 1, 2, 3, 4, 8]
1515
ENFORCE_HETEROGENEOUS = [True, False]
16+
LEN_ENFORCE_SIZES = [0, 1, 2, 3, 4]
1617
TRUNCATION = [True, False]
1718
ENFORCE_TRUNCATION_SIZE = [
1819
[],
@@ -58,6 +59,39 @@
5859
TOKENIZER = AutoTokenizer.from_pretrained("ibm-granite/granite-3.3-8b-instruct")
5960

6061

62+
def _replace_begin_mid_end(
63+
prompt_list: list[str], target_count: int = 1, target_length: int = 128
64+
):
65+
"""Replaces slots in the list with new of target length:
66+
- First `target_count` slots
67+
- Middle `target_count` slots
68+
- Last `target_count` slots
69+
70+
Args:
71+
prompt_list (list[str]): a list of dummy strings.
72+
target_count (int, optional): how many slots to replace. Defaults to 1.
73+
target_length (int, optional): how long the string will be.
74+
"""
75+
76+
replacement_block = ["enforce" * target_length] * target_count
77+
78+
if target_count >= 1:
79+
beginning = replacement_block + prompt_list[target_count:]
80+
mid = len(prompt_list) // 2
81+
pointer = max(0, mid - target_count // 2)
82+
middle = (
83+
prompt_list[:pointer]
84+
+ replacement_block
85+
+ prompt_list[pointer + target_count :]
86+
)
87+
end = prompt_list[:-target_count] + replacement_block
88+
else:
89+
beginning = prompt_list
90+
middle = prompt_list
91+
end = prompt_list
92+
return (beginning, middle, end)
93+
94+
6195
def _prepare_sub_sharegpt_dataset(prompt_length_min, prompt_length_max, tokenizer):
6296
dataset_path = os.environ.get(
6397
"SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json")
@@ -247,3 +281,64 @@ def test_get_truncation(enforce_truncation_size, available_sizes):
247281
assert "size_to_enforce" in f"{e}"
248282
except Exception as e:
249283
pytest.fail(f"Unexpeced exception: {e}")
284+
285+
286+
ENFORCE_SIZES_COMBO = list(product(BATCH_SIZES, LEN_ENFORCE_SIZES))
287+
288+
289+
@pytest.mark.parametrize("batch_size, target_count", ENFORCE_SIZES_COMBO)
290+
def test_enforce_sizes(batch_size, target_count):
291+
print(f"{batch_size=}, {target_count=}")
292+
base_text = "base"
293+
basic_seq_len = 64
294+
prompt_list = [base_text * basic_seq_len] * batch_size
295+
enforce_len = 128
296+
list_of_prompt_list = _replace_begin_mid_end(prompt_list, target_count, enforce_len)
297+
print(list_of_prompt_list)
298+
reference = None
299+
for prompt_list in list_of_prompt_list:
300+
try:
301+
prompts_and_sizes = __sample_requests(
302+
prompt_list,
303+
batch_size,
304+
TOKENIZER,
305+
32,
306+
enforce_len,
307+
None,
308+
False,
309+
[enforce_len] * target_count,
310+
False,
311+
)
312+
except ValueError as e:
313+
assert "is smaller than" in f"{e}"
314+
continue
315+
316+
# Given this test case final batch size should equal returned prompts_and_sizes
317+
assert len(prompts_and_sizes) == batch_size
318+
if reference is None:
319+
reference = prompts_and_sizes.copy()
320+
# all different prompts should yield the same result (without seed it should be sorted)
321+
assert prompts_and_sizes == reference
322+
num_found = 0
323+
for _, sizes in prompts_and_sizes:
324+
if sizes == 128:
325+
num_found += 1
326+
# Verify that all inserted enforceable_sizes are found
327+
assert num_found == target_count
328+
329+
try:
330+
half_batch_prompts_and_sizes = __sample_requests(
331+
prompt_list,
332+
batch_size // 2,
333+
TOKENIZER,
334+
32,
335+
enforce_len,
336+
None,
337+
False,
338+
[enforce_len] * target_count,
339+
False,
340+
)
341+
except ValueError as e:
342+
assert "is smaller than" in f"{e}"
343+
continue
344+
assert len(half_batch_prompts_and_sizes) == batch_size // 2

0 commit comments

Comments
 (0)