@@ -721,20 +721,37 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
721721 ]
722722 _onnx_transforms = [FP16ClipTransform , SplitTensorsTransform ]
723723
724- def __init__ (self , model , ** kwargs ):
724+ def __init__ (
725+ self ,
726+ model ,
727+ qaic_config : Optional [dict ] = None ,
728+ ** kwargs
729+ ):
725730 """
726731 Initializes the language decoder component for multimodal models.
727732
728733 Parameters
729734 ----------
730735 model : nn.Module
731736 The full HuggingFace multimodal model from which the language decoder is extracted.
737+ qaic_config : dict, optional
738+ A dictionary for QAIC-specific configurations.
739+ Only the following keys are supported by the text model of the dual QPC multimodal model:
740+ - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
741+ - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
742+ Additional keys will be ignored.
732743 **kwargs :
733744 Additional keyword arguments passed to the base class constructor.
734745 """
735746 super ().__init__ (model , ** kwargs )
736747 self .model = model .get_qeff_language_decoder ()
737748 self .hash_params ["qeff_auto_class" ] = self .__class__ .__name__
749+ self .model .qaic_config = qaic_config
750+ # ---Sampling---
751+ # Note: SamplerTransform should be applied after all other transforms
752+ # are done. The role of the sampler is to just add nodes at the output of the
753+ # previous transform function.
754+ self .model , _ = SamplerTransform .apply (self .model , qaic_config , ** kwargs )
738755
739756 def export (self , inputs , output_names , dynamic_axes , export_dir = None , offload_pt_weights = True ):
740757 """
@@ -758,10 +775,95 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
758775 str
759776 Path to the generated ONNX graph file for the language decoder.
760777 """
778+ if self .model .qaic_config is not None and self .model .qaic_config .get ("include_sampler" , False ):
779+ inputs , output_names , dynamic_axes = self .get_sampling_inputs_and_outputs (inputs , output_names , dynamic_axes )
761780 return self ._export (
762781 inputs , output_names , dynamic_axes , export_dir = export_dir , offload_pt_weights = offload_pt_weights
763782 )
764783
784+ def get_sampling_inputs_and_outputs (
785+ self ,
786+ example_inputs : Dict [str , torch .Tensor ],
787+ output_names : List [str ],
788+ dynamic_axes : Dict [str , Dict [int , str ]],
789+ ):
790+ """
791+ Updates the example inputs, output names, and dynamic axes to include
792+ parameters relevant for on-device sampling during ONNX export.
793+
794+ Parameters
795+ ----------
796+ example_inputs : Dict[str, torch.Tensor]
797+ Current dictionary of example inputs.
798+ output_names : List[str]
799+ Current list of output names.
800+ dynamic_axes : Dict[str, Dict[int, str]]
801+ Current dictionary of dynamic axes configurations.
802+
803+ Returns
804+ -------
805+ Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
806+ Updated example inputs, output names, and dynamic axes including
807+ sampling-related parameters.
808+ """
809+ bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
810+
811+ assert "logits" in output_names , "logits must be part of the output names to suport on-device sampling"
812+
813+ logits_index = output_names .index ("logits" )
814+ output_names [logits_index ] = "next_tokens"
815+
816+ example_inputs ["last_accepted_output_tokens" ] = torch .zeros (
817+ (bs , constants .ONNX_EXPORT_EXAMPLE_SEQ_LEN ), dtype = torch .int64
818+ )
819+ dynamic_axes ["last_accepted_output_tokens" ] = {0 : "batch_size" , 1 : "seq_len" }
820+
821+ example_inputs ["past_repetition_penalty_buffer" ] = torch .zeros (
822+ (bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
823+ )
824+ dynamic_axes ["past_repetition_penalty_buffer" ] = {
825+ 0 : "batch_size" ,
826+ }
827+ output_names .append ("past_repetition_penalty_buffer_RetainedState" )
828+
829+ example_inputs ["repetition_penalties" ] = (
830+ torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
831+ )
832+ dynamic_axes ["repetition_penalties" ] = {0 : "batch_size" }
833+
834+ example_inputs ["past_presence_penalty_buffer" ] = torch .zeros (
835+ (bs , self .model .language_model .config .vocab_size ), dtype = torch .bool
836+ )
837+ dynamic_axes ["past_presence_penalty_buffer" ] = {
838+ 0 : "batch_size" ,
839+ }
840+ output_names .append ("past_presence_penalty_buffer_RetainedState" )
841+
842+ example_inputs ["presence_penalties" ] = (
843+ torch .zeros ((bs , 1 ), dtype = torch .float ) + constants .ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
844+ )
845+ dynamic_axes ["presence_penalties" ] = {0 : "batch_size" }
846+
847+ example_inputs ["temperatures" ] = (
848+ torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_TEMPERATURES
849+ )
850+ dynamic_axes ["temperatures" ] = {0 : "batch_size" }
851+
852+ max_top_k_ids = self .model .qaic_config .get ("max_top_k_ids" , constants .ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS )
853+ example_inputs ["top_ks" ] = torch .randint (1 , max_top_k_ids , size = (bs , 1 )).to (torch .int32 )
854+ dynamic_axes ["top_ks" ] = {0 : "batch_size" }
855+
856+ example_inputs ["top_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_TOP_PS
857+ dynamic_axes ["top_ps" ] = {0 : "batch_size" }
858+
859+ example_inputs ["min_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_MIN_PS
860+ dynamic_axes ["min_ps" ] = {0 : "batch_size" }
861+
862+ example_inputs ["random_numbers" ] = torch .rand ((bs , 1 ), dtype = torch .float )
863+ dynamic_axes ["random_numbers" ] = {0 : "batch_size" }
864+
865+ return example_inputs , output_names , dynamic_axes
866+
765867 def compile (
766868 self ,
767869 compile_dir ,
@@ -1499,6 +1601,8 @@ def __init__(
14991601 """
15001602 if kwargs .pop ("full_batch_size" , None ):
15011603 raise NotImplementedError ("Continuous batching is not supported for image-text-to-text models yet." )
1604+ if kwargs .pop ("qaic_config" , None ):
1605+ raise NotImplementedError ("On-device sampling is not supported for single QPC multimodal models yet." )
15021606 super ().__init__ (model , ** kwargs )
15031607
15041608 # to handle internvl models
@@ -2023,6 +2127,7 @@ def from_pretrained(
20232127 pretrained_model_name_or_path : str ,
20242128 kv_offload : Optional [bool ] = None ,
20252129 continuous_batching : bool = False ,
2130+ qaic_config : Optional [dict ] = None ,
20262131 ** kwargs ,
20272132 ):
20282133 """
@@ -2036,6 +2141,12 @@ def from_pretrained(
20362141 If True, uses the dual QPC approach (vision encoder KV offloaded).
20372142 If False, uses the single QPC approach (entire model in one QPC).
20382143 If None, the default behavior of the internal classes is used (typically dual QPC).
2144+ qaic_config : dict, optional
2145+ A dictionary for QAIC-specific configurations.
2146+ Only the following keys are supported by the text model of the dual QPC multimodal model:
2147+ - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
2148+ - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
2149+ Additional keys will be ignored.
20392150 **kwargs :
20402151 Additional arguments passed to HuggingFace's ``from_pretrained``.
20412152
@@ -2063,11 +2174,14 @@ def from_pretrained(
20632174 logger .warning ("Updating low_cpu_mem_usage=False" )
20642175
20652176 kwargs .update ({"attn_implementation" : "eager" , "low_cpu_mem_usage" : False })
2177+ if qaic_config is not None :
2178+ qaic_config ["pretrained_model_name_or_path" ] = pretrained_model_name_or_path
20662179 model = cls ._hf_auto_class .from_pretrained (pretrained_model_name_or_path , ** kwargs )
20672180 return cls (
20682181 model ,
20692182 kv_offload = kv_offload ,
20702183 continuous_batching = continuous_batching ,
2184+ qaic_config = qaic_config ,
20712185 pretrained_model_name_or_path = pretrained_model_name_or_path ,
20722186 ** kwargs ,
20732187 )
@@ -2273,7 +2387,11 @@ def from_pretrained(
22732387
22742388 if model .__class__ .__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP :
22752389 return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP [model .__class__ .__name__ ](
2276- model , kv_offload = kv_offload , pretrained_model_name_or_path = pretrained_model_name_or_path , ** kwargs
2390+ model ,
2391+ kv_offload = kv_offload ,
2392+ pretrained_model_name_or_path = pretrained_model_name_or_path ,
2393+ qaic_config = qaic_config ,
2394+ ** kwargs
22772395 )
22782396 return cls (
22792397 model ,
0 commit comments