@@ -67,21 +67,10 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
6767 logits = model (x , input_pos )
6868 return sample (logits , ** sampling_kwargs )
6969
70- def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , use_sdpa = False , callback = lambda _ : _ , ** sampling_kwargs ):
70+ def decode_n_tokens (model : Transformer , cur_token : torch .Tensor , input_pos : torch .Tensor , num_new_tokens : int , callback = lambda _ : _ , ** sampling_kwargs ):
7171 new_tokens , new_probs = [], []
72- if not use_sdpa :
73- for i in range (num_new_tokens ):
74- with torch .backends .cuda .sdp_kernel (enable_flash = False , enable_mem_efficient = False , enable_math = True ): # Actually better for Inductor to codegen attention here
75- next_token , next_prob = decode_one_token (
76- model , cur_token , input_pos , ** sampling_kwargs
77- )
78- input_pos += 1
79- new_tokens .append (next_token .clone ())
80- callback (new_tokens [- 1 ])
81- new_probs .append (next_prob .clone ())
82- cur_token = next_token .view (1 , - 1 )
83- else :
84- for i in range (num_new_tokens ):
72+ for i in range (num_new_tokens ):
73+ with torch .backends .cuda .sdp_kernel (enable_flash = False , enable_mem_efficient = False , enable_math = True ): # Actually better for Inductor to codegen attention here
8574 next_token , next_prob = decode_one_token (
8675 model , cur_token , input_pos , ** sampling_kwargs
8776 )
@@ -103,13 +92,12 @@ def speculative_decode(
10392 cur_token : torch .Tensor ,
10493 input_pos : int ,
10594 speculate_k : int ,
106- use_sdpa = False ,
10795 ** sampling_kwargs
10896) -> torch .Tensor :
10997 # draft model inference sequentially
11098 device = cur_token .device
11199 orig_input_pos = torch .tensor ([input_pos ], dtype = torch .int64 , device = cur_token .device )
112- draft_tokens , draft_probs = decode_n_tokens (draft_model , cur_token .view (1 , - 1 ), orig_input_pos .clone (), speculate_k , use_sdpa = use_sdpa , ** sampling_kwargs )
100+ draft_tokens , draft_probs = decode_n_tokens (draft_model , cur_token .view (1 , - 1 ), orig_input_pos .clone (), speculate_k , ** sampling_kwargs )
113101
114102 draft_tokens = torch .cat (draft_tokens )
115103 # parallel inference on target model using draft tokens
@@ -157,7 +145,6 @@ def generate(
157145 interactive : bool ,
158146 draft_model : Transformer ,
159147 speculate_k : Optional [int ] = 8 ,
160- use_sdpa = False ,
161148 callback = lambda x : x ,
162149 ** sampling_kwargs
163150) -> torch .Tensor :
@@ -201,7 +188,7 @@ def generate(
201188 cur_token = next_token .view (())
202189
203190 next_tokens = speculative_decode (
204- model , draft_model , cur_token , input_pos , speculate_k , use_sdpa = use_sdpa , ** sampling_kwargs
191+ model , draft_model , cur_token , input_pos , speculate_k , ** sampling_kwargs
205192 )
206193
207194 accept_counts [len (next_tokens ) - 1 ] += 1
@@ -212,7 +199,7 @@ def generate(
212199 input_pos = input_pos + num_added
213200 next_token = next_tokens [- 1 ]
214201 else :
215- generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , use_sdpa = use_sdpa , callback = callback , ** sampling_kwargs )
202+ generated_tokens , _ = decode_n_tokens (model , next_token .view (1 , - 1 ), input_pos , max_new_tokens - 1 , callback = callback , ** sampling_kwargs )
216203 seq [T + 1 :] = torch .cat (generated_tokens )
217204
218205 generate_stats = {
@@ -271,7 +258,6 @@ def main(
271258 profile : Optional [Path ] = None ,
272259 draft_checkpoint_path : Optional [Path ] = None ,
273260 speculate_k : int = 5 ,
274- use_sdpa = False ,
275261 device = 'cuda' ,
276262) -> None :
277263 """Generates text samples based on a pre-trained Transformer model and tokenizer.
@@ -374,7 +360,6 @@ def callback(x):
374360 max_new_tokens ,
375361 draft_model = draft_model ,
376362 speculate_k = speculate_k ,
377- use_sdpa = use_sdpa ,
378363 interactive = interactive ,
379364 callback = callback ,
380365 temperature = temperature ,
@@ -428,12 +413,11 @@ def callback(x):
428413 parser .add_argument ('--profile' , type = Path , default = None , help = 'Profile path.' )
429414 parser .add_argument ('--speculate_k' , type = int , default = 5 , help = 'Speculative execution depth.' )
430415 parser .add_argument ('--draft_checkpoint_path' , type = Path , default = None , help = 'Draft checkpoint path.' )
431- parser .add_argument ('--use_sdpa' , action = 'store_true' , help = 'Whether to use SDPA' )
432416 parser .add_argument ('--device' , type = str , default = "cuda" , help = 'device to use' )
433417
434418 args = parser .parse_args ()
435419 main (
436420 args .prompt , args .interactive , args .num_samples , args .max_new_tokens , args .top_k ,
437421 args .temperature , args .checkpoint_path , args .compile , args .compile_prefill , args .profile , args .draft_checkpoint_path ,
438- args .speculate_k , args .use_sdpa , args . device
422+ args .speculate_k , args .device
439423 )
0 commit comments