1- import torch
2- import torch .nn as nn
3- import time
4- from fms .utils .tokenizers import BaseTokenizer
5- from aiu_fms_testing_utils .utils .aiu_setup import dprint
1+ # Standard
62from typing import Optional , List , Tuple
7- import os
8- import requests
93import json
4+ import os
105import random
6+ import requests
7+ import time
8+
9+ # Third Party
10+ from aiu_fms_testing_utils .utils .aiu_setup import dprint
11+ from fms .utils .tokenizers import BaseTokenizer
12+ import torch
13+ import torch .nn as nn
1114
12- def warmup_model (model : nn .Module , input_ids : torch .Tensor , max_new_tokens : int , compile_dynamic_sendnn = False , use_cache : bool = True , ** extra_kwargs ):
15+
16+ def warmup_model (
17+ model : nn .Module ,
18+ input_ids : torch .Tensor ,
19+ max_new_tokens : int ,
20+ compile_dynamic_sendnn : bool = False ,
21+ use_cache : bool = True ,
22+ ** extra_kwargs
23+ ):
1324 import torch_sendnn
1425 attention_specific_kwargs = {}
1526 attn_name = extra_kwargs ["attn_name" ]
@@ -19,7 +30,7 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
1930 # TODO: Add a unified generation dependent on attn_type
2031 from fms .utils .generation import generate
2132 attention_specific_kwargs ["contiguous_cache" ] = True
22-
33+
2334 dprint ("AIU warmup" )
2435 pt_compile_model_time = time .time ()
2536
@@ -31,12 +42,23 @@ def warmup_model(model: nn.Module, input_ids: torch.Tensor, max_new_tokens: int,
3142 _max_new_tokens = 2
3243 # always warmup with batch size 2 when using attn_type=paged
3344 if "paged" in attn_name :
34- _warmup_input_ids , _extra_kwargs = adjust_inputs_to_batch (input_ids , ** extra_kwargs )
45+ _warmup_input_ids , _extra_kwargs = adjust_inputs_to_batch (
46+ input_ids ,
47+ ** extra_kwargs ,
48+ )
3549
3650 extra_kwargs = {** _extra_kwargs , "only_last_token" : "paged" not in attn_name }
3751
3852 with torch_sendnn .warmup_mode ():
39- generate (model , _warmup_input_ids , max_new_tokens = _max_new_tokens , do_sample = False , use_cache = use_cache , extra_kwargs = extra_kwargs , ** attention_specific_kwargs )
53+ generate (
54+ model ,
55+ _warmup_input_ids ,
56+ max_new_tokens = _max_new_tokens ,
57+ do_sample = False ,
58+ use_cache = use_cache ,
59+ extra_kwargs = extra_kwargs ,
60+ ** attention_specific_kwargs ,
61+ )
4062 pt_compile_model_time = time .time () - pt_compile_model_time
4163 dprint (f"PT compile complete, took { pt_compile_model_time :.3f} s" )
4264
@@ -52,17 +74,17 @@ def __download_file(url, filename):
5274 try :
5375 response = requests .get (url , stream = True )
5476 response .raise_for_status ()
55-
77+
5678 with open (filename , 'wb' ) as file :
5779 for chunk in response .iter_content (chunk_size = 8192 ):
5880 file .write (chunk )
5981 print (f"Successfully downloaded { filename } " )
60-
82+
6183 except requests .exceptions .RequestException as e :
6284 print (f"An error occurred: { e } " )
6385
6486def __sample_requests (
65- prompt_list : List [str ],
87+ prompt_list : List [str ],
6688 num_requests : int ,
6789 tokenizer : BaseTokenizer ,
6890 prompt_length_min : int = 32 ,
@@ -82,16 +104,14 @@ def __sample_requests(
82104 # Tokenize the prompts and completions.
83105 prompt = prompt_list [i ]
84106 prompt_token_ids = ids_for_prompt (prompt , tokenizer )
85-
107+
86108 prompt_len = len (prompt_token_ids )
87109 if prompt_len < prompt_length_min or prompt_len > prompt_length_max :
88110 # Prune too short or too long sequences.
89111 continue
90112 filtered_dataset .append ((prompt , prompt_len ))
91-
92- return filtered_dataset
93-
94113
114+ return filtered_dataset
95115
96116def sample_sharegpt_requests (
97117 dataset_path : str ,
@@ -111,15 +131,22 @@ def sample_sharegpt_requests(
111131 # Filter out the conversations with less than 2 turns.
112132 dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
113133 dataset = [data ["conversations" ][0 ]["value" ] for data in dataset ]
114-
115- return __sample_requests (dataset , num_requests , tokenizer , prompt_length_min , prompt_length_max , seed )
134+
135+ return __sample_requests (
136+ dataset ,
137+ num_requests ,
138+ tokenizer ,
139+ prompt_length_min ,
140+ prompt_length_max ,
141+ seed ,
142+ )
116143
117144def sample_squad_v2_qa_requests (
118145 dataset_path : str ,
119- num_requests : int ,
120- tokenizer : BaseTokenizer ,
121- prompt_length_min : int = 32 ,
122- prompt_length_max : int = 64 ,
146+ num_requests : int ,
147+ tokenizer : BaseTokenizer ,
148+ prompt_length_min : int = 32 ,
149+ prompt_length_max : int = 64 ,
123150 seed : Optional [int ] = None
124151) -> List [Tuple [str , int ]]:
125152 from datasets import load_dataset
@@ -128,10 +155,14 @@ def sample_squad_v2_qa_requests(
128155 ds = load_dataset (dataset_path )['train' ]
129156 else :
130157 ds = load_dataset ("rajpurkar/squad_v2" , cache_dir = dataset_path )['train' ]
131-
132-
133- ds = [f"{ data ['context' ]} \n { data ['question' ]} " for data in ds ]
134158
135- return __sample_requests (ds , num_requests , tokenizer , prompt_length_min , prompt_length_max , seed )
136-
159+ ds = [f"{ data ['context' ]} \n { data ['question' ]} " for data in ds ]
137160
161+ return __sample_requests (
162+ ds ,
163+ num_requests ,
164+ tokenizer ,
165+ prompt_length_min ,
166+ prompt_length_max ,
167+ seed ,
168+ )
0 commit comments