@@ -86,73 +86,20 @@ def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torc
8686def model_forward (model , x , input_pos ):
8787 return model (x , input_pos )
8888
89- def speculative_decode (
90- model : Transformer ,
91- draft_model : Transformer ,
92- cur_token : torch .Tensor ,
93- input_pos : int ,
94- speculate_k : int ,
95- ** sampling_kwargs
96- ) -> torch .Tensor :
97- # draft model inference sequentially
98- device = cur_token .device
99- orig_input_pos = torch .tensor ([input_pos ], dtype = torch .int64 , device = cur_token .device )
100- draft_tokens , draft_probs = decode_n_tokens (draft_model , cur_token .view (1 , - 1 ), orig_input_pos .clone (), speculate_k , ** sampling_kwargs )
101-
102- draft_tokens = torch .cat (draft_tokens )
103- # parallel inference on target model using draft tokens
104- target_logits = model_forward (
105- model ,
106- torch .cat ([cur_token .view (1 ), draft_tokens ]).view (1 , - 1 ),
107- torch .arange (input_pos , input_pos + speculate_k + 1 , device = cur_token .device )
108- )
109- target_probs = logits_to_probs (target_logits [0 ], ** sampling_kwargs )
110- draft_probs = torch .stack (draft_probs )
111- # q: target prob, p: draft prob
112- # q >= p: always accept draft token
113- # q < p: q/p prob to accept draft token
114- p = draft_probs [torch .arange (0 , speculate_k , device = device ), draft_tokens ]
115- q = target_probs [torch .arange (0 , speculate_k , device = device ), draft_tokens ]
116- accept_draft_prob = torch .minimum (torch .ones (()), q [:speculate_k ]/ p )
117- rejected_locations = (torch .rand_like (accept_draft_prob ) > accept_draft_prob ).nonzero ()
118-
119- if rejected_locations .shape [0 ] == 0 : # All draft tokens have been accepted
120- accept_length = speculate_k + 1
121- last_token = multinomial_sample_one_no_sync (target_probs [- 1 ])
122- # fill last token into draft model
123- model_forward (
124- draft_model ,
125- draft_tokens [- 1 ].view (1 , - 1 ),
126- orig_input_pos + speculate_k ,
127- )
128- return torch .cat ([draft_tokens , last_token ])
129- else :
130- accept_length = rejected_locations [0 ].item ()
131- p = draft_probs [accept_length ]
132- q = target_probs [accept_length ]
133- new = q - p
134- new = torch .where (new > 0 , new , 0.0 )
135- new = new / new .sum ()
136- next_token = multinomial_sample_one_no_sync (new )
137- return torch .cat ([draft_tokens [:accept_length ], next_token ])
138-
13989@torch .no_grad ()
14090def generate (
14191 model : Transformer ,
14292 prompt : torch .Tensor ,
14393 max_new_tokens : int ,
14494 * ,
14595 interactive : bool ,
146- draft_model : Transformer ,
147- speculate_k : Optional [int ] = 8 ,
14896 callback = lambda x : x ,
14997 ** sampling_kwargs
15098) -> torch .Tensor :
15199 """
152100 Takes a conditioning sequence (prompt) as input and continues to generate as many tokens as requested.
153101 """
154102
155- is_speculative = draft_model is not None
156103 # create an empty tensor of the expected final shape and fill in the current tokens
157104 T = prompt .size (0 )
158105 T_new = T + max_new_tokens
@@ -162,11 +109,8 @@ def generate(
162109 max_seq_length = min (T_new , model .config .block_size )
163110
164111 device , dtype = prompt .device , prompt .dtype
165- max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
166112 with torch .device (device ):
167113 model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
168- if is_speculative and draft_model is not model :
169- draft_model .setup_caches (max_batch_size = 1 , max_seq_length = max_seq_length )
170114
171115 # create an empty tensor of the expected final shape and fill in the current tokens
172116 empty = torch .empty (T_new , dtype = dtype , device = device )
@@ -175,37 +119,14 @@ def generate(
175119 input_pos = torch .arange (0 , T , device = device )
176120
177121 next_token = prefill (model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs )
178- if is_speculative :
179- prefill (draft_model , prompt .view (1 , - 1 ), input_pos , ** sampling_kwargs )
180122 seq [T ] = next_token
181123
182124 input_pos = torch .tensor ([T ], device = device , dtype = torch .int )
183- accept_counts = [0 ] * (speculate_k + 1 )
184-
185- if is_speculative :
186- input_pos = input_pos .item () # for speculative decoding easier to keep on host
187- while input_pos < T_new - 1 :
188- cur_token = next_token .view (())
189-
190- next_tokens = speculative_decode (
191- model , draft_model , cur_token , input_pos , speculate_k , ** sampling_kwargs
192- )
193125
194- accept_counts [len (next_tokens ) - 1 ] += 1
195- num_added = min (T_new - input_pos - 1 , len (next_tokens ))
196- seq [input_pos + 1 : input_pos + num_added + 1 ] = next_tokens [: num_added ]
197- for i in next_tokens [: num_added ,]:
198- callback (i )
199- input_pos = input_pos + num_added
200- next_token = next_tokens [- 1 ]
201- else :
202- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
203- seq [T + 1 :] = torch .cat (generated_tokens )
126+ generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
127+ seq [T + 1 :] = torch .cat (generated_tokens )
204128
205- generate_stats = {
206- 'accept_counts' : accept_counts
207- }
208- return seq , generate_stats
129+ return seq
209130
210131def encode_tokens (tokenizer , string , bos = True , device = 'cuda' ):
211132 tokens = tokenizer .encode (string )
@@ -223,15 +144,6 @@ def _load_model(checkpoint_path, device, precision, use_tp):
223144 simple_quantizer = WeightOnlyBit8QuantHandler (model , torch .int8 )
224145 model = simple_quantizer .convert_for_runtime ()
225146
226- if "int4" in str (checkpoint_path ):
227- print ("Using int4 quantization!" )
228- path_comps = checkpoint_path .name .split ("." )
229- assert path_comps [- 2 ].startswith ("g" )
230- groupsize = int (path_comps [- 2 ][1 :])
231- from quantize import WeightOnlyInt4QuantHandler
232- simple_quantizer = WeightOnlyInt4QuantHandler (model , groupsize )
233- model = simple_quantizer .convert_for_runtime ()
234-
235147 checkpoint = torch .load (str (checkpoint_path ), mmap = True , weights_only = True )
236148 model .load_state_dict (checkpoint , assign = True )
237149
@@ -252,12 +164,10 @@ def main(
252164 max_new_tokens : int = 100 ,
253165 top_k : int = 200 ,
254166 temperature : float = 0.8 ,
255- checkpoint_path : Path = Path ("checkpoints/meta-Transformer/Transformer-2-7b-chat-hf /model.pth" ),
167+ checkpoint_path : Path = Path ("checkpoints/mistralai/Mixtral-8x7B-v0.1 /model.pth" ),
256168 compile : bool = True ,
257169 compile_prefill : bool = False ,
258170 profile : Optional [Path ] = None ,
259- draft_checkpoint_path : Optional [Path ] = None ,
260- speculate_k : int = 5 ,
261171 device = 'cuda' ,
262172) -> None :
263173 """Generates text samples based on a pre-trained Transformer model and tokenizer.
@@ -277,18 +187,12 @@ def main(
277187
278188 print (f"Using device={ device } " )
279189 precision = torch .bfloat16
280- is_speculative = draft_checkpoint_path is not None
281190 is_chat = "chat" in str (checkpoint_path )
282191
283192 print ("Loading model ..." )
284193 t0 = time .time ()
285194 model = _load_model (checkpoint_path , device , precision , use_tp )
286195
287- if is_speculative :
288- draft_model = _load_model (draft_checkpoint_path , device , precision , use_tp )
289- else :
290- draft_model = None
291-
292196 device_sync (device = device ) # MKG
293197 print (f"Time to load model: { time .time () - t0 :.02f} seconds" )
294198
@@ -299,14 +203,7 @@ def main(
299203 torch .manual_seed (1234 )
300204 model_size = sum ([p .numel () * p .dtype .itemsize for p in itertools .chain (model .parameters (), model .buffers ())])
301205 if compile :
302- if is_speculative and use_tp : # and ("cuda" in device):
303- torch ._inductor .config .triton .cudagraph_trees = False # Bug with cudagraph trees in this case
304- if model .config .moe :
305- torch ._inductor .config .assert_indirect_indexing = False
306-
307- if is_speculative :
308- global model_forward , logits_to_prob
309- model_forward = torch .compile (model_forward , mode = "reduce-overhead" , fullgraph = True )
206+ torch ._inductor .config .assert_indirect_indexing = False
310207
311208 global decode_one_token , prefill
312209 decode_one_token = torch .compile (decode_one_token , mode = "reduce-overhead" , fullgraph = True )
@@ -318,7 +215,6 @@ def main(
318215
319216 aggregate_metrics = {
320217 'tokens_per_sec' : [],
321- 'accept_counts' : [],
322218 }
323219 start = - 1 if compile else 0
324220
@@ -355,18 +251,15 @@ def callback(x):
355251 torch .profiler ._utils ._init_for_cuda_graphs ()
356252 prof = torch .profiler .profile ()
357253 with prof :
358- y , metrics = generate (
254+ y = generate (
359255 model ,
360256 encoded ,
361257 max_new_tokens ,
362- draft_model = draft_model ,
363- speculate_k = speculate_k ,
364258 interactive = interactive ,
365259 callback = callback ,
366260 temperature = temperature ,
367261 top_k = top_k ,
368262 )
369- aggregate_metrics ['accept_counts' ].append (metrics ['accept_counts' ])
370263 if i == - 1 :
371264 print (f"Compilation time: { time .perf_counter () - t0 :.2f} seconds" )
372265 continue
@@ -387,12 +280,6 @@ def callback(x):
387280 aggregate_metrics ['tokens_per_sec' ].append (tokens_sec )
388281 print (f"Time for inference { i + 1 } : { t :.02f} sec total, { tokens_sec :.02f} tokens/sec" )
389282 print (f"Bandwidth achieved: { model_size * tokens_sec / 1e9 :.02f} GB/s" )
390- print ("==========" )
391- if is_speculative :
392- counts_aggregated = [sum (i ) for i in zip (* aggregate_metrics ['accept_counts' ])]
393- acceptance_probs = [i / sum (counts_aggregated ) for i in counts_aggregated ]
394- print (f"Acceptance probs: { acceptance_probs } " )
395- print (f"Mean Accepted: { sum ([idx * i for idx , i in enumerate (counts_aggregated )])/ sum (counts_aggregated )} " )
396283
397284 print (f"Average tokens/sec: { torch .mean (torch .tensor (aggregate_metrics ['tokens_per_sec' ])).item ():.2f} " )
398285 print (f"Memory used: { torch .cuda .max_memory_reserved () / 1e9 :.02f} GB" )
@@ -412,13 +299,10 @@ def callback(x):
412299 parser .add_argument ('--compile' , action = 'store_true' , help = 'Whether to compile the model.' )
413300 parser .add_argument ('--compile_prefill' , action = 'store_true' , help = 'Whether to compile the prefill (improves prefill perf, but higher compile times)' )
414301 parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
415- parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
416- parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
417302 parser .add_argument ('--device' , type = str , default = "cuda" , help = 'device to use' )
418303
419304 args = parser .parse_args ()
420305 main (
421306 args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
422- args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
423- args .speculate_k , args .device
307+ args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .device
424308 )
0 commit comments