diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 60f60c768..c110b3ce5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -721,7 +721,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel): ] _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def __init__(self, model, **kwargs): + def __init__(self, model, continuous_batching: bool = False, qaic_config: Optional[dict] = None, **kwargs): """ Initializes the language decoder component for multimodal models. @@ -729,12 +729,28 @@ def __init__(self, model, **kwargs): ---------- model : nn.Module The full HuggingFace multimodal model from which the language decoder is extracted. + continuous_batching : bool, optional + If True, enables continuous batching mode for future compilation and execution. + This setting must be consistent across `from_pretrained` and `compile` calls. Default is False. + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + Only the following keys are supported by the text model of the dual QPC multimodal model: + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + Additional keys will be ignored. **kwargs : Additional keyword arguments passed to the base class constructor. """ super().__init__(model, **kwargs) self.model = model.get_qeff_language_decoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ + self.continuous_batching = continuous_batching + self.model.qaic_config = qaic_config + # ---Sampling--- + # Note: SamplerTransform should be applied after all other transforms + # are done. The role of the sampler is to just add nodes at the output of the + # previous transform function. + self.model, _ = SamplerTransform.apply(self.model, qaic_config, **kwargs) def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True): """ @@ -758,10 +774,98 @@ def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt str Path to the generated ONNX graph file for the language decoder. """ + if self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False): + inputs, output_names, dynamic_axes = self.get_sampling_inputs_and_outputs( + inputs, output_names, dynamic_axes + ) return self._export( inputs, output_names, dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights ) + def get_sampling_inputs_and_outputs( + self, + example_inputs: Dict[str, torch.Tensor], + output_names: List[str], + dynamic_axes: Dict[str, Dict[int, str]], + ): + """ + Updates the example inputs, output names, and dynamic axes to include + parameters relevant for on-device sampling during ONNX export. + + Parameters + ---------- + example_inputs : Dict[str, torch.Tensor] + Current dictionary of example inputs. + output_names : List[str] + Current list of output names. + dynamic_axes : Dict[str, Dict[int, str]] + Current dictionary of dynamic axes configurations. + + Returns + ------- + Tuple[Dict[str, torch.Tensor], List[str], Dict[str, Dict[int, str]]] + Updated example inputs, output names, and dynamic axes including + sampling-related parameters. + """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + + assert "logits" in output_names, "logits must be part of the output names to suport on-device sampling" + + logits_index = output_names.index("logits") + output_names[logits_index] = "next_tokens" + + example_inputs["last_accepted_output_tokens"] = torch.zeros( + (bs, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN), dtype=torch.int64 + ) + dynamic_axes["last_accepted_output_tokens"] = {0: "batch_size", 1: "seq_len"} + + example_inputs["past_repetition_penalty_buffer"] = torch.zeros( + (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["past_repetition_penalty_buffer"] = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + } + output_names.append("past_repetition_penalty_buffer_RetainedState") + + example_inputs["repetition_penalties"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_REPETITION_PENALTIES + ) + dynamic_axes["repetition_penalties"] = {0: "batch_size"} + + example_inputs["past_presence_penalty_buffer"] = torch.zeros( + (fbs if self.continuous_batching else bs, self.model.language_model.config.vocab_size), dtype=torch.bool + ) + dynamic_axes["past_presence_penalty_buffer"] = { + 0: "full_batch_size" if self.continuous_batching else "batch_size", + } + output_names.append("past_presence_penalty_buffer_RetainedState") + + example_inputs["presence_penalties"] = ( + torch.zeros((bs, 1), dtype=torch.float) + constants.ONNX_EXPORT_EXAMPLE_PRESENCE_PENALTIES + ) + dynamic_axes["presence_penalties"] = {0: "batch_size"} + + example_inputs["temperatures"] = ( + torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TEMPERATURES + ) + dynamic_axes["temperatures"] = {0: "batch_size"} + + max_top_k_ids = self.model.qaic_config.get("max_top_k_ids", constants.ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS) + example_inputs["top_ks"] = torch.randint(1, max_top_k_ids, size=(bs, 1)).to(torch.int32) + dynamic_axes["top_ks"] = {0: "batch_size"} + + example_inputs["top_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_TOP_PS + dynamic_axes["top_ps"] = {0: "batch_size"} + + example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS + dynamic_axes["min_ps"] = {0: "batch_size"} + + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) + dynamic_axes["random_numbers"] = {0: "batch_size"} + + return example_inputs, output_names, dynamic_axes + def compile( self, compile_dir, @@ -882,7 +986,7 @@ def __init__( self.model = model self.config = model.config self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) - self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) + self.lang_model = QEffCausalLMForTextImageToTextModel(model, continuous_batching=continuous_batching, **kwargs) self.continuous_batching = continuous_batching self.input_shapes, self.output_names = None, None @@ -1499,6 +1603,8 @@ def __init__( """ if kwargs.pop("full_batch_size", None): raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") + if kwargs.pop("qaic_config", None): + raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.") super().__init__(model, **kwargs) # to handle internvl models @@ -2023,6 +2129,7 @@ def from_pretrained( pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, continuous_batching: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -2036,6 +2143,12 @@ def from_pretrained( If True, uses the dual QPC approach (vision encoder KV offloaded). If False, uses the single QPC approach (entire model in one QPC). If None, the default behavior of the internal classes is used (typically dual QPC). + qaic_config : dict, optional + A dictionary for QAIC-specific configurations. + Only the following keys are supported by the text model of the dual QPC multimodal model: + - **include_sampler** (bool): If True, enables on-device sampling of next tokens. + - **max_top_k_ids** (int): Maximum number of top K tokens (<= vocab size) to consider during sampling. + Additional keys will be ignored. **kwargs : Additional arguments passed to HuggingFace's ``from_pretrained``. @@ -2063,11 +2176,14 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + if qaic_config is not None: + qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( model, kv_offload=kv_offload, continuous_batching=continuous_batching, + qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs, ) @@ -2273,7 +2389,11 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, + **kwargs, ) return cls( model, @@ -2476,7 +2596,7 @@ def get_sampling_inputs_and_outputs( example_inputs["min_ps"] = torch.ones((bs, 1), dtype=torch.float) * constants.ONNX_EXPORT_EXAMPLE_MIN_PS dynamic_axes["min_ps"] = {0: "batch_size"} - example_inputs["random_numbers"] = torch.rand((bs, 1), dtype=torch.float) + example_inputs["random_numbers"] = torch.rand((bs, max_top_k_ids), dtype=torch.float) dynamic_axes["random_numbers"] = {0: "batch_size"} return example_inputs, output_names, dynamic_axes diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 773ce178c..c750a8c66 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -289,6 +289,7 @@ QEffGrok1MultiHeadAttention, ) from QEfficient.transformers.models.internvl.modeling_internvl import ( + QEffInternDecoderWrapper, QEffInternVisionEmbeddings, QEffInternVLModel, ) @@ -392,6 +393,7 @@ QEffQwen2_5_VLModel, QEffQwen2_5_VLTextModel, QEffQwen2_5_VLVisionAttention, + QEffQwen_2_5_vl_DecoderWrapper, QEffQwen_2_5_vl_ForConditionalGeneration, ) from QEfficient.transformers.models.qwen3.modeling_qwen3 import ( @@ -707,10 +709,12 @@ class SamplerTransform: QEffGPTJForCausalLM, QEffGraniteForCausalLM, QEffGraniteMoeForCausalLM, + QEffInternDecoderWrapper, QEffLlamaForCausalLM, QEffMptForCausalLM, QEffPhi3ForCausalLM, QEffQwen2ForCausalLM, + QEffQwen_2_5_vl_DecoderWrapper, } @classmethod diff --git a/QEfficient/transformers/sampler/sampler.py b/QEfficient/transformers/sampler/sampler.py index 96846e712..1075db784 100644 --- a/QEfficient/transformers/sampler/sampler.py +++ b/QEfficient/transformers/sampler/sampler.py @@ -24,6 +24,8 @@ class SamplerOutput(ModelOutput): probs: torch.FloatTensor = None next_tokens: torch.IntTensor = None + vision_embeds: Optional[torch.FloatTensor] = None # For VLMs + image_idx: Optional[torch.IntTensor] = None # for VLMs past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None past_repetition_penalty_buffer: Optional[torch.Tensor] = None past_presence_penalty_buffer: Optional[torch.Tensor] = None @@ -122,6 +124,8 @@ def sampler_forward( top_ps: Optional[torch.Tensor] = None, min_ps: Optional[torch.Tensor] = None, random_numbers: Optional[torch.Tensor] = None, + vision_embeds: Optional[torch.Tensor] = None, + image_idx: Optional[torch.Tensor] = None, ) -> Union[Tuple, SamplerOutput]: r""" Perform the sampling of next tokens on the QAIC device (instead of the host) @@ -170,20 +174,35 @@ def sampler_forward( Sampling parameter that represents the random seeds to use for random sampling. Must be in [-1, 1]. """ - - outputs = self.old_forward( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - batch_index=batch_index, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - cache_position=cache_position, - ) + if vision_embeds is not None: + forward_kwargs = dict( + input_ids=input_ids, + vision_embeds=vision_embeds, + position_ids=position_ids, + image_idx=image_idx, + past_key_values=past_key_values, + ) + if batch_index is not None: + forward_kwargs["batch_index"] = batch_index + + logits, vision_embeds, image_idx, past_key_values = self.old_forward(**forward_kwargs) + outputs = dict(logits=logits, vision_embeds=vision_embeds, image_idx=image_idx, past_key_values=past_key_values) + if position_ids.dim() == 3: # For models using m-rope + position_ids = position_ids[0] + else: + outputs = self.old_forward( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + ) logits = outputs.get("logits", None) assert logits is not None, f"{self.model.__class__.__name__} does not return logits." @@ -230,7 +249,9 @@ def sampler_forward( return SamplerOutput( probs=None, next_tokens=greedy_samples.reshape(-1, spec_length, 1), # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) @@ -300,9 +321,8 @@ def sampler_forward( ) # (batch_size, spec_length, vocab_size) # Random Sampling - topk_probs_asc = torch.softmax(topk_values_asc, dim=1) # (batch_size * spec_length, max_top_k_ids) gumbel_noise = -torch.log(-torch.log(random_numbers.repeat(spec_length, 1))) # Gumbel-Max Trick - y = topk_probs_asc + gumbel_noise + y = topk_values_asc + gumbel_noise # (batch_size * spec_length, max_top_k_ids) random_samples_indices = torch.argmax(y, dim=1, keepdim=True) random_samples = torch.gather(topk_indices_asc, 1, random_samples_indices) # (batch_size * spec_length, 1) @@ -314,7 +334,9 @@ def sampler_forward( return SamplerOutput( probs=probs, next_tokens=next_tokens, # Return sampled next tokens instead of logits - past_key_values=outputs.past_key_values, + vision_embeds=outputs.get("vision_embeds", None), + image_idx=outputs.get("image_idx", None), + past_key_values=outputs.get("past_key_values", None), past_repetition_penalty_buffer=past_repetition_penalty_buffer, past_presence_penalty_buffer=past_presence_penalty_buffer, ) diff --git a/examples/on_device_sampling.py b/examples/on_device_sampling.py index 00d8c2430..108e5390e 100644 --- a/examples/on_device_sampling.py +++ b/examples/on_device_sampling.py @@ -28,6 +28,7 @@ def main(args, **kwargs): if include_sampler is not None: return_pdfs = args.override_qaic_config.get("aic_return_pdfs", None) == "true" max_top_k_ids = int(args.override_qaic_config.get("max_top_k_ids", 512)) + np.random.seed(int(args.random_number)) sampling_params = { "repetition_penalties": np.array(args.repetition_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), "presence_penalties": np.array(args.presence_penalty, dtype=np.float32).repeat(bs).reshape(-1, 1), @@ -36,7 +37,9 @@ def main(args, **kwargs): "top_ks": np.array(args.top_k, dtype=np.int32).repeat(bs).reshape(-1, 1), "top_ps": np.array(args.top_p, dtype=np.float32).repeat(bs).reshape(-1, 1), "min_ps": np.array(args.min_p, dtype=np.float32).repeat(bs).reshape(-1, 1), - "random_numbers": np.array(args.random_number, dtype=np.float32).repeat(bs).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=max_top_k_ids), (bs, 1)).astype( + np.float32 + ), } qaic_config = { k: v @@ -110,10 +113,10 @@ def main(args, **kwargs): --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 2. For non-continuous batching: python3.10 examples/on_device_sampling.py \ @@ -130,10 +133,10 @@ def main(args, **kwargs): --repetition-penalty 1.9 \ --presence-penalty 0.8 \ --temperature 0.67 \ - --top-k 54720 \ + --top-k 54 \ --top-p 0.89 \ --min-p 0.6 \ - --random-number 0.26 + --random-number 26 """ parser = argparse.ArgumentParser(description="Run QEfficient model with On Device Sampling") diff --git a/tests/transformers/sampler/test_sampler.py b/tests/transformers/sampler/test_sampler.py index 9335e1d91..8d437eee8 100644 --- a/tests/transformers/sampler/test_sampler.py +++ b/tests/transformers/sampler/test_sampler.py @@ -211,7 +211,7 @@ def test_greedy_sampling( "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(1.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.zeros((full_batch_size, 512), dtype=np.float32), }, ) model_wo_sampler_exec_info = model_wo_sampler.generate( @@ -233,7 +233,6 @@ def test_greedy_sampling( @pytest.mark.on_qaic -@pytest.mark.skip @pytest.mark.parametrize( "model, prompts, prefill_seq_len, ctx_len, generation_len, full_batch_size, spec_length", random_sampling_configs, @@ -291,6 +290,7 @@ def test_random_sampling( # Generate texts from prompts tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model) + np.random.seed(0) model_w_sampler_exec_info = model_w_sampler.generate( tokenizer=tokenizer, prompts=prompts, @@ -301,11 +301,13 @@ def test_random_sampling( "repetition_penalties": np.array(20.2, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "presence_penalties": np.array(10.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), # "frequency_penalties": np.array(0.5, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "temperatures": np.array(100.1, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "top_ks": np.array(54720, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), + "temperatures": np.array(4.0, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "top_ks": np.array(512, dtype=np.int32).repeat(full_batch_size).reshape(-1, 1), "top_ps": np.array(0.89, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), "min_ps": np.array(0.6, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), - "random_numbers": np.array(0.26, dtype=np.float32).repeat(full_batch_size).reshape(-1, 1), + "random_numbers": np.tile(np.random.uniform(low=0.0, high=1.0, size=512), (full_batch_size, 1)).astype( + np.float32 + ), }, ) model_wo_sampler_exec_info = model_wo_sampler.generate( @@ -319,32 +321,32 @@ def test_random_sampling( # Compare generated texts golden_texts = { - "w_sampler": "Raymond and my favorite color, alongside reds or purples (I can’t have them both", + "w_sampler": "Aiden and I am a freelance writer who loves to explore the world. With over", "wo_sampler": "John Smith and I am a software engineer. I have been working in the industry for the past ", } golden_ids = { "w_sampler": [ [ - 21380, + 319, + 3615, 322, - 590, - 25448, - 2927, - 29892, - 19963, - 2654, - 29879, - 470, - 3708, - 2701, - 313, - 29902, - 508, - 30010, - 29873, - 505, - 963, - 1716, + 306, + 626, + 263, + 3005, + 295, + 749, + 9227, + 1058, + 12355, + 267, + 304, + 26987, + 278, + 3186, + 29889, + 2973, + 975, ] ], "wo_sampler": [