@@ -721,12 +721,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
721721 ]
722722 _onnx_transforms = [FP16ClipTransform , SplitTensorsTransform ]
723723
724- def __init__ (
725- self ,
726- model ,
727- qaic_config : Optional [dict ] = None ,
728- ** kwargs
729- ):
724+ def __init__ (self , model , qaic_config : Optional [dict ] = None , ** kwargs ):
730725 """
731726 Initializes the language decoder component for multimodal models.
732727
@@ -735,7 +730,7 @@ def __init__(
735730 model : nn.Module
736731 The full HuggingFace multimodal model from which the language decoder is extracted.
737732 qaic_config : dict, optional
738- A dictionary for QAIC-specific configurations.
733+ A dictionary for QAIC-specific configurations.
739734 Only the following keys are supported by the text model of the dual QPC multimodal model:
740735 - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
741736 - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
@@ -776,7 +771,9 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
776771 Path to the generated ONNX graph file for the language decoder.
777772 """
778773 if self .model .qaic_config is not None and self .model .qaic_config .get ("include_sampler" , False ):
779- inputs , output_names , dynamic_axes = self .get_sampling_inputs_and_outputs (inputs , output_names , dynamic_axes )
774+ inputs , output_names , dynamic_axes = self .get_sampling_inputs_and_outputs (
775+ inputs , output_names , dynamic_axes
776+ )
780777 return self ._export (
781778 inputs , output_names , dynamic_axes , export_dir = export_dir , offload_pt_weights = offload_pt_weights
782779 )
@@ -807,7 +804,7 @@ def get_sampling_inputs_and_outputs(
807804 sampling-related parameters.
808805 """
809806 bs : int = constants .ONNX_EXPORT_EXAMPLE_BATCH_SIZE
810-
807+
811808 assert "logits" in output_names , "logits must be part of the output names to suport on-device sampling"
812809
813810 logits_index = output_names .index ("logits" )
@@ -859,7 +856,7 @@ def get_sampling_inputs_and_outputs(
859856 example_inputs ["min_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_MIN_PS
860857 dynamic_axes ["min_ps" ] = {0 : "batch_size" }
861858
862- example_inputs ["random_numbers" ] = torch .rand ((bs , 1 ), dtype = torch .float )
859+ example_inputs ["random_numbers" ] = torch .rand ((bs , max_top_k_ids ), dtype = torch .float )
863860 dynamic_axes ["random_numbers" ] = {0 : "batch_size" }
864861
865862 return example_inputs , output_names , dynamic_axes
@@ -2142,7 +2139,7 @@ def from_pretrained(
21422139 If False, uses the single QPC approach (entire model in one QPC).
21432140 If None, the default behavior of the internal classes is used (typically dual QPC).
21442141 qaic_config : dict, optional
2145- A dictionary for QAIC-specific configurations.
2142+ A dictionary for QAIC-specific configurations.
21462143 Only the following keys are supported by the text model of the dual QPC multimodal model:
21472144 - **include_sampler** (bool): If True, enables on-device sampling of next tokens.
21482145 - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
@@ -2181,7 +2178,7 @@ def from_pretrained(
21812178 model ,
21822179 kv_offload = kv_offload ,
21832180 continuous_batching = continuous_batching ,
2184- qaic_config = qaic_config ,
2181+ qaic_config = qaic_config ,
21852182 pretrained_model_name_or_path = pretrained_model_name_or_path ,
21862183 ** kwargs ,
21872184 )
@@ -2391,7 +2388,7 @@ def from_pretrained(
23912388 kv_offload = kv_offload ,
23922389 pretrained_model_name_or_path = pretrained_model_name_or_path ,
23932390 qaic_config = qaic_config ,
2394- ** kwargs
2391+ ** kwargs ,
23952392 )
23962393 return cls (
23972394 model ,
@@ -2594,7 +2591,7 @@ def get_sampling_inputs_and_outputs(
25942591 example_inputs ["min_ps" ] = torch .ones ((bs , 1 ), dtype = torch .float ) * constants .ONNX_EXPORT_EXAMPLE_MIN_PS
25952592 dynamic_axes ["min_ps" ] = {0 : "batch_size" }
25962593
2597- example_inputs ["random_numbers" ] = torch .rand ((bs , 1 ), dtype = torch .float )
2594+ example_inputs ["random_numbers" ] = torch .rand ((bs , max_top_k_ids ), dtype = torch .float )
25982595 dynamic_axes ["random_numbers" ] = {0 : "batch_size" }
25992596
26002597 return example_inputs , output_names , dynamic_axes
0 commit comments