Skip to content

Commit 12a276f

Browse files
authored
Merge pull request #113 from kcirred/truncation
add truncation option to enforce_sizes to truncate prompts from a larger length to meet sizes that may not be available
2 parents dff0aa2 + 204f4ae commit 12a276f

File tree

2 files changed

+372
-53
lines changed

2 files changed

+372
-53
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 216 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import random
66
import requests
77
import time
8+
import bisect
89

910
# Third Party
1011

@@ -166,12 +167,83 @@ def _merge_enforce_keep_heterogeneous(
166167
)
167168
elif len(final_list) < batch_size:
168169
warnings.warn(
169-
f"Requested {batch_size=}, than possible combined list. Will return smaller list than batch size",
170+
f"Requested {batch_size=}, is greater than possible combined list. Will return smaller list than batch size",
170171
stacklevel=2,
171172
)
172173
return final_list
173174

174175

176+
def _get_truncation_size(
177+
dataset_size_and_count: dict[int, int], enforce_sizes: List[int]
178+
):
179+
"""
180+
Given a list of sizes to enforce and a dictionary of sizes that exists and their count,
181+
find out which sizes are not possible and create a new truncation list which will grab from
182+
the next larger size in order to enforce that size.
183+
If there are no larger sizes, try to take the largest from the dataset.
184+
185+
Args:
186+
dataset_size_and_count (Dict[int, int]): List of possible sizes and counts for the dataset
187+
enforce_sizes (List[int]): List of ints which sizes must be enforced
188+
189+
Returns:
190+
List[Tuple[int,int]]: a List of Tuples which have first int as size to truncate to, and second int as to prompt len to grab from
191+
"""
192+
truncation_list: List[Tuple[int, int]] = []
193+
sorted_sizes_in_dataset: List[int] = sorted(dataset_size_and_count.keys())
194+
# sort for consistent results where user mixes order of enforce_sizes
195+
enforce_sizes = sorted(enforce_sizes)
196+
197+
for size_to_enforce in enforce_sizes:
198+
found_idx = bisect.bisect_left(sorted_sizes_in_dataset, size_to_enforce)
199+
truncation_size = None
200+
201+
# if valid search found
202+
if found_idx < len(sorted_sizes_in_dataset):
203+
while found_idx < len(sorted_sizes_in_dataset):
204+
# reset the candidate to the new found_idx
205+
candidate = sorted_sizes_in_dataset[found_idx]
206+
# Have to check if this prompt length is available with the count
207+
if dataset_size_and_count[candidate] > 0:
208+
# if count is > 0 then decrement the count as it no longer can be used for future prompts
209+
dataset_size_and_count[candidate] -= 1
210+
truncation_size = candidate
211+
break
212+
# if prompt length is not avaible increment to see if the next larger prompt is available
213+
found_idx += 1
214+
215+
if truncation_size is None:
216+
raise ValueError(
217+
f"We've exhausted all possible truncation sizes, please increase max_prompt_len or remove {size_to_enforce=}"
218+
)
219+
truncation_list.append((size_to_enforce, truncation_size))
220+
else:
221+
# this occurs when size_to_enforce is outside of the max range of dataset
222+
if sorted_sizes_in_dataset:
223+
# try to grab the largest size from the end of sorted list if it is available otherwise throw error
224+
truncation_size = sorted_sizes_in_dataset[-1]
225+
if dataset_size_and_count[truncation_size] > 0:
226+
truncation_list.append((size_to_enforce, truncation_size))
227+
dataset_size_and_count[truncation_size] -= 1
228+
else:
229+
raise ValueError(
230+
f"{size_to_enforce=} is larger than largest sample and not available."
231+
)
232+
return truncation_list
233+
234+
235+
def _remove_list_from_list(main_list, list_to_remove):
236+
for item in list_to_remove:
237+
if item in main_list:
238+
main_list.remove(item)
239+
return main_list
240+
241+
242+
# Because we now require encoding the dataset, cache the datasets to make
243+
# second sample request quick
244+
__cached_encoded_datasets = {}
245+
246+
175247
def __sample_requests(
176248
prompt_list: List[str],
177249
num_requests: int,
@@ -180,97 +252,203 @@ def __sample_requests(
180252
prompt_length_max: int = 64,
181253
seed: Optional[int] = None,
182254
enforce_heterogeneous: bool = False,
183-
enforce_sizes: List[int] = [],
255+
enforce_sizes: List[int] | None = None,
256+
truncation: bool = False,
184257
pad_multiple: int = 64,
258+
_cached_dataset_key: Optional[str] = None,
185259
):
186260
"""
187-
Shuffles dataset, tokenizes the prompts and then filters
261+
Shuffles dataset, tokenizes the prompts and then filters.
188262
189263
Args:
190264
prompt_length_min (int): filters out prompts shorter than this value.
191265
prompt_length_max (int): filters out prompts larger than this value.
192266
enforce_sizes (List[int]): sample request will grab a prompt with this length if available.
193-
enforce_heterogeneous (bool): Pads all prompts within batch size to nearest multiple of 64.
267+
enforce_heterogeneous (bool): Pads all prompts within batch to nearest multiple of `pad_multiple`.
268+
However, if enforce_sizes is not empty, it will set enforce_heteogeneous to False.
194269
pad_multiple (int): Used only when enforce_heterogeneous is True or enforce_sizes is not empty, asserts that prompt_length would be padded to this multiple
195270
List[Tuple[str, int]]: a filtered dataset
271+
truncation (bool): If true will truncate to an enforced size if the size does not exist. Only to be used with enforce_sizes, otherwise
272+
will be ignored
273+
_cached_dataset_key (optional[str]): The key to the dataset if enabling caching of encoded datasets
274+
275+
Returns:
276+
List[Tuple[str, int]]
196277
"""
197278

279+
assert prompt_length_max >= prompt_length_min, (
280+
"Please enter valid prompt length max/min values"
281+
)
282+
283+
if enforce_sizes is None:
284+
enforce_sizes = []
285+
286+
if enforce_heterogeneous and enforce_sizes:
287+
warnings.warn(
288+
f"{enforce_heterogeneous=} and {enforce_sizes=}, these two are not designed to be used at the same time. Forcing enforce_heterogeneous to False"
289+
)
290+
enforce_heterogeneous = False
291+
198292
# Based on min/max prompt length, one can back out the number of possible heterogeneous values
199293
max_heterogeneous_combinations = (prompt_length_max // pad_multiple) - (
200294
(prompt_length_min - 1) // pad_multiple
201295
)
202296

203297
# Filter out sequences that are too long or too short
298+
dataset: List[Tuple[str, int]] = []
204299
filtered_dataset: List[Tuple[str, int]] = []
205300
enforced_dataset: List[Tuple[str, int]] = []
206301

207302
# To track sizes seen
208303
seen_sizes: List[int] = []
209304

305+
sample_size_counter: dict[int, int] = {}
306+
# first int is the size to truncate to, second int is size of text to grab from
307+
enforce_sizes_with_truncation: List[Tuple[int, int]] = []
308+
309+
if truncation and not enforce_sizes:
310+
warnings.warn(
311+
f"truncation and enforce_sizes should be used together, whereas {truncation=} and {enforce_sizes=}, hence no truncation will happen",
312+
stacklevel=2,
313+
)
314+
315+
if (
316+
_cached_dataset_key is not None
317+
and _cached_dataset_key in __cached_encoded_datasets
318+
):
319+
dataset = __cached_encoded_datasets[_cached_dataset_key]
320+
else:
321+
# Loop to check create filtered dataset
322+
for i in range(len(prompt_list)):
323+
# Tokenize the prompts and completions.
324+
prompt = prompt_list[i]
325+
prompt_token_ids = tokenizer.encode(prompt, return_tensors="pt").squeeze(0)
326+
327+
prompt_len = len(prompt_token_ids)
328+
329+
dataset.append((prompt, prompt_len))
330+
331+
dataset.sort(key=lambda tuple: tuple[1])
332+
__cached_encoded_datasets[_cached_dataset_key] = dataset
333+
334+
# only keep values that are required
335+
dataset = [
336+
r for r in dataset if r[1] >= prompt_length_min and r[1] <= prompt_length_max
337+
]
338+
339+
pad_size_dict: dict[int, int] = {}
340+
for _, prompt_len in dataset:
341+
pad_size_dict.setdefault(prompt_len, get_pad_size(prompt_len, pad_multiple))
342+
sample_size_counter[pad_size_dict[prompt_len]] = (
343+
sample_size_counter.get(pad_size_dict[prompt_len], 0) + 1
344+
)
345+
210346
if enforce_sizes:
211347
for size in enforce_sizes:
212348
# Check that enforced sizes fall within min/max range
213349
assert prompt_length_min <= size <= prompt_length_max, (
214350
f"Size {size} in enforced sizes not within {prompt_length_min=}, {prompt_length_max=}"
215351
)
352+
assert size % pad_multiple == 0, (
353+
"Enforce sizes must be a multiple of pad_multiple"
354+
)
216355
if len(enforce_sizes) > num_requests:
217356
raise ValueError(
218357
f"{num_requests=} which is smaller than {len(enforce_sizes)=}"
219358
)
220359

360+
if truncation:
361+
truncation_size_counter = sample_size_counter.copy()
362+
363+
# Allocate certain counts to enforce_sizes
364+
needs_truncation = []
365+
for size in enforce_sizes:
366+
if sample_size_counter.get(size, 0) > 0:
367+
sample_size_counter[size] -= 1
368+
else:
369+
needs_truncation.append(size)
370+
enforce_sizes = _remove_list_from_list(enforce_sizes, needs_truncation)
371+
372+
enforce_sizes_with_truncation = _get_truncation_size(
373+
truncation_size_counter, needs_truncation
374+
)
375+
221376
# Shuffle the dataset.
222377
if seed is not None:
223-
random.Random(seed).shuffle(prompt_list)
378+
random.Random(seed).shuffle(dataset)
224379

225-
for i in range(len(prompt_list)):
380+
for prompt, prompt_len in dataset:
226381
if len(filtered_dataset) == num_requests and not enforce_sizes:
227382
break
228383

229-
# Tokenize the prompts and completions.
230-
prompt = prompt_list[i]
231-
prompt_token_ids = tokenizer.encode(prompt, return_tensors="pt").squeeze(0)
232-
233-
prompt_len = len(prompt_token_ids)
234-
if prompt_len < prompt_length_min or prompt_len > prompt_length_max:
235-
# Prune too short or too long sequences.
236-
continue
237-
# This section is for enforce heterogeneous
384+
# NOTE: This section is for enforce heterogeneous, does not work with enforce_sizes
238385
if (
239386
enforce_heterogeneous
240387
and max_heterogeneous_combinations > len(filtered_dataset)
241388
and len(filtered_dataset) < num_requests
242389
):
243390
# for _, size in filtered_dataset:
244-
current_padded_size = get_pad_size(prompt_len, pad_multiple)
245-
246-
# If it's in the list of enforce_sizes it is enforced, can remove from list
247-
if current_padded_size in enforce_sizes:
248-
enforce_sizes.remove(current_padded_size)
249-
enforced_dataset.append((prompt, prompt_len))
391+
current_padded_size = pad_size_dict[prompt_len]
250392

251393
if current_padded_size not in seen_sizes:
252394
filtered_dataset.append((prompt, prompt_len))
253395
seen_sizes.append(current_padded_size)
254396
# Forcing search for enforce_sizes
255-
elif enforce_sizes:
256-
current_padded_size = get_pad_size(prompt_len, pad_multiple)
397+
elif enforce_sizes or enforce_sizes_with_truncation:
398+
current_padded_size = pad_size_dict[prompt_len]
399+
# if it is in the enforce_size list
257400
if current_padded_size in enforce_sizes:
258401
enforce_sizes.remove(current_padded_size)
259402
enforced_dataset.append((prompt, prompt_len))
403+
# NOTE: this should not be `elif` despite enforce_sizes and enforce_sizes_with_truncation
404+
# are mutually exclusive because we allow same prompt to be used in enforce_sizes_with_truncation
405+
# even if it is taken from enforce_sizes
406+
if enforce_sizes_with_truncation:
407+
truncation_found: Tuple[int, int] = next(
408+
(
409+
tup
410+
for tup in enforce_sizes_with_truncation
411+
if tup[1] == current_padded_size
412+
),
413+
None,
414+
)
415+
if truncation_found:
416+
truncate_to_size, _ = truncation_found
417+
prompt_token_ids = tokenizer.encode(
418+
prompt, add_special_tokens=False
419+
)
420+
truncated_prompt = tokenizer.decode(
421+
prompt_token_ids[:truncate_to_size], skip_special_tokens=True
422+
)
423+
enforced_dataset.append((truncated_prompt, truncate_to_size))
424+
enforce_sizes_with_truncation.remove(truncation_found)
425+
260426
# when not enforcing heterogeneous or when exhausted all possible prompt_lengths
261427
else:
262428
filtered_dataset.append((prompt, prompt_len))
263-
assert not enforce_sizes, "Enforce size should be empty if all lengths are captured"
429+
if enforce_sizes:
430+
warnings.warn(
431+
f"{enforce_sizes=} so these sizes were not enforced, consider setting truncation=True",
432+
stacklevel=2,
433+
)
434+
if enforce_sizes_with_truncation:
435+
warnings.warn(
436+
f"{enforce_sizes_with_truncation=} so not all sizes with truncation enforced",
437+
stacklevel=2,
438+
)
264439

265440
if num_requests > max_heterogeneous_combinations:
266441
print(
267-
f"There will be prompt size repeats because {num_requests=} while {max_heterogeneous_combinations=}"
442+
f"There may be prompt size repeats because {num_requests=} while {max_heterogeneous_combinations=}"
268443
)
269444
if enforced_dataset:
270445
filtered_dataset = _merge_enforce_keep_heterogeneous(
271446
enforced_dataset, filtered_dataset, num_requests
272447
)
273448

449+
if len(filtered_dataset) != num_requests:
450+
warnings.warn("Returning dataset not equal to number requested", stacklevel=2)
451+
274452
return filtered_dataset
275453

276454

@@ -282,7 +460,8 @@ def sample_sharegpt_requests(
282460
prompt_length_max: int = 64,
283461
seed: Optional[int] = None,
284462
enforce_heterogeneous: bool = False,
285-
enforce_sizes: List[int] = [],
463+
enforce_sizes: List[int] | None = None,
464+
truncation: bool = False,
286465
pad_multiple: int = 64,
287466
) -> List[Tuple[str, int]]:
288467
if not os.path.exists(dataset_path):
@@ -292,6 +471,9 @@ def sample_sharegpt_requests(
292471
dataset_path,
293472
)
294473

474+
if enforce_sizes is None:
475+
enforce_sizes = []
476+
295477
# Load the dataset.
296478
with open(dataset_path, encoding="utf-8") as f:
297479
dataset = json.load(f)
@@ -308,7 +490,9 @@ def sample_sharegpt_requests(
308490
seed,
309491
enforce_heterogeneous,
310492
enforce_sizes,
493+
truncation,
311494
pad_multiple,
495+
_cached_dataset_key=dataset_path,
312496
)
313497

314498

@@ -320,11 +504,15 @@ def sample_squad_v2_qa_requests(
320504
prompt_length_max: int = 64,
321505
seed: Optional[int] = None,
322506
enforce_heterogeneous: bool = False,
323-
enforce_sizes: List[int] = [],
507+
enforce_sizes: List[int] | None = None,
508+
truncation: bool = False,
324509
pad_multiple: int = 64,
325510
) -> List[Tuple[str, int]]:
326511
from datasets import load_dataset
327512

513+
if enforce_sizes is None:
514+
enforce_sizes = []
515+
328516
if os.path.exists(dataset_path):
329517
ds = load_dataset(dataset_path)["train"]
330518
else:
@@ -341,6 +529,7 @@ def sample_squad_v2_qa_requests(
341529
seed,
342530
enforce_heterogeneous,
343531
enforce_sizes,
532+
truncation,
344533
pad_multiple,
345534
)
346535

0 commit comments

Comments
 (0)