Skip to content

Commit d2e3d98

Browse files
committed
Add configurable sample requests to prepare inputs
Signed-off-by: Flavia Beo <flavia.beo@ibm.com>
1 parent 28c44f8 commit d2e3d98

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

aiu_fms_testing_utils/utils/__init__.py

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -168,15 +168,39 @@ def sample_squad_v2_qa_requests(
168168
seed,
169169
)
170170

171-
def prepare_inputs(batch_size, seq_length, tokenizer, sharegpt_path, seed=0):
172-
prompts_and_sizes = sample_sharegpt_requests(
173-
sharegpt_path,
174-
batch_size,
175-
tokenizer,
176-
int(seq_length / 2),
177-
seq_length,
178-
seed,
179-
)
171+
def prepare_inputs(batch_size, seq_length, tokenizer, ds_path, seed=0, ds_type="sharegpt"):
172+
"""
173+
Prepare input IDs and padding kwargs for a batch of questions.
174+
175+
Args:
176+
batch_size (int): The number of questions in the batch.
177+
seq_length (int): The maximum length of the input sequence.
178+
tokenizer (Tokenizer): A tokenizer object to tokenize the questions.
179+
ds_path (str): The path to the dataset file.
180+
seed (int, optional): The random seed for reproducibility. Defaults to 0.
181+
ds_type (str, optional): The type of dataset to use. Can be "sharegpt" or any other supported dataset type. Defaults to "sharegpt".
182+
183+
Returns:
184+
tuple: A tuple containing the input IDs and padding kwargs.
185+
"""
186+
if not "sharegpt" in ds_type:
187+
prompts_and_sizes = sample_squad_v2_qa_requests(
188+
ds_path,
189+
batch_size,
190+
tokenizer,
191+
int(seq_length / 2),
192+
seq_length,
193+
seed,
194+
)
195+
else:
196+
prompts_and_sizes = sample_sharegpt_requests(
197+
ds_path,
198+
batch_size,
199+
tokenizer,
200+
int(seq_length / 2),
201+
seq_length,
202+
seed,
203+
)
180204
prompt_list = []
181205
for prompt, _ in prompts_and_sizes:
182206
prompt_list.append(ids_for_prompt(prompt, tokenizer))

0 commit comments

Comments
 (0)