Skip to content

Commit 409da24

Browse files
committed
Extend on-device sampling support for dual QPC VLMs
Signed-off-by: quic-xiyushi <xiyushi@qti.qualcomm.com>
1 parent 04f1ad7 commit 409da24

File tree

3 files changed

+164
-18
lines changed

3 files changed

+164
-18
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 120 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -721,20 +721,37 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
721721
]
722722
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
723723

724-
def __init__(self, model, **kwargs):
724+
def __init__(
725+
self,
726+
model,
727+
qaic_config: Optional[dict] = None,
728+
**kwargs
729+
):
725730
"""
726731
Initializes the language decoder component for multimodal models.
727732
728733
Parameters
729734
----------
730735
model : nn.Module
731736
The full HuggingFace multimodal model from which the language decoder is extracted.
737+
qaic_config : dict, optional
738+
A dictionary for QAIC-specific configurations.
739+
Only the following keys are supported by the text model of the dual QPC multimodal model:
740+
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
741+
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
742+
Additional keys will be ignored.
732743
**kwargs :
733744
Additional keyword arguments passed to the base class constructor.
734745
"""
735746
super().__init__(model, **kwargs)
736747
self.model = model.get_qeff_language_decoder()
737748
self.hash_params["qeff_auto_class"] = self.__class__.__name__
749+
self.model.qaic_config = qaic_config
750+
# ---Sampling---
751+
# Note: SamplerTransform should be applied after all other transforms
752+
# are done. The role of the sampler is to just add nodes at the output of the
753+
# previous transform function.
754+
self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs)
738755

739756
def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True):
740757
"""
@@ -758,10 +775,95 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
758775
str
759776
Path to the generated ONNX graph file for the language decoder.
760777
"""
778+
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
779+
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes)
761780
return self._export(
762781
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
763782
)
764783

