11import argparse
2+ import ast
23import json
34import os
45import random
89
910from aiu_fms_testing_utils .testing .validation import capture_level_1_metrics , extract_validation_information , LogitsExtractorHook , print_failed_cases , \
1011 validate_level_0 , GoldenTokenHook , top_k_loss_calculator
11- from aiu_fms_testing_utils .utils import ids_for_prompt
12+ from aiu_fms_testing_utils .utils import ids_for_prompt , sample_sharegpt_requests
1213from fms .models import get_model
1314from fms .utils import tokenizers
1415from fms .utils .generation import pad_input_ids
8485 help = "top k values per token to generate loss on" ,
8586 default = 20
8687)
88+ parser .add_argument (
89+ "--num_test_tokens_per_sequence" ,
90+ type = int ,
91+ help = "number of tokens in test. For instance, if max_new_tokens=128 and num_test_tokens_per_sequence=256, this means we will generate data over 2 sample prompts. If not set, will be set to max_new_tokens" ,
92+ default = None
93+ )
94+ parser .add_argument (
95+ "--extra_get_model_kwargs" ,
96+ nargs = '*' ,
97+ default = {},
98+ help = "Use this to override model configuration values to get model. Example: --extra_get_model_kwargs nlayers=2,..."
99+ )
87100args = parser .parse_args ()
88101
102+ extra_get_model_kwargs = {}
103+ for a in args .extra_get_model_kwargs :
104+ a_split = a .split ("=" )
105+ try :
106+ extra_get_model_kwargs [a_split [0 ]] = ast .literal_eval (a_split [1 ])
107+ except ValueError :
108+ extra_get_model_kwargs [a_split [0 ]] = a_split [1 ]
89109
90- prefix = f"{ args .variant .replace ('/' , '--' )} _max-new-tokens-{ args .max_new_tokens } _batch-size-{ args .batch_size } _seq-length{ args .min_pad_length } _dtype-{ args .default_dtype } "
110+ # this follows the same pattern of naming in test_shapes. This way we can save and re-use for quicker shape testing.
111+ prefix = f"{ args .variant .replace ('/' , '--' )} _max-new-tokens-{ args .max_new_tokens } _batch-size-{ args .batch_size } _seq-length-{ args .min_pad_length } _dtype-{ args .default_dtype } "
91112if os .path .exists (os .path .join (args .output_dir , f"{ prefix } .prob_mean.csv" )):
92113 print ("skipping metric generation as it has already been done" )
93114 exit (0 )
115136 model_path = args .model_path ,
116137 device_type = "cuda" ,
117138 data_type = default_dtype ,
139+ ** extra_get_model_kwargs ,
118140)
119141
120- print ("loaded cuda model" )
121-
122142cuda_model .eval ()
143+ print ("loaded cuda model" )
123144
124145# prepare the cpu model (this is the reference)
125146cpu_model = get_model (
128149 model_path = args .model_path ,
129150 device_type = "cpu" ,
130151 data_type = torch .float32 ,
152+ ** extra_get_model_kwargs ,
131153)
132154cpu_model .eval ()
133155print ("loaded cpu model" )
134156
135- def sample_sharegpt_requests (
136- dataset_path : str ,
137- num_requests : int ,
138- tokenizer ,
139- ) -> List [Tuple [str , int , int , None ]]:
140- # Load the dataset.
141- with open (dataset_path , encoding = 'utf-8' ) as f :
142- dataset = json .load (f )
143- # Filter out the conversations with less than 2 turns.
144- dataset = [data for data in dataset if len (data ["conversations" ]) >= 2 ]
145- # Only keep the first two turns of each conversation.
146- dataset = [(data ["conversations" ][0 ]["value" ],
147- data ["conversations" ][1 ]["value" ]) for data in dataset ]
148-
149- # Shuffle the dataset.
150- random .Random (42 ).shuffle (dataset )
151-
152- # Filter out sequences that are too long or too short
153- filtered_dataset : List [Tuple [str , int , int ]] = []
154- for i in range (len (dataset )):
155- if len (filtered_dataset ) == num_requests :
156- break
157-
158- # Tokenize the prompts and completions.
159- prompt = dataset [i ][0 ]
160- prompt_token_ids = ids_for_prompt (prompt , tokenizer )
161-
162- prompt_len = len (prompt_token_ids )
163- if prompt_len < 32 or prompt_len > args .min_pad_length :
164- # Prune too short sequences.
165- continue
166- filtered_dataset .append ((prompt , prompt_len ))
167-
168- return filtered_dataset
169-
170157def find_eos_index (reference_tokens , eos_token_id ):
171158 result = []
172159 for sentence in reference_tokens :
@@ -184,21 +171,17 @@ def filter_before_eos(l, filter_indexes):
184171 from itertools import groupby
185172 filtered_results = [list (g )[:filter_indexes [k ]] for k , g in groupby (l , key = lambda x : x [0 ])]
186173 return [item for sublist in filtered_results for item in sublist ]
187-
188- prompts_and_lens = sample_sharegpt_requests (args .sharegpt_path , args .batch_size , tokenizer )
189- print (f"prompt_lengths: { [pl [1 ] for pl in prompts_and_lens ]} " )
190- prompts = [ids_for_prompt (pl [0 ], tokenizer ) for pl in prompts_and_lens ]
191174
192- padding_length = args .min_pad_length
175+ def __prepare_inputs (batch_size , seq_length , tokenizer , seed = 0 ):
176+ prompts_and_sizes = sample_sharegpt_requests (args .sharegpt_path , batch_size , tokenizer , seq_length // 2 , seq_length , seed )
177+ prompt_list = []
178+ for prompt , _ in prompts_and_sizes :
179+ prompt_list .append (ids_for_prompt (prompt , tokenizer ))
193180
194- has_padding = args . batch_size > 1 or padding_length != 0
195- max_len = max ([ len ( prompt ) for prompt in prompts ])
181+ input_ids , padding_kwargs = pad_input_ids ( prompt_list , min_pad_length = seq_length )
182+ return input_ids , padding_kwargs
196183
197- if has_padding :
198- ids , padding_kwargs = pad_input_ids (prompts , min_pad_length = padding_length )
199- else :
200- ids = prompts
201- padding_kwargs = {}
184+ ids , padding_kwargs = __prepare_inputs (args .batch_size , args .min_pad_length , tokenizer )
202185
203186# first test validation level 0
204187cpu_validation_info = extract_validation_information (
@@ -231,63 +214,88 @@ def filter_before_eos(l, filter_indexes):
231214if len (failed_responses ) != 0 :
232215 print_failed_cases (failed_responses , cpu_static_tokens , cuda_static_tokens , tokenizer )
233216
234- # generate aiu validation info
235- cuda_validation_info = extract_validation_information (
236- cuda_model ,
237- ids .to ("cuda" ),
238- args .max_new_tokens ,
239- GoldenTokenHook (cpu_static_tokens , "cuda" ),
240- only_last_token = True ,
241- ** {k : v .to ("cuda" ) for k ,v in padding_kwargs .items ()}
242- )
243-
244- print ("extracted cuda validation information level 1" )
245-
246- cross_entropy = lambda r , t : torch .nn .CrossEntropyLoss ()(r , t .softmax (dim = 1 ).to (dtype = torch .float32 ))
247- prob_mean = lambda r , t : torch .mean ((r .softmax (dim = 1 ).to (dtype = torch .float32 ) / t .softmax (dim = 1 ).to (dtype = torch .float32 )) - 1.0 )
248- prob_std = lambda r , t : torch .std (r .softmax (dim = 1 ).to (dtype = torch .float32 ) / t .softmax (dim = 1 ).to (dtype = torch .float32 ))
249- diff_mean = lambda r , t : torch .mean (r .softmax (dim = 1 ).to (dtype = torch .float32 ) - t .softmax (dim = 1 ).to (dtype = torch .float32 ))
250-
251217def write_csv (l , path , metric ):
252218 with open (path , 'w' ) as f :
253219 f .write (f'{ metric } \n ' )
254220 for t in l :
255221 f .write (f"{ t [2 ].item ()} \n " )
256222 f .close ()
257223
258- prefix = f"{ args .variant .replace ('/' , '--' )} _max-new-tokens-{ args .max_new_tokens } _batch-size-{ args .batch_size } _seq-length{ args .min_pad_length } _dtype-{ args .default_dtype } "
259-
260- cpu_validation_info .save (os .path .join (args .output_dir , f"{ prefix } .cpu_output_logits.out" ))
261- cuda_validation_info .save (os .path .join (args .output_dir , f"{ prefix } .cuda_output_logits.out" ))
224+ num_test_tokens_per_sequence = args .num_test_tokens_per_sequence
225+ if num_test_tokens_per_sequence is None :
226+ num_test_tokens_per_sequence = args .max_new_tokens
262227
263- level_1_metrics = capture_level_1_metrics (
264- cpu_validation_info .get_info ("logits" ),
265- cuda_validation_info .get_info ("logits" ),
266- top_k_loss_calculator (args .topk_per_token , prob_mean ),
267- )
268- loss_metrics = filter_before_eos (level_1_metrics , eos_indexes )
269- write_csv (loss_metrics , os .path .join (args .output_dir , f"{ prefix } .prob_mean.csv" ), "prob_mean" )
228+ cross_entropy = lambda r , t : torch .nn .CrossEntropyLoss ()(r , t .softmax (dim = 1 ).to (dtype = torch .float32 ))
229+ prob_mean = lambda r , t : torch .mean ((r .softmax (dim = 1 ).to (dtype = torch .float32 ) / t .softmax (dim = 1 ).to (dtype = torch .float32 )) - 1.0 )
230+ prob_std = lambda r , t : torch .std (r .softmax (dim = 1 ).to (dtype = torch .float32 ) / t .softmax (dim = 1 ).to (dtype = torch .float32 ))
231+ diff_mean = lambda r , t : torch .mean (r .softmax (dim = 1 ).to (dtype = torch .float32 ) - t .softmax (dim = 1 ).to (dtype = torch .float32 ))
270232
271- level_1_metrics = capture_level_1_metrics (
272- cpu_validation_info .get_info ("logits" ),
273- cuda_validation_info .get_info ("logits" ),
274- top_k_loss_calculator (args .topk_per_token , prob_std ),
275- )
276- loss_metrics = filter_before_eos (level_1_metrics , eos_indexes )
277- write_csv (loss_metrics , os .path .join (args .output_dir , f"{ prefix } .prob_std.csv" ), "prob_std" )
233+ prob_mean_metrics = []
234+ prob_std_metrics = []
235+ prob_diff_metrics = []
236+ prob_ce_loss_metrics = []
278237
279- level_1_metrics = capture_level_1_metrics (
280- cpu_validation_info .get_info ("logits" ),
281- cuda_validation_info .get_info ("logits" ),
282- top_k_loss_calculator (args .topk_per_token , cross_entropy ),
283- )
284- loss_metrics = filter_before_eos (level_1_metrics , eos_indexes )
285- write_csv (loss_metrics , os .path .join (args .output_dir , f"{ prefix } .ce.csv" ), "ce" )
238+ prefix = f"{ args .variant .replace ('/' , '--' )} _max-new-tokens-{ args .max_new_tokens } _batch-size-{ args .batch_size } _seq-length{ args .min_pad_length } _dtype-{ args .default_dtype } "
286239
287- level_1_metrics = capture_level_1_metrics (
288- cpu_validation_info .get_info ("logits" ),
289- cuda_validation_info .get_info ("logits" ),
290- top_k_loss_calculator (args .topk_per_token , diff_mean ),
291- )
292- loss_metrics = filter_before_eos (level_1_metrics , eos_indexes )
293- write_csv (loss_metrics , os .path .join (args .output_dir , f"{ prefix } .diff_mean.csv" ), "diff_mean" )
240+ for i in range (num_test_tokens_per_sequence // args .max_new_tokens ):
241+ ids , padding_kwargs = __prepare_inputs (args .batch_size , args .min_pad_length , tokenizer , i )
242+
243+ # only need to compute this once if we aren't generating more test data
244+ if num_test_tokens_per_sequence > args .max_new_tokens :
245+ cpu_validation_info = extract_validation_information (
246+ cpu_model ,
247+ ids ,
248+ args .max_new_tokens ,
249+ LogitsExtractorHook (),
250+ attn_algorithm = "math" ,
251+ ** padding_kwargs
252+ )
253+ eos_indexes = find_eos_index (cpu_validation_info .get_info ("tokens" ), tokenizer .eos_token_id )
254+
255+ # generate aiu validation info
256+ cuda_validation_info = extract_validation_information (
257+ cuda_model ,
258+ ids .to ("cuda" ),
259+ args .max_new_tokens ,
260+ GoldenTokenHook (cpu_validation_info .get_info ("tokens" ), "cuda" ),
261+ only_last_token = True ,
262+ ** {k : v .to ("cuda" ) for k ,v in padding_kwargs .items ()}
263+ )
264+
265+ print ("extracted cuda validation information level 1" )
266+
267+ cpu_validation_info .save (os .path .join (args .output_dir , f"{ prefix } .cpu_validation_info.{ i } .out" ))
268+ cuda_validation_info .save (os .path .join (args .output_dir , f"{ prefix } .cuda_validation_info.{ i } .out" ))
269+
270+ level_1_metrics = capture_level_1_metrics (
271+ cpu_validation_info .get_info ("logits" ),
272+ cuda_validation_info .get_info ("logits" ),
273+ top_k_loss_calculator (args .topk_per_token , prob_mean ),
274+ )
275+ prob_mean_metrics .extend (filter_before_eos (level_1_metrics , eos_indexes ))
276+
277+ level_1_metrics = capture_level_1_metrics (
278+ cpu_validation_info .get_info ("logits" ),
279+ cuda_validation_info .get_info ("logits" ),
280+ top_k_loss_calculator (args .topk_per_token , prob_std ),
281+ )
282+ prob_std_metrics .extend (filter_before_eos (level_1_metrics , eos_indexes ))
283+
284+ level_1_metrics = capture_level_1_metrics (
285+ cpu_validation_info .get_info ("logits" ),
286+ cuda_validation_info .get_info ("logits" ),
287+ top_k_loss_calculator (args .topk_per_token , cross_entropy ),
288+ )
289+ prob_ce_loss_metrics .extend (filter_before_eos (level_1_metrics , eos_indexes ))
290+
291+ level_1_metrics = capture_level_1_metrics (
292+ cpu_validation_info .get_info ("logits" ),
293+ cuda_validation_info .get_info ("logits" ),
294+ top_k_loss_calculator (args .topk_per_token , diff_mean ),
295+ )
296+ prob_diff_metrics .extend (filter_before_eos (level_1_metrics , eos_indexes ))
297+
298+ write_csv (prob_mean_metrics , os .path .join (args .output_dir , f"{ prefix } .prob_mean.csv" ), "prob_mean" )
299+ write_csv (prob_std_metrics , os .path .join (args .output_dir , f"{ prefix } .prob_std.csv" ), "prob_std" )
300+ write_csv (prob_ce_loss_metrics , os .path .join (args .output_dir , f"{ prefix } .ce.csv" ), "ce" )
301+ write_csv (prob_diff_metrics , os .path .join (args .output_dir , f"{ prefix } .diff_mean.csv" ), "diff_mean" )
0 commit comments