Skip to content

Commit af8e673

Browse files
committed
Fix random_numbers shape
1 parent 16e683a commit af8e673

File tree

2 files changed

+24
-33
lines changed

2 files changed

+24
-33
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -718,12 +718,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
718718
]
719719
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]
720720

721-
def __init__(
722-
self,
723-
model,
724-
qaic_config: Optional[dict] = None,
725-
**kwargs
726-
):
721+
def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
727722
"""
728723
Initializes the language decoder component for multimodal models.
729724
@@ -732,7 +727,7 @@ def __init__(
732727
model : nn.Module
733728
The full HuggingFace multimodal model from which the language decoder is extracted.
734729
qaic_config : dict, optional
735-
A dictionary for QAIC-specific configurations.
730+
A dictionary for QAIC-specific configurations.
736731
Only the following keys are supported by the text model of the dual QPC multimodal model:
737732
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
738733
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
@@ -773,7 +768,9 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt
773768
Path to the generated ONNX graph file for the language decoder.
774769
"""
775770
if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False):
776-
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(inputs, output_names, dynamic_axes)
771+
inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs(
772+
inputs, output_names, dynamic_axes
773+
)
777774
return self._export(
778775
inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights
779776
)
@@ -804,7 +801,7 @@ def get_sampling_inputs_and_outputs(
804801
sampling-related parameters.
805802
"""
806803
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
807-
804+
808805
assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling"
809806

810807
logits_index = output_names.index("logits")
@@ -856,7 +853,7 @@ def get_sampling_inputs_and_outputs(
856853
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
857854
dynamic_axes["min_ps"] = {0: "batch_size"}
858855

859-
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
856+
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
860857
dynamic_axes["random_numbers"] = {0: "batch_size"}
861858

862859
return example_inputs, output_names, dynamic_axes
@@ -2066,7 +2063,7 @@ def from_pretrained(
20662063
pretrained_model_name_or_path: str,
20672064
kv_offload: Optional[bool] = None,
20682065
qaic_config: Optional[dict] = None,
2069-
**kwargs
2066+
**kwargs,
20702067
):
20712068
"""
20722069
Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path.
@@ -2080,7 +2077,7 @@ def from_pretrained(
20802077
If False, uses the single QPC approach (entire model in one QPC).
20812078
If None, the default behavior of the internal classes is used (typically dual QPC).
20822079
qaic_config : dict, optional
2083-
A dictionary for QAIC-specific configurations.
2080+
A dictionary for QAIC-specific configurations.
20842081
Only the following keys are supported by the text model of the dual QPC multimodal model:
20852082
- **include_sampler** (bool): If True, enables on-device sampling of next tokens.
20862083
- **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling.
@@ -2116,11 +2113,11 @@ def from_pretrained(
21162113
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
21172114
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
21182115
return cls(
2119-
model,
2120-
kv_offload=kv_offload,
2121-
pretrained_model_name_or_path=pretrained_model_name_or_path,
2122-
qaic_config=qaic_config,
2123-
**kwargs
2116+
model,
2117+
kv_offload=kv_offload,
2118+
pretrained_model_name_or_path=pretrained_model_name_or_path,
2119+
qaic_config=qaic_config,
2120+
**kwargs,
21242121
)
21252122

21262123

@@ -2327,7 +2324,7 @@ def from_pretrained(
23272324
kv_offload=kv_offload,
23282325
pretrained_model_name_or_path=pretrained_model_name_or_path,
23292326
qaic_config=qaic_config,
2330-
**kwargs
2327+
**kwargs,
23312328
)
23322329
return cls(
23332330
model,
@@ -2519,7 +2516,7 @@ def get_sampling_inputs_and_outputs(
25192516
example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS
25202517
dynamic_axes["min_ps"] = {0: "batch_size"}
25212518

2522-
example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float)
2519+
example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float)
25232520
dynamic_axes["random_numbers"] = {0: "batch_size"}
25242521

25252522
return example_inputs, output_names, dynamic_axes

QEfficient/transformers/sampler/sampler.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ class SamplerOutput(ModelOutput):
2424

2525
probs: torch.FloatTensor = None
2626
next_tokens: torch.IntTensor = None
27-
vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
28-
image_idx: Optional[torch.IntTensor] = None # for VLMs
27+
vision_embeds: Optional[torch.FloatTensor] = None # For VLMs
28+
image_idx: Optional[torch.IntTensor] = None # for VLMs
2929
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
3030
past_repetition_penalty_buffer: Optional[torch.Tensor] = None
3131
past_presence_penalty_buffer: Optional[torch.Tensor] = None
@@ -176,19 +176,14 @@ def sampler_forward(
176176
"""
177177
if vision_embeds is not None:
178178
logits, vision_embeds, image_idx, past_key_values = self.old_forward(
179-
input_ids=input_ids,
180-
vision_embeds=vision_embeds,
181-
position_ids=position_ids,
182-
image_idx=image_idx,
183-
past_key_values=past_key_values
184-
)
185-
outputs = dict(
186-
logits=logits,
179+
input_ids=input_ids,
187180
vision_embeds=vision_embeds,
181+
position_ids=position_ids,
188182
image_idx=image_idx,
189-
past_key_values=past_key_values
183+
past_key_values=past_key_values,
190184
)
191-
if position_ids.dim() == 3: # For models using m-rope
185+
outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values)
186+
if position_ids.dim() == 3: # For models using m-rope
192187
position_ids = position_ids[0]
193188
else:
194189
outputs = self.old_forward(
@@ -322,9 +317,8 @@ def sampler_forward(
322317
) # (batch_size, spec_length, vocab_size)
323318

324319
# Random Sampling
325-
topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids)
326320
gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick
327-
y = topk_probs_asc + gumbel_noise
321+
y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids)
328322
random_samples_indices = torch.argmax(y, dim=1, keepdim=True)
329323
random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1)
330324

0 commit comments

Comments
 (0)