784+
def get_sampling_inputs_and_outputs(
785+
self,
786+
example_inputs: Dict[str, torch.Tensor],
787+
output_names: List[str],
788+
dynamic_axes: Dict[str, Dict[int, str]],
789+
):
790+
"""
791+
Updates the example inputs, output names, and dynamic axes to include
792+
parameters relevant for on-device sampling during ONNX export.
793+
794+
Parameters
795+
----------
796+
example_inputs : Dict[str, torch.Tensor]
797+
Current dictionary of example inputs.
798+
output_names : List[str]
799+
Current list of output names.
800+
dynamic_axes : Dict[str, Dict[int, str]]
801+
Current dictionary of dynamic axes configurations.
802+
803+
Returns
804+
-------
805+
Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]]
806+
Updated example inputs, output names, and dynamic axes including
807+
sampling-related parameters.
808+
"""
809+
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
810+
811+
assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling"
812+
813+
logits_index = output_names.index("logits")
814+
output_names[logits_index] = "next_tokens"
815+
816+
example_inputs["last_accepted_output_tokens"] = torch.zeros(
817+
(bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64
818+
)
819+
dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"}
820+
821+
example_inputs["past_repetition_penalty_buffer"] = torch.zeros(
822+
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
823+
)
824+
dynamic_axes["past_repetition_penalty_buffer"] = {
825+
0: "batch_size",
826+
}
827+
output_names.append("past_repetition_penalty_buffer_RetainedState")
828+
829+
example_inputs["repetition_penalties"] = (
830+
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES
831+
)
832+
dynamic_axes["repetition_penalties"] = {0: "batch_size"}
833+
834+
example_inputs["past_presence_penalty_buffer"] = torch.zeros(
835+
(bs, self.model.language_model.config.vocab_size), dtype=torch.bool
836+
)
837+
dynamic_axes["past_presence_penalty_buffer"] = {
838+
0: "batch_size",
839+
}
840+
output_names.append("past_presence_penalty_buffer_RetainedState")
841+
842+
example_inputs["presence_penalties"] = (
843+
torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES
844+
)
845+
dynamic_axes["presence_penalties"] = {0: "batch_size"}
846+
847+
example_inputs["temperatures"] = (
848+
torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES
849+
)
850+
dynamic_axes["temperatures"] = {0: "batch_size"}
851+
852+
max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS)
853+
example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32)
854+
dynamic_axes["top_ks"] = {0: "batch_size"}
855+
856+
example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS
857+
dynamic_axes["top_ps"] = {0: "batch_size"}
858+
859+
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
860+
dynamic_axes["min_ps"] = {0: "batch_size"}
861+
862+
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
863+
dynamic_axes["random_numbers"] = {0: "batch_size"}
864+
865+
return example_inputs, output_names, dynamic_axes
866+
765867
def compile(
766868
self,
767869
compile_dir,
@@ -1499,6 +1601,8 @@ def __init__(
14991601
"""
15001602
if kwargs.pop("full_batch_size", None):
15011603
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
1604+
if kwargs.pop("qaic_config", None):
1605+
raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")
15021606
super().__init__(model, **kwargs)
15031607

15041608
# to handle internvl models
@@ -2023,6 +2127,7 @@ def from_pretrained(
20232127
pretrained_model_name_or_path: str,
20242128
kv_offload: Optional[bool] = None,
20252129
continuous_batching: bool = False,
2130+
qaic_config: Optional[dict] = None,
20262131
**kwargs,
20272132
):
20282133
"""
@@ -2036,6 +2141,12 @@ def from_pretrained(
20362141
If True, uses the dual QPC approach (vision encoder KV offloaded).
20372142
If False, uses the single QPC approach (entire model in one QPC).
20382143
If None, the default behavior of the internal classes is used (typically dual QPC).
2144+
qaic_config : dict, optional
2145+
A dictionary for QAIC-specific configurations.
2146+
Only the following keys are supported by the text model of the dual QPC multimodal model:
2147+
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
2148+
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
2149+
Additional keys will be ignored.
20392150
**kwargs :
20402151
Additional arguments passed to HuggingFace's ``from_pretrained``.
20412152
@@ -2063,11 +2174,14 @@ def from_pretrained(
20632174
logger.warning("Updating low_cpu_mem_usage=False")
20642175

20652176
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
2177+
if qaic_config is not None:
2178+
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
20662179
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
20672180
return cls(
20682181
model,
20692182
kv_offload=kv_offload,
20702183
continuous_batching=continuous_batching,
2184+
qaic_config=qaic_config,
20712185
pretrained_model_name_or_path=pretrained_model_name_or_path,
20722186
**kwargs,
20732187
)
@@ -2273,7 +2387,11 @@ def from_pretrained(
22732387

22742388
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
22752389
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
2276-
model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
2390+
model,
2391+
kv_offload=kv_offload,
2392+
pretrained_model_name_or_path=pretrained_model_name_or_path,
2393+
qaic_config=qaic_config,
2394+
**kwargs
22772395
)
22782396
return cls(
22792397
model,

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,7 @@
289289
QEffGrok1MultiHeadAttention,
290290
)
291291
from QEfficient.transformers.models.internvl.modeling_internvl import (
292+
QEffInternDecoderWrapper,
292293
QEffInternVisionEmbeddings,
293294
QEffInternVLModel,
294295
)
@@ -392,6 +393,7 @@
392393
QEffQwen2_5_VLModel,
393394
QEffQwen2_5_VLTextModel,
394395
QEffQwen2_5_VLVisionAttention,
396+
QEffQwen_2_5_vl_DecoderWrapper,
395397
QEffQwen_2_5_vl_ForConditionalGeneration,
396398
)
397399
from QEfficient.transformers.models.qwen3.modeling_qwen3 import (
@@ -707,10 +709,12 @@ class SamplerTransform:
707709
QEffGPTJForCausalLM,
708710
QEffGraniteForCausalLM,
709711
QEffGraniteMoeForCausalLM,
712+
QEffInternDecoderWrapper,
710713
QEffLlamaForCausalLM,
711714
QEffMptForCausalLM,
712715
QEffPhi3ForCausalLM,
713716
QEffQwen2ForCausalLM,
717+
QEffQwen_2_5_vl_DecoderWrapper,
714718
}
715719

716720
@classmethod

QEfficient/transformers/sampler/sampler.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput):
2424

2525
probs: torch.FloatTensor = None
2626
next_tokens: torch.IntTensor = None
27+
vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
28+
image_idx: Optional[torch.IntTensor] = None # for VLMs
2729
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
2830
past_repetition_penalty_buffer: Optional[torch.Tensor] = None
2931
past_presence_penalty_buffer: Optional[torch.Tensor] = None
@@ -122,6 +124,8 @@ def sampler_forward(
122124
top_ps: Optional[torch.Tensor] = None,
123125
min_ps: Optional[torch.Tensor] = None,
124126
random_numbers: Optional[torch.Tensor] = None,
127+
vision_embeds: Optional[torch.Tensor] = None,
128+
image_idx: Optional[torch.Tensor] = None,
125129
) -> Union[Tuple, SamplerOutput]:
126130
r"""
127131
Perform the sampling of next tokens on the QAIC device (instead of the host)
@@ -170,20 +174,36 @@ def sampler_forward(
170174
Sampling parameter that represents the random seeds to use for random sampling.
171175
Must be in [-1, 1].
172176
"""
173-
174-
outputs = self.old_forward(
175-
input_ids=input_ids,
176-
attention_mask=attention_mask,
177-
position_ids=position_ids,
178-
past_key_values=past_key_values,
179-
batch_index=batch_index,
180-
inputs_embeds=inputs_embeds,
181-
use_cache=use_cache,
182-
output_attentions=output_attentions,
183-
output_hidden_states=output_hidden_states,
184-
return_dict=return_dict,
185-
cache_position=cache_position,
186-
)
177+
if vision_embeds is not None:
178+
logits, vision_embeds, image_idx, past_key_values = self.old_forward(
179+
input_ids=input_ids,
180+
vision_embeds=vision_embeds,
181+
position_ids=position_ids,
182+
image_idx=image_idx,
183+
past_key_values=past_key_values
184+
)
185+
outputs = dict(
186+
logits=logits,
187+
vision_embeds=vision_embeds,
188+
image_idx=image_idx,
189+
past_key_values=past_key_values
190+
)
191+
if position_ids.dim() == 3: # For models using m-rope
192+
position_ids = position_ids[0]
193+
else:
194+
outputs = self.old_forward(
195+
input_ids=input_ids,
196+
attention_mask=attention_mask,
197+
position_ids=position_ids,
198+
past_key_values=past_key_values,
199+
batch_index=batch_index,
200+
inputs_embeds=inputs_embeds,
201+
use_cache=use_cache,
202+
output_attentions=output_attentions,
203+
output_hidden_states=output_hidden_states,
204+
return_dict=return_dict,
205+
cache_position=cache_position,
206+
)
187207

188208
logits = outputs.get("logits", None)
189209
assert logits is not None, f"{self.model.__class__.__name__} does not return logits."
@@ -230,7 +250,9 @@ def sampler_forward(
230250
return SamplerOutput(
231251
probs=None,
232252
next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits
233-
past_key_values=outputs.past_key_values,
253+
vision_embeds=outputs.get("vision_embeds", None),
254+
image_idx=outputs.get("image_idx", None),
255+
past_key_values=outputs.get("past_key_values", None),
234256
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
235257
past_presence_penalty_buffer=past_presence_penalty_buffer,
236258
)
@@ -314,7 +336,9 @@ def sampler_forward(
314336
return SamplerOutput(
315337
probs=probs,
316338
next_tokens=next_tokens, # Return sampled next tokens instead of logits
317-
past_key_values=outputs.past_key_values,
339+
vision_embeds=outputs.get("vision_embeds", None),
340+
image_idx=outputs.get("image_idx", None),
341+
past_key_values=outputs.get("past_key_values", None),
318342
past_repetition_penalty_buffer=past_repetition_penalty_buffer,
319343
past_presence_penalty_buffer=past_presence_penalty_buffer,
320344
)

0 commit comments

Comments
 (0)