1010from collections import deque
1111from dataclasses import dataclass
1212from time import perf_counter
13- from typing import Dict , List , Optional , Tuple , Union
13+ from typing import Any , Dict , List , Optional , Tuple , Union
1414
1515import numpy as np
1616import transformers
1717from transformers import PreTrainedTokenizer , PreTrainedTokenizerFast
1818
1919from QEfficient .generation .cloud_infer import QAICInferenceSession
2020from QEfficient .utils import padding_check_and_fix
21+ from QEfficient .utils .constants import Constants
2122from QEfficient .utils .logging_utils import logger
23+ from QEfficient .utils .sampler_utils import validate_sampler_inputs
2224
2325
2426@dataclass
@@ -322,6 +324,9 @@ def cloud_ai_100_exec_kv(
322324 automation = False ,
323325 prompt_to_lora_id_mapping : Optional [List [int ]] = None ,
324326 is_tlm : bool = False ,
327+ include_sampler : bool = False ,
328+ return_pdfs : bool = False ,
329+ sampling_params : Optional [Dict [str , Any ]] = None ,
325330):
326331 """
327332 This method generates output until ``eos`` or ``generation_len`` by executing the compiled ``qpc`` on ``Cloud AI 100`` Hardware cards.
@@ -342,6 +347,15 @@ def cloud_ai_100_exec_kv(
342347 :Write_io_dir (str): Path to write the input and output files. ``Defaults to None``.
343348 :automation (bool): If true, it prints input, output, and performance stats. ``Defaults to False``.
344349 :prompt_to_lora_id_mapping (List[int]): Mapping to associate prompts with their respective LoRA adapter.
350+ :include_sampler (bool, default=False): Enable/Disable sampling of next tokens.
351+ :return_pdfs (bool, default=False): Return probability distributions along with sampled
352+ next tokens. For Speculative Decoding Target Language Model,
353+ `return_pdfs`=True always. Otherwise, `return_pdfs`=True for Speculative
354+ Decoding Draft Language Model and `return_pdfs`=False for regular model.
355+ sampling_params (Dict[str, Any], default=None): A dictionary of sampling parameters supported by the QAIC backend.
356+ The dictionary should contain the following keys:
357+ `repetition_penalties`, `presence_penalties`, `temperatures`, `top_ks`, `top_ps`,
358+ `min_ps`, and `random_numbers`. Each value should be a numpy array of shape (batch_size, 1).
345359
346360 Returns:
347361 :CloudAI100ExecInfo: Object holding execution output and performance details.
@@ -372,6 +386,9 @@ def cloud_ai_100_exec_kv(
372386 write_io_dir = write_io_dir ,
373387 full_batch_size = full_batch_size ,
374388 is_tlm = is_tlm ,
389+ include_sampler = include_sampler ,
390+ return_pdfs = return_pdfs ,
391+ sampling_params = sampling_params ,
375392 )
376393 if full_batch_size is None :
377394 exec_info = [
@@ -411,14 +428,24 @@ def __init__(
411428 enable_debug_logs : bool = False ,
412429 write_io_dir : Optional [str ] = None ,
413430 is_tlm : Optional [int ] = None ,
431+ include_sampler : bool = False ,
432+ return_pdfs : bool = False ,
433+ sampling_params : Optional [Dict [str , Any ]] = None ,
414434 ) -> None :
415435 self ._ctx_len = ctx_len
416436 self ._write_io_dir = write_io_dir
417437 self .is_tlm = is_tlm
438+ self .return_pdfs = return_pdfs
439+ self .sampling_params = sampling_params
418440
419441 # Load QPC
420442 self ._session = QAICInferenceSession (qpc_path , device_id , enable_debug_logs = enable_debug_logs )
421443
444+ # Validate sampler inputs for On-Device Sampling
445+ self .include_sampler = validate_sampler_inputs (
446+ session_inputs = set (self ._session .input_names ), include_sampler = include_sampler
447+ )
448+
422449 # Fetch the variables from the QPC
423450 self ._vocab_size = self ._fetch_vocab_size () # Fetch Vocab size
424451 self .batch_size , self ._prefill_seq_len = self ._fetch_batch_size_prefill_seq_len ()
@@ -523,10 +550,17 @@ def _fetch_vocab_size(
523550 Returns:
524551 vocab_size: The vocabulary size fetched from the session's allowed shapes.
525552 """
553+ key = (
554+ "probs"
555+ if self .include_sampler and self .return_pdfs
556+ else "next_tokens"
557+ if self .include_sampler
558+ else "logits"
559+ )
526560 if self ._session .allowed_shapes :
527- return [x [self ._session .binding_index_map ["logits" ]] for x in self ._session .allowed_shapes ][0 ][1 ][2 ]
561+ return [x [self ._session .binding_index_map [key ]] for x in self ._session .allowed_shapes ][0 ][1 ][2 ]
528562
529- return self ._session .bindings [self ._session .binding_index_map ["logits" ]].dims [2 ]
563+ return self ._session .bindings [self ._session .binding_index_map [key ]].dims [2 ]
530564
531565 def _fetch_generation_len (self , generation_len , max_gen_len ):
532566 """
@@ -574,6 +608,13 @@ def prepare_decode_inputs(self):
574608 decode_inputs ["position_ids" ] = self .decode_pos_ids
575609 if self .batch_index is not None :
576610 decode_inputs ["batch_index" ] = self .batch_index
611+ if self .include_sampler :
612+ decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
613+ for op in Constants .SAMPLER_OPS :
614+ if self .batch_index is not None :
615+ decode_inputs [op ] = self .sampling_params [op ][self .batch_index .flatten ()]
616+ else :
617+ decode_inputs [op ] = self .sampling_params [op ]
577618
578619 if self ._prompt_to_lora_id_mapping_decode :
579620 if self .full_batch_size :
@@ -589,21 +630,24 @@ def prepare_decode_inputs(self):
589630
590631 def _fetch_next_token_id (self , outputs ):
591632 """
592- Fetches the next token ID from the model's output logits .
593- The method identifies the token with the highest probability using argmax along the last dimension.
633+ Fetches the next token ID from the model's output.
634+
594635 Args:
595- outputs (dict): A dictionary containing the model's output logits. The key "logits" should map to a numpy array of shape (batch_size, sequence_length, vocab_size) or (batch_size, vocab_size) .
636+ outputs (dict): A dictionary containing the model's output.
596637
597638 Returns:
598639 numpy.ndarray: An array of the next token IDs for each sequence in the batch.
599640 """
600- logits = outputs ["logits" ]
601- if len (logits .shape ) == 2 :
602- logits = np .expand_dims (logits , 1 )
603-
604- # Get output token
605- next_token_id = logits .argmax (2 )
606- return next_token_id
641+ if self .include_sampler :
642+ if self .return_pdfs :
643+ return outputs ["probs" ].argmax (2 )
644+ else :
645+ return outputs ["next_tokens" ].reshape (outputs ["next_tokens" ].shape [0 ], outputs ["next_tokens" ].shape [1 ])
646+ else :
647+ logits = outputs ["logits" ]
648+ if len (logits .shape ) == 2 :
649+ logits = np .expand_dims (logits , 1 )
650+ return logits .argmax (2 )
607651
608652 def initialize_decode_inputs (self , num_prompts , execution_batch_size , max_gen_length ):
609653 """
@@ -673,6 +717,23 @@ def run_prefill_for_all_inputs(self, prompt_queue, generation_len):
673717
674718 _ = self .update_decode_input (outputs , position_ids , generation_len , decode_batch_id )
675719
720+ def _set_output_buffers (self , batch_size : int = 1 , sequence_length : int = 1 ):
721+ """
722+ Sets the sizes of the output buffers.
723+
724+ Args:
725+ batch_size (int): The batch size.
726+ """
727+ if self .include_sampler :
728+ if self .return_pdfs :
729+ probs_out_placeholder = np .zeros ((batch_size , sequence_length , self ._vocab_size ), dtype = np .float32 )
730+ self ._session .set_buffers ({"probs" : probs_out_placeholder })
731+ next_tokens_out_placeholder = np .zeros ((batch_size , sequence_length , 1 ), dtype = np .int64 )
732+ self ._session .set_buffers ({"next_tokens" : next_tokens_out_placeholder })
733+ else :
734+ logits_out_placeholder = np .zeros ((batch_size , sequence_length , self ._vocab_size ), dtype = np .float32 )
735+ self ._session .set_buffers ({"logits" : logits_out_placeholder })
736+
676737 def run_prefill (self , prompt , generation_len , prefill_logit_bs = 1 , decode_batch_id = None ):
677738 """
678739 Runs prefill for a given prompt and generation length.
@@ -702,9 +763,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
702763 max_gen_len = self ._ctx_len - position_ids .max ()
703764 generation_len = self ._fetch_generation_len (generation_len , max_gen_len )
704765
705- # Set the prefill logic buffer
706- logits_out_placeholder = np .zeros ((prefill_logit_bs , 1 , self ._vocab_size ), dtype = np .float32 )
707- self ._session .set_buffers ({"logits" : logits_out_placeholder })
766+ # Set the prefill output buffers
767+ self ._set_output_buffers (batch_size = prefill_logit_bs , sequence_length = 1 )
708768
709769 inputs = self .tokenizer (prompt , return_tensors = "np" , padding = "max_length" , max_length = padded_len )
710770 inputs ["position_ids" ] = np .where (inputs .pop ("attention_mask" ), np .arange (padded_len ), - 1 )
@@ -714,6 +774,13 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
714774 inputs ["batch_index" ] = decode_batch_id
715775 if self .is_tlm :
716776 inputs ["num_logits_to_keep" ] = np .zeros ((1 , 1 ))
777+ if self .include_sampler :
778+ inputs ["last_accepted_output_tokens" ] = inputs ["input_ids" ]
779+ for op in Constants .SAMPLER_OPS :
780+ if decode_batch_id is not None :
781+ inputs [op ] = self .sampling_params [op ][decode_batch_id .flatten ()]
782+ else :
783+ inputs [op ] = self .sampling_params [op ]
717784
718785 if self ._prompt_to_lora_id_mapping_prefill :
719786 if self .full_batch_size :
@@ -732,6 +799,8 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
732799 chunk_inputs ["position_ids" ] = inputs ["position_ids" ][
733800 :, i * self ._prefill_seq_len : (i + 1 ) * self ._prefill_seq_len
734801 ]
802+ if self .include_sampler :
803+ chunk_inputs ["last_accepted_output_tokens" ] = chunk_inputs ["input_ids" ]
735804 outputs = self ._session .run (chunk_inputs )
736805 if self ._write_io_dir is not None :
737806 write_io_files (inputs , outputs , self ._write_io_dir , "prefill" , "aic_batch_io" , True , False )
@@ -753,11 +822,12 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
753822
754823 """
755824
756- # Set logits placeholder for decode
757- logits_out_placeholder = np .zeros (
758- (self .full_batch_size , self ._decode_seq_len , self ._vocab_size ), dtype = np .float32
825+ # Set output placeholders for decode
826+ self ._set_output_buffers (
827+ batch_size = self .full_batch_size ,
828+ sequence_length = self ._decode_seq_len ,
759829 )
760- self . _session . set_buffers ({ "logits" : logits_out_placeholder })
830+
761831 # Generate flag for tracking progress for each batch ID
762832 current_decode_ongoing = np .full ((self .full_batch_size , 1 ), True )
763833
@@ -775,10 +845,7 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
775845 outputs = self ._session .run (decode_inputs )
776846
777847 # Prepare inputs for next iteration
778- logits = outputs ["logits" ]
779- if len (logits .shape ) == 2 :
780- logits = np .expand_dims (logits , 1 )
781- next_token_id = logits .argmax (2 )
848+ next_token_id = self ._fetch_next_token_id (outputs )
782849
783850 for decode_batch_id in range (self .full_batch_size ):
784851 if (
@@ -800,7 +867,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
800867 self .generated_ids [batch_id_map [decode_batch_id ], 0 ] = new_token_id .squeeze (1 )
801868 generated_id_current_index [decode_batch_id ] = 1
802869
803- self ._session .set_buffers ({"logits" : logits_out_placeholder })
870+ self ._set_output_buffers (
871+ batch_size = self .full_batch_size ,
872+ sequence_length = self ._decode_seq_len ,
873+ )
804874 decode_pause_time += perf_counter () - start
805875
806876 if self ._prompt_to_lora_id_mapping_decode :
@@ -817,6 +887,8 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
817887 self .generated_ids [batch_id_map [decode_batch_id ], generated_id_current_index [decode_batch_id ]] = (
818888 next_token_id [decode_batch_id , - 1 ]
819889 )
890+ if self .include_sampler :
891+ decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
820892
821893 generated_id_current_index [decode_batch_id ] += 1
822894
@@ -852,10 +924,12 @@ def run_decode(self, decode_inputs, generation_len, streamer: Optional[transform
852924 self ._write_io_dir = None
853925
854926 # Prepare inputs for next iteration
855- decode_inputs ["input_ids" ] = outputs [ "logits" ]. argmax ( 2 )
927+ decode_inputs ["input_ids" ] = self . _fetch_next_token_id ( outputs )
856928 decode_inputs ["position_ids" ][:, - 1 ] += 1
857929 self .generated_ids [:, num_token ] = decode_inputs ["input_ids" ][:, - 1 ]
858930 finished_sequences |= decode_inputs ["input_ids" ] == self .tokenizer .eos_token_id
931+ if self .include_sampler :
932+ decode_inputs ["last_accepted_output_tokens" ] = decode_inputs ["input_ids" ]
859933
860934 if finished_sequences .all ():
861935 break
@@ -905,9 +979,22 @@ def __init__(
905979 enable_debug_logs : bool = False ,
906980 write_io_dir : Optional [str ] = None ,
907981 is_tlm : bool = False ,
982+ include_sampler : bool = False ,
983+ return_pdfs : bool = False ,
984+ sampling_params : Optional [Dict [str , Any ]] = None ,
908985 ) -> None :
909986 self ._qaic_model = QEffTextGenerationBase (
910- tokenizer , qpc_path , full_batch_size , ctx_len , device_id , enable_debug_logs , write_io_dir , is_tlm
987+ tokenizer = tokenizer ,
988+ qpc_path = qpc_path ,
989+ full_batch_size = full_batch_size ,
990+ ctx_len = ctx_len ,
991+ device_id = device_id ,
992+ enable_debug_logs = enable_debug_logs ,
993+ write_io_dir = write_io_dir ,
994+ is_tlm = is_tlm ,
995+ include_sampler = include_sampler ,
996+ return_pdfs = return_pdfs ,
997+ sampling_params = sampling_params ,
911998 )
912999 self ._full_batch_size = self ._qaic_model .full_batch_size
9131000 self ._tokenizer = self ._qaic_model .tokenizer
0 commit comments