1111
1212from aiu_fms_testing_utils .utils .aiu_setup import dprint , rank , world_size
1313from transformers .tokenization_utils_base import PreTrainedTokenizerBase
14+ from aiu_fms_testing_utils .testing .utils import format_kwargs_to_string
1415
1516from fms .utils .generation import pad_input_ids
1617import torch
@@ -85,7 +86,7 @@ def warmup_model(
8586 ** extra_kwargs ,
8687 )
8788
88- extra_kwargs = {** _extra_kwargs , "only_last_token " : "paged" not in attn_name }
89+ extra_kwargs = {** _extra_kwargs , "last_n_tokens " : 64 if "paged" in attn_name else 1 }
8990
9091 with stagger_region (stagger_update_lazyhandle ):
9192 with torch_sendnn .warmup_mode ():
@@ -421,8 +422,11 @@ def __sample_requests(
421422 prompt_token_ids = tokenizer .encode (
422423 prompt , add_special_tokens = False
423424 )
425+ # If we don't set clean_up_tokenization_spaces=False, encoding then decoding text might result in different lengths which would break expected results from the sampler
424426 truncated_prompt = tokenizer .decode (
425- prompt_token_ids [:truncate_to_size ], skip_special_tokens = True
427+ prompt_token_ids [:truncate_to_size ],
428+ skip_special_tokens = True ,
429+ clean_up_tokenization_spaces = False ,
426430 )
427431 enforced_dataset .append ((truncated_prompt , truncate_to_size ))
428432 enforce_sizes_with_truncation .remove (truncation_found )
@@ -479,6 +483,7 @@ def sample_rag_factoid_requests(
479483 enforce_sizes : List [int ] = [],
480484 truncation : bool = False ,
481485 pad_multiple : int = 64 ,
486+ return_key : bool = False ,
482487) -> List [Tuple [str , int ]]:
483488 if not os .path .exists (dataset_path ):
484489 print ("error dataset does not exist" )
@@ -489,7 +494,7 @@ def sample_rag_factoid_requests(
489494 for line in f :
490495 dataset .append (line )
491496
492- return __sample_requests (
497+ sample_request = __sample_requests (
493498 dataset ,
494499 num_requests ,
495500 tokenizer ,
@@ -503,6 +508,24 @@ def sample_rag_factoid_requests(
503508 _cached_dataset_key = dataset_path ,
504509 )
505510
511+ if return_key :
512+ sample_key : str = format_kwargs_to_string (
513+ dataset = "rag_factoid" ,
514+ num_requests = num_requests ,
515+ tokenizer = tokenizer .name_or_path .replace ("/" , "--" ),
516+ prompt_length_min = prompt_length_min ,
517+ prompt_length_max = prompt_length_max ,
518+ seed = seed ,
519+ enforce_heterogeneous = enforce_heterogeneous ,
520+ enforce_sizes = enforce_sizes ,
521+ truncate = truncation ,
522+ pad_multiple = pad_multiple ,
523+ )
524+
525+ return sample_request , sample_key
526+ else :
527+ return sample_request
528+
506529
507530def sample_sharegpt_requests (
508531 dataset_path : str ,
@@ -515,6 +538,7 @@ def sample_sharegpt_requests(
515538 enforce_sizes : List [int ] | None = None ,
516539 truncation : bool = False ,
517540 pad_multiple : int = 64 ,
541+ return_key : bool = False ,
518542) -> List [Tuple [str , int ]]:
519543 if not os .path .exists (dataset_path ):
520544 print ("downloading share-gpt dataset as it does not exist" )
@@ -540,7 +564,7 @@ def sample_sharegpt_requests(
540564 dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
541565 dataset : List [str ] = [data ["conversations" ][0 ]["value" ] for data in dataset ]
542566
543- return __sample_requests (
567+ sample_request = __sample_requests (
544568 dataset ,
545569 num_requests ,
546570 tokenizer ,
@@ -554,6 +578,23 @@ def sample_sharegpt_requests(
554578 _cached_dataset_key = dataset_path ,
555579 )
556580
581+ if return_key :
582+ sample_key : str = format_kwargs_to_string (
583+ dataset = "sharegpt" ,
584+ num_requests = num_requests ,
585+ tokenizer = tokenizer .name_or_path .replace ("/" , "--" ),
586+ prompt_length_min = prompt_length_min ,
587+ prompt_length_max = prompt_length_max ,
588+ seed = seed ,
589+ enforce_heterogeneous = enforce_heterogeneous ,
590+ enforce_sizes = enforce_sizes ,
591+ truncate = truncation ,
592+ pad_multiple = pad_multiple ,
593+ )
594+ return sample_request , sample_key
595+ else :
596+ return sample_request
597+
557598
558599def sample_squad_v2_qa_requests (
559600 dataset_path : str ,
@@ -566,6 +607,7 @@ def sample_squad_v2_qa_requests(
566607 enforce_sizes : List [int ] | None = None ,
567608 truncation : bool = False ,
568609 pad_multiple : int = 64 ,
610+ return_key : bool = False ,
569611) -> List [Tuple [str , int ]]:
570612 from datasets import load_dataset
571613
@@ -579,7 +621,7 @@ def sample_squad_v2_qa_requests(
579621
580622 ds = [f"{ data ['context' ]} \n { data ['question' ]} " for data in ds ]
581623
582- return __sample_requests (
624+ sample_request = __sample_requests (
583625 ds ,
584626 num_requests ,
585627 tokenizer ,
@@ -592,6 +634,23 @@ def sample_squad_v2_qa_requests(
592634 pad_multiple ,
593635 )
594636
637+ if return_key :
638+ sample_key : str = format_kwargs_to_string (
639+ dataset = "squad_v2" ,
640+ num_requests = num_requests ,
641+ tokenizer = tokenizer .name_or_path .replace ("/" , "--" ),
642+ prompt_length_min = prompt_length_min ,
643+ prompt_length_max = prompt_length_max ,
644+ seed = seed ,
645+ enforce_heterogeneous = enforce_heterogeneous ,
646+ enforce_sizes = enforce_sizes ,
647+ truncate = truncation ,
648+ pad_multiple = pad_multiple ,
649+ )
650+ return sample_request , sample_key
651+ else :
652+ return sample_request
653+
595654
596655def prepare_inputs (
597656 batch_size , seq_length , tokenizer , ds_path , seed = 0 , ds_type = "sharegpt"
0 commit comments