Skip to content

Commit 5e14106

Browse files
committed
[test_sampling] added test for checking enforce_sizes
Signed-off-by: kcirred <16872435+kcirred@users.noreply.github.com>
1 parent 5d2c665 commit 5e14106

File tree

1 file changed

+95
-0
lines changed

1 file changed

+95
-0
lines changed

tests/utils/test_sampling.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
BATCH_SIZES = [0, 1, 2, 3, 4, 8]
1515
ENFORCE_HETEROGENEOUS = [True, False]
16+
LEN_ENFORCE_SIZES = [0, 1, 2, 3, 4]
1617
TRUNCATION = [True, False]
1718
ENFORCE_TRUNCATION_SIZE = [
1819
[],
@@ -58,6 +59,39 @@
5859
TOKENIZER = AutoTokenizer.from_pretrained("ibm-granite/granite-3.3-8b-instruct")
5960

6061

62+
def _replace_begin_mid_end(
63+
prompt_list: list[str], target_count: int = 1, target_length: int = 128
64+
):
65+
"""Replaces slots in the list with new of target length:
66+
- First `target_count` slots
67+
- Middle `target_count` slots
68+
- Last `target_count` slots
69+
70+
Args:
71+
prompt_list (list[str]): a list of dummy strings.
72+
target_count (int, optional): how many slots to replace. Defaults to 1.
73+
target_length (int, optional): how long the string will be.
74+
"""
75+
76+
replacement_block = ["enforce" * target_length] * target_count
77+
78+
if target_count >= 1:
79+
beginning = replacement_block + prompt_list[target_count:]
80+
mid = len(prompt_list) // 2
81+
pointer = max(0, mid - target_count // 2)
82+
middle = (
83+
prompt_list[:pointer]
84+
+ replacement_block
85+
+ prompt_list[pointer + target_count :]
86+
)
87+
end = prompt_list[:-target_count] + replacement_block
88+
else:
89+
beginning = prompt_list
90+
middle = prompt_list
91+
end = prompt_list
92+
return (beginning, middle, end)
93+
94+
6195
def _prepare_sub_sharegpt_dataset(prompt_length_min, prompt_length_max, tokenizer):
6296
dataset_path = os.environ.get(
6397
"SHARE_GPT_DATASET_PATH", os.path.expanduser("~/share_gpt.json")
@@ -247,3 +281,64 @@ def test_get_truncation(enforce_truncation_size, available_sizes):
247281
assert "size_to_enforce" in f"{e}"
248282
except Exception as e:
249283
pytest.fail(f"Unexpeced exception: {e}")
284+
285+
286+
ENFORCE_SIZES_COMBO = list(product(BATCH_SIZES, LEN_ENFORCE_SIZES))
287+
288+
289+
@pytest.mark.parametrize("batch_size, target_count", ENFORCE_SIZES_COMBO)
290+
def test_enforce_sizes(batch_size, target_count):
291+
print(f"{batch_size=}, {target_count=}")
292+
base_text = "base"
293+
basic_seq_len = 64
294+
prompt_list = [base_text * basic_seq_len] * batch_size
295+
enforce_len = 128
296+
list_of_prompt_list = _replace_begin_mid_end(prompt_list, target_count, enforce_len)
297+
print(list_of_prompt_list)
298+
reference = None
299+
for prompt_list in list_of_prompt_list:
300+
try:
301+
prompts_and_sizes = __sample_requests(
302+
prompt_list,
303+
batch_size,
304+
TOKENIZER,
305+
32,
306+
enforce_len,
307+
None,
308+
False,
309+
[enforce_len] * target_count,
310+
False,
311+
)
312+
except ValueError as e:
313+
assert "is smaller than" in f"{e}"
314+
continue
315+
316+
# Given this test case final batch size should equal returned prompts_and_sizes
317+
assert len(prompts_and_sizes) == batch_size
318+
if reference is None:
319+
reference = prompts_and_sizes.copy()
320+
# all different prompts should yield the same result (without seed it should be sorted)
321+
assert prompts_and_sizes == reference
322+
num_found = 0
323+
for _, sizes in prompts_and_sizes:
324+
if sizes == 128:
325+
num_found += 1
326+
# Verify that all inserted enforceable_sizes are found
327+
assert num_found == target_count
328+
329+
try:
330+
half_batch_prompts_and_sizes = __sample_requests(
331+
prompt_list,
332+
batch_size // 2,
333+
TOKENIZER,
334+
32,
335+
enforce_len,
336+
None,
337+
False,
338+
[enforce_len] * target_count,
339+
False,
340+
)
341+
except ValueError as e:
342+
assert "is smaller than" in f"{e}"
343+
continue
344+
assert len(half_batch_prompts_and_sizes) == batch_size // 2

0 commit comments

Comments
 (0)