@@ -96,6 +96,8 @@ def __init__(
9696 self ,
9797 dataset_path : Optional [str ] = None ,
9898 random_seed : int = DEFAULT_SEED ,
99+ disable_shuffle : bool = False ,
100+ ** kwargs ,
99101 ) -> None :
100102 """
101103 Initialize the BenchmarkDataset with an optional dataset path and random
@@ -111,6 +113,7 @@ def __init__(
111113 # Set the random seed, ensuring that a None value is replaced with the
112114 # default seed.
113115 self .random_seed = random_seed if random_seed is not None else self .DEFAULT_SEED
116+ self .disable_shuffle = disable_shuffle
114117 self .data = None
115118
116119 def apply_multimodal_chat_transformation (
@@ -1044,7 +1047,8 @@ def load_data(self) -> None:
10441047 if "conversations" in entry and len (entry ["conversations" ]) >= 2
10451048 ]
10461049 random .seed (self .random_seed )
1047- random .shuffle (self .data )
1050+ if not getattr (self , "disable_shuffle" , False ):
1051+ random .shuffle (self .data )
10481052
10491053 def sample (
10501054 self ,
@@ -1175,6 +1179,11 @@ def add_dataset_parser(parser: FlexibleArgumentParser):
11751179 action = "store_true" ,
11761180 help = "Skip applying chat template to prompt for datasets that support it." ,
11771181 )
1182+ parser .add_argument (
1183+ "--disable-shuffle" ,
1184+ action = "store_true" ,
1185+ help = "Disable shuffling of dataset samples for deterministic ordering." ,
1186+ )
11781187
11791188 # group for dataset specific arguments
11801189 custom_group = parser .add_argument_group ("custom dataset options" )
@@ -1441,7 +1450,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14411450 args .request_id_prefix = ""
14421451
14431452 if args .dataset_name == "custom" :
1444- dataset = CustomDataset (dataset_path = args .dataset_path )
1453+ dataset = CustomDataset (
1454+ dataset_path = args .dataset_path , disable_shuffle = args .disable_shuffle
1455+ )
14451456 input_requests = dataset .sample (
14461457 num_requests = args .num_prompts ,
14471458 tokenizer = tokenizer ,
@@ -1452,7 +1463,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
14521463 )
14531464
14541465 elif args .dataset_name == "sonnet" :
1455- dataset = SonnetDataset (dataset_path = args .dataset_path )
1466+ dataset = SonnetDataset (
1467+ dataset_path = args .dataset_path , disable_shuffle = args .disable_shuffle
1468+ )
14561469 # For the "sonnet" dataset, formatting depends on the backend.
14571470 if args .backend == "openai-chat" :
14581471 input_requests = dataset .sample (
@@ -1586,6 +1599,7 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
15861599 random_seed = args .seed ,
15871600 no_stream = args .no_stream ,
15881601 hf_name = args .hf_name ,
1602+ disable_shuffle = args .disable_shuffle ,
15891603 ).sample (
15901604 num_requests = args .num_prompts ,
15911605 tokenizer = tokenizer ,
@@ -1600,7 +1614,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16001614 # For datasets that follow a similar structure, use a mapping.
16011615 dataset_mapping = {
16021616 "spec_bench" : lambda : SpecBench (
1603- dataset_path = args .dataset_path , category = args .spec_bench_category
1617+ dataset_path = args .dataset_path ,
1618+ category = args .spec_bench_category ,
1619+ disable_shuffle = args .disable_shuffle ,
16041620 ).sample (
16051621 num_requests = args .num_prompts ,
16061622 tokenizer = tokenizer ,
@@ -1609,7 +1625,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16091625 no_oversample = args .no_oversample ,
16101626 ),
16111627 "sharegpt" : lambda : ShareGPTDataset (
1612- random_seed = args .seed , dataset_path = args .dataset_path
1628+ random_seed = args .seed ,
1629+ dataset_path = args .dataset_path ,
1630+ disable_shuffle = args .disable_shuffle ,
16131631 ).sample (
16141632 tokenizer = tokenizer ,
16151633 num_requests = args .num_prompts ,
@@ -1618,15 +1636,19 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16181636 no_oversample = args .no_oversample ,
16191637 ),
16201638 "burstgpt" : lambda : BurstGPTDataset (
1621- random_seed = args .seed , dataset_path = args .dataset_path
1639+ random_seed = args .seed ,
1640+ dataset_path = args .dataset_path ,
1641+ disable_shuffle = args .disable_shuffle ,
16221642 ).sample (
16231643 tokenizer = tokenizer ,
16241644 num_requests = args .num_prompts ,
16251645 request_id_prefix = args .request_id_prefix ,
16261646 no_oversample = args .no_oversample ,
16271647 ),
16281648 "random" : lambda : RandomDataset (
1629- random_seed = args .seed , dataset_path = args .dataset_path
1649+ random_seed = args .seed ,
1650+ dataset_path = args .dataset_path ,
1651+ disable_shuffle = args .disable_shuffle ,
16301652 ).sample (
16311653 tokenizer = tokenizer ,
16321654 num_requests = args .num_prompts ,
@@ -1639,7 +1661,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16391661 no_oversample = args .no_oversample ,
16401662 ),
16411663 "random-mm" : lambda : RandomMultiModalDataset (
1642- random_seed = args .seed , dataset_path = args .dataset_path
1664+ random_seed = args .seed ,
1665+ dataset_path = args .dataset_path ,
1666+ disable_shuffle = args .disable_shuffle ,
16431667 ).sample (
16441668 tokenizer = tokenizer ,
16451669 num_requests = args .num_prompts ,
@@ -1655,7 +1679,9 @@ def get_samples(args, tokenizer) -> list[SampleRequest]:
16551679 no_oversample = args .no_oversample ,
16561680 ),
16571681 "prefix_repetition" : lambda : PrefixRepetitionRandomDataset (
1658- random_seed = args .seed , dataset_path = args .dataset_path
1682+ random_seed = args .seed ,
1683+ dataset_path = args .dataset_path ,
1684+ disable_shuffle = args .disable_shuffle ,
16591685 ).sample (
16601686 tokenizer = tokenizer ,
16611687 num_requests = args .num_prompts ,
@@ -1733,7 +1759,8 @@ def load_data(self) -> None:
17331759 )
17341760
17351761 random .seed (self .random_seed )
1736- random .shuffle (self .data )
1762+ if not getattr (self , "disable_shuffle" , False ):
1763+ random .shuffle (self .data )
17371764
17381765 def sample (
17391766 self ,
@@ -1825,7 +1852,8 @@ def load_data(self) -> None:
18251852 self .data .append ({"prompt" : prompt })
18261853
18271854 random .seed (self .random_seed )
1828- random .shuffle (self .data )
1855+ if not getattr (self , "disable_shuffle" , False ):
1856+ random .shuffle (self .data )
18291857
18301858 def sample (self , ** kwargs ) -> list :
18311859 # leverage CustomDataset sample
@@ -2033,7 +2061,8 @@ def load_data(self) -> None:
20332061 split = self .dataset_split ,
20342062 streaming = self .load_stream ,
20352063 )
2036- self .data = self .data .shuffle (seed = self .random_seed )
2064+ if not getattr (self , "disable_shuffle" , False ):
2065+ self .data = self .data .shuffle (seed = self .random_seed )
20372066
20382067
20392068# -----------------------------------------------------------------------------
@@ -2849,7 +2878,8 @@ def _generate_exact_length_tokens(target_length: int) -> list[int]:
28492878 abs (token_mismatch_total ),
28502879 sign ,
28512880 )
2852- random .shuffle (requests )
2881+ if not getattr (self , "disable_shuffle" , False ):
2882+ random .shuffle (requests )
28532883 return requests
28542884
28552885
0 commit comments