55import random
66import requests
77import time
8+ import bisect
89
910# Third Party
1011
@@ -166,12 +167,83 @@ def _merge_enforce_keep_heterogeneous(
166167 )
167168 elif len (final_list ) < batch_size :
168169 warnings .warn (
169- f"Requested { batch_size = } , than possible combined list. Will return smaller list than batch size" ,
170+ f"Requested { batch_size = } , is greater than possible combined list. Will return smaller list than batch size" ,
170171 stacklevel = 2 ,
171172 )
172173 return final_list
173174
174175
176+ def _get_truncation_size (
177+ dataset_size_and_count : dict [int , int ], enforce_sizes : List [int ]
178+ ):
179+ """
180+ Given a list of sizes to enforce and a dictionary of sizes that exists and their count,
181+ find out which sizes are not possible and create a new truncation list which will grab from
182+ the next larger size in order to enforce that size.
183+ If there are no larger sizes, try to take the largest from the dataset.
184+
185+ Args:
186+ dataset_size_and_count (Dict[int, int]): List of possible sizes and counts for the dataset
187+ enforce_sizes (List[int]): List of ints which sizes must be enforced
188+
189+ Returns:
190+ List[Tuple[int,int]]: a List of Tuples which have first int as size to truncate to, and second int as to prompt len to grab from
191+ """
192+ truncation_list : List [Tuple [int , int ]] = []
193+ sorted_sizes_in_dataset : List [int ] = sorted (dataset_size_and_count .keys ())
194+ # sort for consistent results where user mixes order of enforce_sizes
195+ enforce_sizes = sorted (enforce_sizes )
196+
197+ for size_to_enforce in enforce_sizes :
198+ found_idx = bisect .bisect_left (sorted_sizes_in_dataset , size_to_enforce )
199+ truncation_size = None
200+
201+ # if valid search found
202+ if found_idx < len (sorted_sizes_in_dataset ):
203+ while found_idx < len (sorted_sizes_in_dataset ):
204+ # reset the candidate to the new found_idx
205+ candidate = sorted_sizes_in_dataset [found_idx ]
206+ # Have to check if this prompt length is available with the count
207+ if dataset_size_and_count [candidate ] > 0 :
208+ # if count is > 0 then decrement the count as it no longer can be used for future prompts
209+ dataset_size_and_count [candidate ] -= 1
210+ truncation_size = candidate
211+ break
212+ # if prompt length is not avaible increment to see if the next larger prompt is available
213+ found_idx += 1
214+
215+ if truncation_size is None :
216+ raise ValueError (
217+ f"We've exhausted all possible truncation sizes, please increase max_prompt_len or remove { size_to_enforce = } "
218+ )
219+ truncation_list .append ((size_to_enforce , truncation_size ))
220+ else :
221+ # this occurs when size_to_enforce is outside of the max range of dataset
222+ if sorted_sizes_in_dataset :
223+ # try to grab the largest size from the end of sorted list if it is available otherwise throw error
224+ truncation_size = sorted_sizes_in_dataset [- 1 ]
225+ if dataset_size_and_count [truncation_size ] > 0 :
226+ truncation_list .append ((size_to_enforce , truncation_size ))
227+ dataset_size_and_count [truncation_size ] -= 1
228+ else :
229+ raise ValueError (
230+ f"{ size_to_enforce = } is larger than largest sample and not available."
231+ )
232+ return truncation_list
233+
234+
235+ def _remove_list_from_list (main_list , list_to_remove ):
236+ for item in list_to_remove :
237+ if item in main_list :
238+ main_list .remove (item )
239+ return main_list
240+
241+
242+ # Because we now require encoding the dataset, cache the datasets to make
243+ # second sample request quick
244+ __cached_encoded_datasets = {}
245+
246+
175247def __sample_requests (
176248 prompt_list : List [str ],
177249 num_requests : int ,
@@ -180,97 +252,203 @@ def __sample_requests(
180252 prompt_length_max : int = 64 ,
181253 seed : Optional [int ] = None ,
182254 enforce_heterogeneous : bool = False ,
183- enforce_sizes : List [int ] = [],
255+ enforce_sizes : List [int ] | None = None ,
256+ truncation : bool = False ,
184257 pad_multiple : int = 64 ,
258+ _cached_dataset_key : Optional [str ] = None ,
185259):
186260 """
187- Shuffles dataset, tokenizes the prompts and then filters
261+ Shuffles dataset, tokenizes the prompts and then filters.
188262
189263 Args:
190264 prompt_length_min (int): filters out prompts shorter than this value.
191265 prompt_length_max (int): filters out prompts larger than this value.
192266 enforce_sizes (List[int]): sample request will grab a prompt with this length if available.
193- enforce_heterogeneous (bool): Pads all prompts within batch size to nearest multiple of 64.
267+ enforce_heterogeneous (bool): Pads all prompts within batch to nearest multiple of `pad_multiple`.
268+ However, if enforce_sizes is not empty, it will set enforce_heteogeneous to False.
194269 pad_multiple (int): Used only when enforce_heterogeneous is True or enforce_sizes is not empty, asserts that prompt_length would be padded to this multiple
195270 List[Tuple[str, int]]: a filtered dataset
271+ truncation (bool): If true will truncate to an enforced size if the size does not exist. Only to be used with enforce_sizes, otherwise
272+ will be ignored
273+ _cached_dataset_key (optional[str]): The key to the dataset if enabling caching of encoded datasets
274+
275+ Returns:
276+ List[Tuple[str, int]]
196277 """
197278
279+ assert prompt_length_max >= prompt_length_min , (
280+ "Please enter valid prompt length max/min values"
281+ )
282+
283+ if enforce_sizes is None :
284+ enforce_sizes = []
285+
286+ if enforce_heterogeneous and enforce_sizes :
287+ warnings .warn (
288+ f"{ enforce_heterogeneous = } and { enforce_sizes = } , these two are not designed to be used at the same time. Forcing enforce_heterogeneous to False"
289+ )
290+ enforce_heterogeneous = False
291+
198292 # Based on min/max prompt length, one can back out the number of possible heterogeneous values
199293 max_heterogeneous_combinations = (prompt_length_max // pad_multiple ) - (
200294 (prompt_length_min - 1 ) // pad_multiple
201295 )
202296
203297 # Filter out sequences that are too long or too short
298+ dataset : List [Tuple [str , int ]] = []
204299 filtered_dataset : List [Tuple [str , int ]] = []
205300 enforced_dataset : List [Tuple [str , int ]] = []
206301
207302 # To track sizes seen
208303 seen_sizes : List [int ] = []
209304
305+ sample_size_counter : dict [int , int ] = {}
306+ # first int is the size to truncate to, second int is size of text to grab from
307+ enforce_sizes_with_truncation : List [Tuple [int , int ]] = []
308+
309+ if truncation and not enforce_sizes :
310+ warnings .warn (
311+ f"truncation and enforce_sizes should be used together, whereas { truncation = } and { enforce_sizes = } , hence no truncation will happen" ,
312+ stacklevel = 2 ,
313+ )
314+
315+ if (
316+ _cached_dataset_key is not None
317+ and _cached_dataset_key in __cached_encoded_datasets
318+ ):
319+ dataset = __cached_encoded_datasets [_cached_dataset_key ]
320+ else :
321+ # Loop to check create filtered dataset
322+ for i in range (len (prompt_list )):
323+ # Tokenize the prompts and completions.
324+ prompt = prompt_list [i ]
325+ prompt_token_ids = tokenizer .encode (prompt , return_tensors = "pt" ).squeeze (0 )
326+
327+ prompt_len = len (prompt_token_ids )
328+
329+ dataset .append ((prompt , prompt_len ))
330+
331+ dataset .sort (key = lambda tuple : tuple [1 ])
332+ __cached_encoded_datasets [_cached_dataset_key ] = dataset
333+
334+ # only keep values that are required
335+ dataset = [
336+ r for r in dataset if r [1 ] >= prompt_length_min and r [1 ] <= prompt_length_max
337+ ]
338+
339+ pad_size_dict : dict [int , int ] = {}
340+ for _ , prompt_len in dataset :
341+ pad_size_dict .setdefault (prompt_len , get_pad_size (prompt_len , pad_multiple ))
342+ sample_size_counter [pad_size_dict [prompt_len ]] = (
343+ sample_size_counter .get (pad_size_dict [prompt_len ], 0 ) + 1
344+ )
345+
210346 if enforce_sizes :
211347 for size in enforce_sizes :
212348 # Check that enforced sizes fall within min/max range
213349 assert prompt_length_min <= size <= prompt_length_max , (
214350 f"Size { size } in enforced sizes not within { prompt_length_min = } , { prompt_length_max = } "
215351 )
352+ assert size % pad_multiple == 0 , (
353+ "Enforce sizes must be a multiple of pad_multiple"
354+ )
216355 if len (enforce_sizes ) > num_requests :
217356 raise ValueError (
218357 f"{ num_requests = } which is smaller than { len (enforce_sizes )= } "
219358 )
220359
360+ if truncation :
361+ truncation_size_counter = sample_size_counter .copy ()
362+
363+ # Allocate certain counts to enforce_sizes
364+ needs_truncation = []
365+ for size in enforce_sizes :
366+ if sample_size_counter .get (size , 0 ) > 0 :
367+ sample_size_counter [size ] -= 1
368+ else :
369+ needs_truncation .append (size )
370+ enforce_sizes = _remove_list_from_list (enforce_sizes , needs_truncation )
371+
372+ enforce_sizes_with_truncation = _get_truncation_size (
373+ truncation_size_counter , needs_truncation
374+ )
375+
221376 # Shuffle the dataset.
222377 if seed is not None :
223- random .Random (seed ).shuffle (prompt_list )
378+ random .Random (seed ).shuffle (dataset )
224379
225- for i in range ( len ( prompt_list )) :
380+ for prompt , prompt_len in dataset :
226381 if len (filtered_dataset ) == num_requests and not enforce_sizes :
227382 break
228383
229- # Tokenize the prompts and completions.
230- prompt = prompt_list [i ]
231- prompt_token_ids = tokenizer .encode (prompt , return_tensors = "pt" ).squeeze (0 )
232-
233- prompt_len = len (prompt_token_ids )
234- if prompt_len < prompt_length_min or prompt_len > prompt_length_max :
235- # Prune too short or too long sequences.
236- continue
237- # This section is for enforce heterogeneous
384+ # NOTE: This section is for enforce heterogeneous, does not work with enforce_sizes
238385 if (
239386 enforce_heterogeneous
240387 and max_heterogeneous_combinations > len (filtered_dataset )
241388 and len (filtered_dataset ) < num_requests
242389 ):
243390 # for _, size in filtered_dataset:
244- current_padded_size = get_pad_size (prompt_len , pad_multiple )
245-
246- # If it's in the list of enforce_sizes it is enforced, can remove from list
247- if current_padded_size in enforce_sizes :
248- enforce_sizes .remove (current_padded_size )
249- enforced_dataset .append ((prompt , prompt_len ))
391+ current_padded_size = pad_size_dict [prompt_len ]
250392
251393 if current_padded_size not in seen_sizes :
252394 filtered_dataset .append ((prompt , prompt_len ))
253395 seen_sizes .append (current_padded_size )
254396 # Forcing search for enforce_sizes
255- elif enforce_sizes :
256- current_padded_size = get_pad_size (prompt_len , pad_multiple )
397+ elif enforce_sizes or enforce_sizes_with_truncation :
398+ current_padded_size = pad_size_dict [prompt_len ]
399+ # if it is in the enforce_size list
257400 if current_padded_size in enforce_sizes :
258401 enforce_sizes .remove (current_padded_size )
259402 enforced_dataset .append ((prompt , prompt_len ))
403+ # NOTE: this should not be `elif` despite enforce_sizes and enforce_sizes_with_truncation
404+ # are mutually exclusive because we allow same prompt to be used in enforce_sizes_with_truncation
405+ # even if it is taken from enforce_sizes
406+ if enforce_sizes_with_truncation :
407+ truncation_found : Tuple [int , int ] = next (
408+ (
409+ tup
410+ for tup in enforce_sizes_with_truncation
411+ if tup [1 ] == current_padded_size
412+ ),
413+ None ,
414+ )
415+ if truncation_found :
416+ truncate_to_size , _ = truncation_found
417+ prompt_token_ids = tokenizer .encode (
418+ prompt , add_special_tokens = False
419+ )
420+ truncated_prompt = tokenizer .decode (
421+ prompt_token_ids [:truncate_to_size ], skip_special_tokens = True
422+ )
423+ enforced_dataset .append ((truncated_prompt , truncate_to_size ))
424+ enforce_sizes_with_truncation .remove (truncation_found )
425+
260426 # when not enforcing heterogeneous or when exhausted all possible prompt_lengths
261427 else :
262428 filtered_dataset .append ((prompt , prompt_len ))
263- assert not enforce_sizes , "Enforce size should be empty if all lengths are captured"
429+ if enforce_sizes :
430+ warnings .warn (
431+ f"{ enforce_sizes = } so these sizes were not enforced, consider setting truncation=True" ,
432+ stacklevel = 2 ,
433+ )
434+ if enforce_sizes_with_truncation :
435+ warnings .warn (
436+ f"{ enforce_sizes_with_truncation = } so not all sizes with truncation enforced" ,
437+ stacklevel = 2 ,
438+ )
264439
265440 if num_requests > max_heterogeneous_combinations :
266441 print (
267- f"There will be prompt size repeats because { num_requests = } while { max_heterogeneous_combinations = } "
442+ f"There may be prompt size repeats because { num_requests = } while { max_heterogeneous_combinations = } "
268443 )
269444 if enforced_dataset :
270445 filtered_dataset = _merge_enforce_keep_heterogeneous (
271446 enforced_dataset , filtered_dataset , num_requests
272447 )
273448
449+ if len (filtered_dataset ) != num_requests :
450+ warnings .warn ("Returning dataset not equal to number requested" , stacklevel = 2 )
451+
274452 return filtered_dataset
275453
276454
@@ -282,7 +460,8 @@ def sample_sharegpt_requests(
282460 prompt_length_max : int = 64 ,
283461 seed : Optional [int ] = None ,
284462 enforce_heterogeneous : bool = False ,
285- enforce_sizes : List [int ] = [],
463+ enforce_sizes : List [int ] | None = None ,
464+ truncation : bool = False ,
286465 pad_multiple : int = 64 ,
287466) -> List [Tuple [str , int ]]:
288467 if not os .path .exists (dataset_path ):
@@ -292,6 +471,9 @@ def sample_sharegpt_requests(
292471 dataset_path ,
293472 )
294473
474+ if enforce_sizes is None :
475+ enforce_sizes = []
476+
295477 # Load the dataset.
296478 with open (dataset_path , encoding = "utf-8" ) as f :
297479 dataset = json .load (f )
@@ -308,7 +490,9 @@ def sample_sharegpt_requests(
308490 seed ,
309491 enforce_heterogeneous ,
310492 enforce_sizes ,
493+ truncation ,
311494 pad_multiple ,
495+ _cached_dataset_key = dataset_path ,
312496 )
313497
314498
@@ -320,11 +504,15 @@ def sample_squad_v2_qa_requests(
320504 prompt_length_max : int = 64 ,
321505 seed : Optional [int ] = None ,
322506 enforce_heterogeneous : bool = False ,
323- enforce_sizes : List [int ] = [],
507+ enforce_sizes : List [int ] | None = None ,
508+ truncation : bool = False ,
324509 pad_multiple : int = 64 ,
325510) -> List [Tuple [str , int ]]:
326511 from datasets import load_dataset
327512
513+ if enforce_sizes is None :
514+ enforce_sizes = []
515+
328516 if os .path .exists (dataset_path ):
329517 ds = load_dataset (dataset_path )["train" ]
330518 else :
@@ -341,6 +529,7 @@ def sample_squad_v2_qa_requests(
341529 seed ,
342530 enforce_heterogeneous ,
343531 enforce_sizes ,
532+ truncation ,
344533 pad_multiple ,
345534 )
346535
0 commit comments