Skip to content

Commit e06e175

Browse files
committed
Fix random_numbers shape
Signed-off-by: quic-xiyushi <xiyushi@qti.qualcomm.com>
1 parent 409da24 commit e06e175

File tree

2 files changed

+19
-28
lines changed

2 files changed

+19
-28
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -721,12 +721,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
721721
]
722722
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
723723

724-
def __init__(
725-
self,
726-
model,
727-
qaic_config: Optional[dict] = None,
728-
**kwargs
729-
):
724+
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
730725
"""
731726
Initializes the language decoder component for multimodal models.
732727
@@ -735,7 +730,7 @@ def __init__(
735730
model : nn.Module
736731
The full HuggingFace multimodal model from which the language decoder is extracted.
737732
qaic_config : dict, optional
738-
A dictionary for QAIC-specific configurations.
733+
A dictionary for QAIC-specific configurations.
739734
Only the following keys are supported by the text model of the dual QPC multimodal model:
740735
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
741736
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
@@ -776,7 +771,9 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
776771
Path to the generated ONNX graph file for the language decoder.
777772
"""
778773
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
779-
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes)
774+
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
775+
inputs, output_names, dynamic_axes
776+
)
780777
return self._export(
781778
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
782779
)
@@ -807,7 +804,7 @@ def get_sampling_inputs_and_outputs(
807804
sampling-related parameters.
808805
"""
809806
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
810-
807+
811808
assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling"
812809

813810
logits_index = output_names.index("logits")
@@ -859,7 +856,7 @@ def get_sampling_inputs_and_outputs(
859856
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
860857
dynamic_axes["min_ps"] = {0: "batch_size"}
861858

862-
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
859+
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
863860
dynamic_axes["random_numbers"] = {0: "batch_size"}
864861

865862
return example_inputs, output_names, dynamic_axes
@@ -2142,7 +2139,7 @@ def from_pretrained(
21422139
If False, uses the single QPC approach (entire model in one QPC).
21432140
If None, the default behavior of the internal classes is used (typically dual QPC).
21442141
qaic_config : dict, optional
2145-
A dictionary for QAIC-specific configurations.
2142+
A dictionary for QAIC-specific configurations.
21462143
Only the following keys are supported by the text model of the dual QPC multimodal model:
21472144
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
21482145
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
@@ -2181,7 +2178,7 @@ def from_pretrained(
21812178
model,
21822179
kv_offload=kv_offload,
21832180
continuous_batching=continuous_batching,
2184-
qaic_config=qaic_config,
2181+
qaic_config=qaic_config,
21852182
pretrained_model_name_or_path=pretrained_model_name_or_path,
21862183
**kwargs,
21872184
)
@@ -2391,7 +2388,7 @@ def from_pretrained(
23912388
kv_offload=kv_offload,
23922389
pretrained_model_name_or_path=pretrained_model_name_or_path,
23932390
qaic_config=qaic_config,
2394-
**kwargs
2391+
**kwargs,
23952392
)
23962393
return cls(
23972394
model,
@@ -2594,7 +2591,7 @@ def get_sampling_inputs_and_outputs(
25942591
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
25952592
dynamic_axes["min_ps"] = {0: "batch_size"}
25962593

2597-
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
2594+
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
25982595
dynamic_axes["random_numbers"] = {0: "batch_size"}
25992596

26002597
return example_inputs, output_names, dynamic_axes

QEfficient/transformers/sampler/sampler.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class SamplerOutput(ModelOutput):
2424

2525
probs: torch.FloatTensor = None
2626
next_tokens: torch.IntTensor = None
27-
vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
28-
image_idx: Optional[torch.IntTensor] = None # for VLMs
27+
vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
28+
image_idx: Optional[torch.IntTensor] = None # for VLMs
2929
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
3030
past_repetition_penalty_buffer: Optional[torch.Tensor] = None
3131
past_presence_penalty_buffer: Optional[torch.Tensor] = None
@@ -176,19 +176,14 @@ def sampler_forward(
176176
"""
177177
if vision_embeds is not None:
178178
logits, vision_embeds, image_idx, past_key_values = self.old_forward(
179-
input_ids=input_ids,
180-
vision_embeds=vision_embeds,
181-
position_ids=position_ids,
182-
image_idx=image_idx,
183-
past_key_values=past_key_values
184-
)
185-
outputs = dict(
186-
logits=logits,
179+
input_ids=input_ids,
187180
vision_embeds=vision_embeds,
181+
position_ids=position_ids,
188182
image_idx=image_idx,
189-
past_key_values=past_key_values
183+
past_key_values=past_key_values,
190184
)
191-
if position_ids.dim() == 3: # For models using m-rope
185+
outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values)
186+
if position_ids.dim() == 3: # For models using m-rope
192187
position_ids = position_ids[0]
193188
else:
194189
outputs = self.old_forward(
@@ -322,9 +317,8 @@ def sampler_forward(
322317
) # (batch_size, spec_length, vocab_size)
323318

324319
# Random Sampling
325-
topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids)
326320
gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick
327-
y = topk_probs_asc + gumbel_noise
321+
y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids)
328322
random_samples_indices = torch.argmax(y, dim=1, keepdim=True)
329323
random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1)
330324

0 commit comments

Comments
 (0)