From f00737f602fd37ac557fbdda4b8055bcf8331de9 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 16 Oct 2025 23:25:29 -0700 Subject: [PATCH 01/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/cloud/infer.py | 12 + QEfficient/customop/ctx_scatter_gather.py | 16 +- QEfficient/customop/ctx_scatter_gather_cb.py | 18 +- .../generation/text_generation_inference.py | 76 +++++ QEfficient/transformers/cache_utils.py | 50 +-- .../models/gemma/modeling_gemma.py | 11 +- .../models/gemma2/modeling_gemma2.py | 17 +- .../models/gemma3/modeling_gemma3.py | 111 +++++-- .../models/granite/modeling_granite.py | 15 +- .../models/granitemoe/modeling_granitemoe.py | 9 + .../models/internvl/modeling_internvl.py | 100 ++++-- .../models/llama/modeling_llama.py | 11 +- .../models/llama4/modeling_llama4.py | 125 ++++++-- .../llama_swiftkv/modeling_llama_swiftkv.py | 28 +- .../models/llava/modeling_llava.py | 92 ++++-- .../models/llava_next/modeling_llava_next.py | 106 +++++-- .../models/mistral/modeling_mistral.py | 11 +- .../models/mistral3/modeling_mistral3.py | 86 ++++-- .../models/mixtral_moe/modeling_mixtral.py | 11 +- .../models/mllama/modeling_mllama.py | 98 ++++-- .../transformers/models/modeling_auto.py | 272 +++++++++++++++-- .../transformers/models/mpt/modeling_mpt.py | 11 +- .../models/olmo2/modeling_olmo2.py | 11 +- .../transformers/models/phi/modeling_phi.py | 17 +- .../transformers/models/phi3/modeling_phi3.py | 10 + .../models/qwen2/modeling_qwen2.py | 11 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 93 ++++-- .../models/qwen3/modeling_qwen3.py | 11 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 11 +- .../models/starcoder2/modeling_starcoder2.py | 11 +- QEfficient/utils/check_ccl_specializations.py | 43 +++ examples/ccl_image_text_to_text_inference.py | 135 +++++++++ examples/ccl_llama4_example.py | 126 ++++++++ examples/ccl_mistral3_example.py | 120 ++++++++ examples/ccl_qwen2_5_vl_example.py | 189 ++++++++++++ examples/compute_context_length.py | 61 ++++ examples/gemma3_example/ccl_gemma3_mm.py | 119 ++++++++ .../ccl_granite_vision_inference.py | 127 ++++++++ .../ccl_granitemoe_inference.py | 40 +++ .../intern_example/ccl_internvl_inference.py | 286 ++++++++++++++++++ .../ccl_qwen3moe_inference.py | 42 +++ tests/transformers/test_comp_ctx_length.py | 193 ++++++++++++ 42 files changed, 2685 insertions(+), 257 deletions(-) create mode 100644 QEfficient/utils/check_ccl_specializations.py create mode 100644 examples/ccl_image_text_to_text_inference.py create mode 100644 examples/ccl_llama4_example.py create mode 100644 examples/ccl_mistral3_example.py create mode 100644 examples/ccl_qwen2_5_vl_example.py create mode 100644 examples/compute_context_length.py create mode 100644 examples/gemma3_example/ccl_gemma3_mm.py create mode 100644 examples/granite_example/ccl_granite_vision_inference.py create mode 100644 examples/granite_example/ccl_granitemoe_inference.py create mode 100644 examples/intern_example/ccl_internvl_inference.py create mode 100644 examples/qwen3moe_example/ccl_qwen3moe_inference.py create mode 100644 tests/transformers/test_comp_ctx_length.py diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 814122b9d..fbff5b18b 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -340,6 +340,18 @@ def main( "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation." ) parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.") + parser.add_argument( + "--comp-ctx-lengths-prefill", + type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")], + default=[512], + help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).", + ) + parser.add_argument( + "--comp-ctx-lengths-decode", + type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")], + default=[2048], + help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).", + ) parser.add_argument( "--mxfp6", "--mxfp6_matmul", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c4f5a7bbd..269ccb0be 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) +def CtxGather( + data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 +) -> onnxscript.FLOAT: + # Create a shape tensor based on comp_ctx_len + shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0) + + # Directly use the shape tensor without validation + ctx_indices = ops.Expand(ctx_indices, shape_tensor) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=2) @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function): """ @staticmethod - def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value: + return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data) diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index 75d9a12ef..cc9693716 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -97,16 +97,20 @@ def symbolic( @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherCB( - data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 + data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 ) -> onnxscript.FLOAT: batch_size = ops.Gather(ops.Shape(batch_index), [0]) num_heads = ops.Gather(ops.Shape(data), [1]) - ctx_len = ops.Gather(ops.Shape(data), [2]) + # using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well. + ctx_len = ops.Reshape(comp_ctx_len, [1]) # Expanded shape to create indices zero = ops.Constant(value_ints=[0]) one = ops.Constant(value_ints=[1]) - exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + # exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + exp_shape = ops.Concat( + ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0 + ) # Create indices batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape) @@ -119,7 +123,7 @@ def CtxGatherCB( class CtxGatherFuncCB(torch.autograd.Function): @staticmethod - def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = batch_index.view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data) + def symbolic( + g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int + ) -> torch.Value: + return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data) @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6d04cf573..cf4b6aa27 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv( prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, enable_debug_logs: bool = False, stream: bool = True, write_io_dir: Optional[str] = None, @@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv( qpc_path=qpc_path, device_id=device_id, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, @@ -430,6 +434,8 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, @@ -439,6 +445,8 @@ def __init__( sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._ctx_len = ctx_len + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs @@ -797,7 +805,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + if self.comp_ctx_lengths_prefill is not None: + self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + for i in range(num_chunks): + if self.comp_ctx_lengths_prefill is not None: + if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len @@ -816,6 +834,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i generation_len, ) + def initialize_ccl(self, decode_inputs): + self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(decode_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + + return ccl_id, max_ccl_id + def run_continuous_batching_decode(self, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -847,6 +878,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() + if self.comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + while prompt_queue or current_decode_ongoing.any(): outputs = self._session.run(decode_inputs) @@ -884,6 +919,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): batch_id_map[decode_batch_id] ] + if self.comp_ctx_lengths_decode is not None: + ###Recalculate ccl_id based on position ids### + # Determine the maximum value of position_ids across all batch elements + max_position_id = np.max(decode_inputs["position_ids"]) + + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + else: current_decode_ongoing[decode_batch_id] = False else: @@ -896,6 +945,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): if self.include_sampler: decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] + if self.comp_ctx_lengths_decode is not None: + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + if ( + decode_inputs["position_ids"][decode_batch_id, -1] + >= self.comp_ctx_lengths_decode[ccl_id] - 1 + ): + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + generated_id_current_index[decode_batch_id] += 1 return decode_pause_time @@ -922,7 +980,18 @@ def run_decode( self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 + + if self.comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + + cache_index = np.max(decode_inputs["position_ids"]) for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + if streamer: streamer.put(decode_inputs["input_ids"][0]) outputs = self._session.run(decode_inputs) @@ -934,6 +1003,7 @@ def run_decode( # Prepare inputs for next iteration decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) decode_inputs["position_ids"][:, -1] += 1 + cache_index += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id if self.include_sampler: @@ -983,6 +1053,8 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, @@ -996,6 +1068,8 @@ def __init__( qpc_path=qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, device_id=device_id, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, @@ -1007,6 +1081,8 @@ def __init__( self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer self._ctx_len = ctx_len + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self._perf_metrics = None self._prompt_queue = None self._text_streamer = None diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index bbd937d52..0d123d25f 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -40,8 +40,9 @@ def read_only(self, cache_kwargs): k_out, v_out = self.keys, self.values position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) - ctx_len = k_out.shape[2] - ctx_indices = torch.arange(ctx_len)[None, None, ...] + comp_ctx_len = cache_kwargs.get("CCL") + + ctx_indices = torch.arange(comp_ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -53,12 +54,11 @@ def read_only(self, cache_kwargs): ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, comp_ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, comp_ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) - + k_out = CtxGatherFunc.apply(k_out, ctx_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, comp_ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -121,6 +121,7 @@ def update( else: position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs + comp_ctx_len = cache_kwargs.get("CCL") # Scatter if batch_index is not None: @@ -137,8 +138,7 @@ def update( k_out, v_out = self.keys, self.values # Gather - ctx_len = k_out.shape[2] - ctx_indices = torch.arange(ctx_len)[None, None, ...] + ctx_indices = torch.arange(comp_ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -149,11 +149,11 @@ def update( ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, comp_ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, comp_ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, comp_ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -392,6 +392,8 @@ def update( else: position_ids = cache_kwargs.get("position_ids") sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") + comp_ctx_len = cache_kwargs.get("CCL") + is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) layer_ctx_len = self.key_cache[layer_idx].shape[2] kv_position_ids = torch.where( @@ -417,20 +419,24 @@ def update( ctx_len = self.key_cache[layer_idx].shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max else: invalid_idx_value = 0 + + ctx_indices = ctx_indices[:, :, :comp_ctx_len] + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) + rolling_indices = rolling_indices[:comp_ctx_len] final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(v_out, final_indices, comp_ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out @@ -492,6 +498,8 @@ def update( else: position_ids = cache_kwargs.get("position_ids") + comp_ctx_len = cache_kwargs.get("CCL") + is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx])) # Update the position_ids to handle the sliding window @@ -519,21 +527,25 @@ def update( ctx_len = min(layer_ctx_len, k_out.shape[2]) ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max else: invalid_idx_value = 0 + + ctx_indices = ctx_indices[:, :, :comp_ctx_len] + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) # Rolling indices for sliding window all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) + rolling_indices = rolling_indices[:comp_ctx_len] final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(v_out, final_indices, comp_ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index eea1e3898..4c64109d8 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -137,6 +137,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -153,7 +154,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -186,6 +189,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -214,6 +218,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -243,6 +248,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -299,6 +305,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -334,6 +341,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -350,6 +358,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index be3ba942d..85bba2989 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -144,6 +144,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -160,8 +161,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -194,6 +203,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -226,6 +236,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -266,6 +277,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -338,6 +350,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -381,6 +394,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -404,6 +418,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 20b7036fd..2e8494e8e 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import copy -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -215,6 +215,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -245,6 +246,8 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -253,6 +256,7 @@ def forward( "position_ids": position_ids, "is_sliding": self.is_sliding, "sliding_window_pattern": self.config.sliding_window_pattern, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -297,6 +301,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -323,6 +328,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -363,6 +369,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -429,6 +436,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -466,6 +474,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -525,6 +534,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -592,7 +602,9 @@ def __init__(self, model): self.config = self.model.config self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_index @@ -603,7 +615,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -620,7 +636,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffGemma3DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): image_features = self.get_image_features(pixel_values=pixel_values) inputs_embeds = self.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -632,7 +650,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -647,6 +669,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -667,24 +691,55 @@ def get_specializations( "ctx_len": ctx_len, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - }, - ] + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + }, + ] + specializations = {} if kv_offload: @@ -694,7 +749,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -719,6 +774,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): ) lang_dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -767,7 +825,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: @@ -813,6 +871,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 2a2d47d6d..dd3d6c7f3 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -129,6 +129,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -145,8 +146,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -171,6 +180,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -226,6 +236,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -267,6 +278,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -319,6 +331,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index c085f6a5e..07031d7fc 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -123,6 +123,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -142,6 +143,8 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -149,6 +152,7 @@ def forward( "cache_position": cache_position, "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -209,6 +213,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -286,6 +291,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -297,6 +303,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -492,6 +499,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -546,6 +554,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 38d0fe167..29e6ac9a4 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F @@ -34,7 +36,9 @@ def __init__(self, model): self.config = self.model.language_model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) @@ -55,7 +59,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return outputs.logits, vision_embeds, image_idx, outputs.past_key_values @@ -74,6 +82,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -104,24 +114,54 @@ def get_specializations( "batched_num_patches": batch_size * num_patches, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + }, + ] specializations = {} @@ -132,7 +172,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -146,6 +186,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -173,7 +216,7 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -234,6 +277,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -244,7 +290,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): return inputs - def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, pixel_values, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): input_embeds = self.language_model.get_input_embeddings()(input_ids) vision_embeds = self.extract_feature(pixel_values) B, N, C = input_embeds.shape @@ -266,7 +314,11 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index f2a68f80e..58d174270 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -132,6 +132,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -154,7 +155,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -187,6 +190,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -202,6 +206,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -229,6 +234,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -277,6 +283,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -310,6 +317,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -326,6 +334,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 212fe16ae..82678e380 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -470,6 +470,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -503,6 +504,8 @@ def forward( if past_key_value is not None: chunk_position_ids = position_ids + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] if self.use_rope: chunk_position_ids = torch.where( @@ -510,7 +513,11 @@ def forward( ) # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_position_ids} + cache_kwargs = { + "batch_index": batch_index, + "position_ids": chunk_position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -543,6 +550,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -562,6 +570,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -615,6 +624,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -682,6 +692,7 @@ def forward( attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -731,6 +742,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -754,6 +766,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -836,7 +849,9 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -846,7 +861,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -860,7 +879,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlama4DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.language_model.get_input_embeddings()(input_ids) vision_feature_layer = self.config.vision_config.vision_feature_layer vision_feature_select_strategy = self.config.vision_config.vision_feature_select_strategy @@ -880,7 +901,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -892,6 +917,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -941,28 +968,62 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + }, + ] specializations = {} @@ -973,7 +1034,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -993,6 +1054,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -1045,7 +1109,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1102,6 +1166,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 9fd1ed782..5b36b1019 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -89,6 +89,7 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.LongTensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, ) -> torch.Tensor: @@ -105,8 +106,10 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -155,6 +158,7 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, + comp_ctx_lengths, causal_mask, batch_index: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -166,6 +170,7 @@ def forward( hidden_states=hidden_states, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, ) @@ -201,11 +206,19 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def _run_swiftkv_layers( - self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask, batch_index + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values, + comp_ctx_lengths, + causal_mask, + batch_index, ) -> torch.Tensor: for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] - hidden_states = layer(hidden_states, position_ids, past_key_values, causal_mask, batch_index) + hidden_states = layer( + hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index + ) hidden_states = self.norm(hidden_states) return hidden_states, past_key_values @@ -289,6 +302,7 @@ def forward( input_ids: Optional[torch.Tensor], position_ids: torch.Tensor, past_key_values: List[torch.Tensor], + comp_ctx_lengths: Optional[torch.LongTensor], batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.embed_tokens(input_ids) @@ -328,6 +342,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=False, use_cache=True, @@ -373,7 +388,7 @@ def forward( causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :] hidden_states, next_decoder_cache = self._run_swiftkv_layers( - hidden_states, position_ids, past_key_values, causal_mask, batch_index + hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index ) # We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction # we only need the last valid pos_indices hidden_states. @@ -405,9 +420,12 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, ): - hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index) + hidden_states, output_past_key_values = self.model( + input_ids, position_ids, past_key_values, comp_ctx_lengths, batch_index + ) logits = self.lm_head(hidden_states) return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index e260beb05..450fc79b6 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import List + import torch import torch.nn as nn import torch.utils.checkpoint @@ -51,7 +53,9 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -65,6 +69,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, return_dict=True, ) @@ -83,7 +88,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFLlavaDecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.get_input_embeddings()(input_ids) # Image features image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -109,6 +116,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -120,7 +128,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -150,6 +158,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: @@ -166,6 +178,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -187,24 +201,55 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + ] + specializations = {} if kv_offload: @@ -214,7 +259,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -230,6 +275,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 2fa1d9234..b23073fa7 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- +from typing import List + import numpy as np import torch import torch.nn as nn @@ -123,7 +125,9 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.config.image_token_index @@ -138,6 +142,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -154,7 +159,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlavaNextDecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -217,6 +222,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, constants.GRANITEVISION_CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -232,6 +241,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -285,30 +296,67 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "image_size_height": image_size_height, - "image_size_width": image_size_width, - "num_patches": num_patches, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "image_size_height": image_size_height, - "image_size_width": image_size_width, - "num_patches": num_patches, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + ] + specializations = {} if kv_offload: specializations["vision"] = vision @@ -317,7 +365,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { @@ -332,6 +380,10 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for i in range(num_layers): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index ca23cc144..30c73ae8b 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -140,6 +140,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -163,7 +164,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -196,6 +199,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -226,6 +230,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -256,6 +261,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -316,6 +322,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -354,6 +361,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -377,6 +385,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 735eec9e5..a5f1301d2 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -106,6 +106,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, use_cache: Optional[bool] = None, @@ -126,6 +127,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -166,7 +168,9 @@ def __init__(self, model): self.config = self.model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -179,6 +183,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds_1, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) # Cast to int32 to avoid ONNXRT issue @@ -198,7 +203,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFMistral3DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.get_input_embeddings()(input_ids) image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]).repeat(pixel_values.shape[0], 1) image_features = self.get_image_features( @@ -219,6 +226,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) # Cast to int32 to avoid ONNXRT issue logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -230,7 +238,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size @@ -282,6 +290,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -298,6 +309,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -323,22 +336,50 @@ def get_specializations( "vision_size": vision_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "image_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "image_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "image_size": img_size, + "vision_size": vision_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "image_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + }, + ] specializations = {} @@ -351,7 +392,7 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -368,6 +409,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 9b9e3448a..6e61568ac 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -137,6 +137,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -159,7 +160,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -245,6 +248,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -282,6 +286,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -314,6 +319,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -375,6 +381,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, @@ -412,6 +419,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -435,6 +443,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index cb24f1de4..2197bec91 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -177,6 +177,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, @@ -249,6 +250,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, position_embeddings: torch.Tensor = None, use_cache: bool = False, @@ -278,9 +280,12 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -316,6 +321,7 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -350,6 +356,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -379,6 +386,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, @@ -396,13 +404,15 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, - {"batch_index": batch_index, "position_ids": position_ids}, + {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]}, ) elif past_key_value is not None: key_states, value_states = ( @@ -448,6 +458,7 @@ def forward( full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -461,6 +472,7 @@ def forward( attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, cache_position=cache_position, ) @@ -594,6 +606,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.FloatTensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, @@ -658,6 +671,7 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, ) @@ -688,6 +702,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.LongTensor] = None, cross_attention_mask: Optional[torch.LongTensor] = None, @@ -707,6 +722,7 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -774,6 +790,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -820,6 +837,7 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, inputs_embeds=inputs_embeds, cache_position=cache_position, @@ -853,6 +871,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -869,6 +888,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -879,7 +899,7 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN CTX_LEN = constants.ONNX_EXPORT_CTX_LEN @@ -943,6 +963,10 @@ def get_dummy_inputs(self, kv_offload: bool = False): lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: @@ -959,6 +983,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -973,22 +999,53 @@ def get_specializations( logger.warning("Setting `img_size=448` as it was neither passed nor found in vision_config") vision = [{"batch_size": batch_size, "max_num_images": max_num_images, "img_size": img_size}] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - }, - ] + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_images": max_num_images, + "img_size": img_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_images": max_num_images, + "img_size": img_size, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + ] + specializations = {} if kv_offload: @@ -998,7 +1055,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers @@ -1023,6 +1080,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..5c6f67ddc 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -56,6 +56,7 @@ constants, get_padding_shape_from_config, ) +from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger @@ -877,6 +878,15 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") self.model = model self.config = model.config + + self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + if self.comp_ctx_lengths_prefill: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len + ) + self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) self.input_shapes, self.output_names = None, None @@ -922,8 +932,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if comp_ctx_lengths_prefill: + comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len + ) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + **kwargs, + ) @property def onnx_path(self): @@ -978,8 +1005,8 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + inputs = self.model.get_dummy_inputs(self.comp_ctx_lengths_decode, kv_offload=True) + dynamic_axes = self.model.get_onnx_dynamic_axes(self.comp_ctx_lengths_decode, kv_offload=True) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1083,6 +1110,8 @@ def compile( batch_size=batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, img_size=img_size, kv_offload=True, **compiler_options, @@ -1332,6 +1361,11 @@ def kv_offload_generate( lang_session.set_buffers(vision_outputs) + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + # Prepare inputs for prefill chunk_inputs = lang_inputs.copy() prefill_start = perf_counter() @@ -1339,6 +1373,13 @@ def kv_offload_generate( # Run prefill chunk_inputs = lang_inputs.copy() for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ ..., i * prefill_seq_len : (i + 1) * prefill_seq_len @@ -1368,8 +1409,25 @@ def kv_offload_generate( streamer.put(lang_inputs["input_ids"][0]) # Decode loop + if self.comp_ctx_lengths_decode is not None: + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_position_id = np.max(lang_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + decode_start = perf_counter() for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + outputs = lang_session.run(lang_inputs) # Prepare inputs for next iteration @@ -1440,6 +1498,15 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") super().__init__(model, **kwargs) + self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if self.comp_ctx_lengths_prefill: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len + ) + # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): self.model.config.llm_config.use_cache = True @@ -1486,6 +1553,16 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if comp_ctx_lengths_prefill: + comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len + ) + from transformers import AutoConfig config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) @@ -1493,7 +1570,14 @@ def from_pretrained( config.vision_config.use_flash_attn = "false" model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + **kwargs, + ) def export( self, @@ -1515,8 +1599,8 @@ def export( str Path to the generated ONNX graph file. """ - inputs = self.model.get_dummy_inputs() - dynamic_axes = self.model.get_onnx_dynamic_axes() + inputs = self.model.get_dummy_inputs(self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes(self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) @@ -1598,6 +1682,8 @@ def compile( batch_size=batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, img_size=img_size, **compiler_options, ) @@ -1782,12 +1868,24 @@ def cloud_ai_100_generate( inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs["image_idx"] = np.array([[0]]) + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + qpc_session.activate() chunk_inputs = inputs.copy() prefill_start = perf_counter() # Run prefill for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] outputs = qpc_session.run(chunk_inputs) @@ -1811,8 +1909,25 @@ def cloud_ai_100_generate( inputs.pop("pixel_values") # Decode loop + if self.comp_ctx_lengths_decode is not None: + list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + decode_start = perf_counter() for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + outputs = qpc_session.run(inputs) # Prepare inputs for next iteration inputs["input_ids"] = outputs["logits"].argmax(2) @@ -1950,6 +2065,9 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) Union[_QEffAutoModelForImageTextToTextDualQPC, _QEFFAutoModelForImageTextToTextSingleQPC] The wrapped model instance, configured for either dual or single QPC. """ + self.comp_ctx_lengths_prefill = kwargs.get("comp_ctx_lengths_prefill", None) + self.comp_ctx_lengths_decode = kwargs.get("comp_ctx_lengths_decode", None) + if kv_offload: return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) else: @@ -1996,8 +2114,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if comp_ctx_lengths_prefill: + comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len + ) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { @@ -2096,6 +2232,15 @@ def __init__( self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) self.is_tlm = transformed + self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if self.comp_ctx_lengths_prefill: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len + ) + self.hash_params["qeff_auto_class"] = self.__class__.__name__ # ---Sampling--- @@ -2190,6 +2335,14 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) + comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + if comp_ctx_lengths_prefill: + comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len + ) + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: @@ -2199,13 +2352,22 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, ) return cls( model, continuous_batching=continuous_batching, qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, **kwargs, ) @@ -2255,6 +2417,10 @@ def export(self, export_dir: Optional[str] = None) -> str: "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + if self.comp_ctx_lengths_prefill is not None: + example_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d pkv_dynamic_axes = { 0: "full_batch_size" if self.continuous_batching else "batch_size", @@ -2400,6 +2566,7 @@ def build_prefill_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -2431,6 +2598,9 @@ def build_prefill_specialization( "ctx_len": ctx_len, "num_logits_to_keep": 1 if self.is_tlm else None, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -2443,6 +2613,7 @@ def build_decode_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -2472,7 +2643,7 @@ def build_decode_specialization( A dictionary defining the decode specialization, or None if it would be a duplicate of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ - if prefill_seq_len == 1 and not self.continuous_batching: + if prefill_seq_len == 1 and not self.continuous_batching and comp_ctx_lengths is None: return None # Avoid duplication with prefill spec = { "batch_size": full_batch_size if self.continuous_batching else batch_size, @@ -2480,6 +2651,8 @@ def build_decode_specialization( "ctx_len": ctx_len, "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size @@ -2494,6 +2667,8 @@ def compile( *, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, batch_size: int = 1, full_batch_size: Optional[int] = None, kv_cache_batch_size: Optional[int] = None, @@ -2581,6 +2756,23 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + # For comp_ctx_lengths Disaggregated applications + if self.comp_ctx_lengths_prefill is None: + if comp_ctx_lengths_prefill is not None: + import ast + + if isinstance(comp_ctx_lengths_prefill, str): + try: + # Safely evaluate the string to a Python list for disaggregated input + self.comp_ctx_lengths_prefill = ast.literal_eval(comp_ctx_lengths_prefill) + self.comp_ctx_lengths_decode = ast.literal_eval(comp_ctx_lengths_decode) + + except (ValueError, SyntaxError): + raise ValueError("Invalid format for comp_ctx_lengths. Expected a list-like string.") + else: + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode + # --- Validation --- if prefill_only is not None and not isinstance(prefill_only, bool): raise TypeError("`prefill_only` must be a boolean.") @@ -2611,26 +2803,58 @@ def compile( # --- Specializations --- specializations = [] if prefill_only is None or prefill_only or prefill_seq_len == 1: - specializations.append( - self.build_prefill_specialization( + if self.comp_ctx_lengths_prefill is not None: + # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization + for i in range(0, len(self.comp_ctx_lengths_prefill)): + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_prefill[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + ) + ) + + else: + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + ) + ) + + if prefill_only is None or not prefill_only: + if self.comp_ctx_lengths_decode is not None: + # Adding elements from self.comp_ctx_lengths_decode to decode_specialization + for i in range(0, len(self.comp_ctx_lengths_decode)): + decode_spec = self.build_decode_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_decode[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) + + else: + decode_spec = self.build_decode_specialization( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, ) - ) - if prefill_only is None or not prefill_only: - decode_spec = self.build_decode_specialization( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - batch_size=batch_size, - kv_cache_batch_size=kv_cache_batch_size, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, - ) - if decode_spec: - specializations.append(decode_spec) + if decode_spec: + specializations.append(decode_spec) # --- Compilation --- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" @@ -2708,6 +2932,8 @@ def generate( tokenizer, self.qpc_path, prompt=prompts, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, device_id=device_id, generation_len=generation_len, automation=kwargs.pop("automation", False), diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 9bf6a4422..16ca54051 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -39,6 +39,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, ): @@ -51,7 +52,9 @@ def forward( value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale @@ -101,6 +104,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, use_cache: bool = False, output_attentions: bool = False, ): @@ -118,6 +122,7 @@ def forward( batch_index=batch_index, attention_mask=attention_mask, past_key_value=layer_past, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, ) @@ -144,6 +149,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -205,6 +211,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -250,6 +257,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -271,6 +279,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 6dae7ac84..b82fcadb1 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -129,6 +129,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -151,8 +152,10 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -185,6 +188,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -200,6 +204,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -230,6 +235,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -283,6 +289,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -319,6 +326,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -340,6 +348,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py index 18557f1ca..a5e53216a 100644 --- a/QEfficient/transformers/models/phi/modeling_phi.py +++ b/QEfficient/transformers/models/phi/modeling_phi.py @@ -67,6 +67,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -104,8 +105,16 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -140,6 +149,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, @@ -181,6 +191,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -213,6 +224,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -274,6 +286,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -316,6 +329,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -370,6 +384,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 4b5234a5a..851395f08 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -140,6 +140,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, position_ids=Optional[torch.Tensor], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -162,9 +163,12 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -198,6 +202,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -235,6 +240,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -265,6 +271,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -314,6 +321,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -350,6 +358,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -366,6 +375,7 @@ def forward( batch_index=batch_index, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 24e8df46c..1aca7039d 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -150,6 +150,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -166,7 +167,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -200,6 +203,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -231,6 +235,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -261,6 +266,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -313,6 +319,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -348,6 +355,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -364,6 +372,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 030dd7a56..3b1d116de 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -399,6 +399,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -425,8 +426,16 @@ def forward( ) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids[0]} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids[0], + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -457,6 +466,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -496,6 +506,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -528,6 +539,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -578,6 +590,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -616,6 +629,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -639,6 +653,7 @@ def forward( position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -680,7 +695,9 @@ def __init__(self, model): self.model = model self.language_model = self.model.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_id @@ -691,7 +708,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.model.model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) @@ -709,7 +730,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -757,6 +778,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -775,6 +799,8 @@ def get_specializations( img_size: None, height: int = None, width: int = None, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -856,20 +882,46 @@ def smart_resize( "grid_w": grid_w, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + }, + ] specializations = {} @@ -880,7 +932,7 @@ def smart_resize( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.num_hidden_layers @@ -899,6 +951,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index ecdb36019..ccf918c2c 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -151,6 +151,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -167,7 +168,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -201,6 +204,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -232,6 +236,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -262,6 +267,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -314,6 +320,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -349,6 +356,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -367,6 +375,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 591f7c1b0..c8a5ae2fd 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -201,6 +201,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -217,7 +218,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -243,6 +246,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -274,6 +278,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -300,6 +305,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, batch_index: Optional[torch.LongTensor] = None, @@ -342,6 +348,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -369,6 +376,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -385,6 +393,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index 9a327761d..075b8aedb 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -69,6 +69,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -84,7 +85,9 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -118,6 +121,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -153,6 +157,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -184,6 +189,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -237,6 +243,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -273,6 +280,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -289,6 +297,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py new file mode 100644 index 000000000..dbfb08926 --- /dev/null +++ b/QEfficient/utils/check_ccl_specializations.py @@ -0,0 +1,43 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List, Optional + + +def process_ccl_specializations( + ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None +): + if ctx_len is None: + raise TypeError("`ctx_len` is required when loading the model.") + if ccl_prefill is None: + ccl_prefill = [ctx_len] + if ccl_decode is None: + ccl_decode = [ctx_len] + + # Step 1: Cap values to ctx_len + ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] + ccl_decode = [min(x, ctx_len) for x in ccl_decode] + + # Step 2: Remove duplicates within each list + ccl_prefill = list(set(ccl_prefill)) + ccl_decode = list(set(ccl_decode)) + + # Step 3: Ensure no overlap between ccl_prefill and ccl_decode + updated_prefill = [] + for val in ccl_prefill: + while val in ccl_decode or val in updated_prefill: + val -= 1 + if val < 0: + break # Prevent negative values + if val >= 0: + updated_prefill.append(val) + + # Step 4: Sort both lists + updated_prefill.sort() + ccl_decode.sort() + + return updated_prefill, ccl_decode diff --git a/examples/ccl_image_text_to_text_inference.py b/examples/ccl_image_text_to_text_inference.py new file mode 100644 index 000000000..932a407b9 --- /dev/null +++ b/examples/ccl_image_text_to_text_inference.py @@ -0,0 +1,135 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# Add HuggingFace Token to access the model +HF_TOKEN = "" + + +def run_model( + model_name, + token, + query, + image_url, + kv_offload=False, + prefill_seq_len=32, + ctx_len=512, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=560, + num_cores=16, + num_devices=1, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name, token=token) + + # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + token=token, + attn_implementation="eager", + kv_offload=kv_offload, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + image = Image.open(requests.get(image_url, stream=True).raw) + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": query}, + ], + } + ] + input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] + + inputs = processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=prefill_seq_len, + ) + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output_statistics = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output_statistics) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "llava-hf/llava-1.5-7b-hf" + # model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + query = "Describe this image." + image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 32 + ctx_len = 8192 + generation_len = 128 + img_size = 336 + # img_size = 560 + num_cores = 16 + num_devices = 4 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + token=HF_TOKEN, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: + +This image depicts a charming anthropomorphic rabbit standing on a dirt path in front of a picturesque stone cottage, surrounded by a serene landscape. + +The rabbit, with its light brown fur and distinctive long ears, is attired in a stylish blue coat, brown vest, and tan pants, exuding a sense of sophistication. The dirt path, flanked by vibrant flowers and lush greenery, leads to the cottage, which features a thatched roof and a chimney, adding to the rustic charm of the scene. In the background, rolling hills and trees create a breathtaking panorama, while the sky above is a brilliant blue with white clouds, completing the + +""" diff --git a/examples/ccl_llama4_example.py b/examples/ccl_llama4_example.py new file mode 100644 index 000000000..5fc715589 --- /dev/null +++ b/examples/ccl_llama4_example.py @@ -0,0 +1,126 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config) +model.eval() +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id) + +### For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### +ctx_len = 8192 +comp_ctx_lengths_prefill = [3072] +comp_ctx_lengths_decode = [4096, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText( + model, + kv_offload=True, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, +) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you describe the image in detail.", + }, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=700) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=1024) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print() diff --git a/examples/ccl_mistral3_example.py b/examples/ccl_mistral3_example.py new file mode 100644 index 000000000..ed02a4fa9 --- /dev/null +++ b/examples/ccl_mistral3_example.py @@ -0,0 +1,120 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + + +def run_model( + model_name, + query, + image_url, + kv_offload=False, + prefill_seq_len=128, + ctx_len=4096, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=1540, + num_cores=16, + num_devices=4, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + config = AutoConfig.from_pretrained(model_name) + config.vision_config._attn_implementation = "eager" + + model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, + kv_offload=kv_offload, + config=config, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1540 x 1540) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1540 and width 1540 as defined in the config + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1540, 1540)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "https://www.ilankelman.org/stopsigns/australia.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 128 + ctx_len = 8192 + generation_len = 128 + num_cores = 16 + num_devices = 4 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: +The image depicts a street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese archway, known as a paifang, which is intricately designed with red columns and ornate details. The archway features Chinese characters at the top, which translate to "Chinatown Gate." +In the foreground, there is a red stop sign mounted on a pole. The street is relatively quiet, with a single dark-colored SUV driving through the archway. On either side of the archway, there are stone lion statues, which are common decorative elements in Chinese architecture and symbolize protection. + + +""" diff --git a/examples/ccl_qwen2_5_vl_example.py b/examples/ccl_qwen2_5_vl_example.py new file mode 100644 index 000000000..7056011f2 --- /dev/null +++ b/examples/ccl_qwen2_5_vl_example.py @@ -0,0 +1,189 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import torch +import torch.nn.functional as F +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model + +ctx_len = 32768 + +comp_ctx_lengths_prefill = [4000] +comp_ctx_lengths_decode = [4096, 8192,16384, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + attn_implementation="eager", + kv_offload=True, + config=config +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + + ## Set Batch_Size ## + batch_size = 2 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + pos_ids, rope_deltas = qeff_model.model.get_rope_index( + inputs["input_ids"], + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + attention_mask=inputs["attention_mask"], + ) + + input_ids_length = inputs["input_ids"].shape[1] + + inputs["position_ids"] = torch.cat([pos_ids, pos_ids[0].unsqueeze(0)], dim=0) + + prefill_seq_len = 128 + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + }, + ] + + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe about the color of the dog."}, + ], + }, + ] + + messages = [messages_2] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + input_ids_length = inputs["input_ids"].shape[1] + + inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) + + pos_ids, rope_deltas = qeff_model.model.model.get_rope_index( + inputs["input_ids"], + inputs["image_grid_thw"], + video_grid_thw=None, + second_per_grid_ts=None, + attention_mask=inputs["attention_mask"], + ) + + inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0) + + prefill_seq_len = 128 + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + inputs.pop("image_grid_thw") + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py new file mode 100644 index 000000000..554c61c84 --- /dev/null +++ b/examples/compute_context_length.py @@ -0,0 +1,61 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## In this example, you can run a model for static and continuous batching with different Compute-Context-Length (CCL) inputs. ## + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +## Using optional variable comp_ctx_lengths variable you can pass a list of context lengths. It will run the model with default context length if comp_ctx_lengths=None. ## +## - The first comp_ctx_lengths_prefill list shows the compute-ctx-length list for prefilling process. ## +## - The second comp_ctx_lengths_decode list will be used for decoding. During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ## + +ctx_len = 1024 +comp_ctx_lengths_prefill = [256] +comp_ctx_lengths_decode = [512,ctx_len] + +# model_name = "google/gemma-7b" +# model_name = "google/gemma-2-2b" +# model_name = "ibm-granite/granite-3.1-8b-instruct" +# model_name = "Snowflake/Llama-3.1-SwiftKV-8B-Instruct" +# model_name = "mistralai/Mistral-7B-v0.1" +# model_name = "microsoft/phi-1_5" +# model_name = "microsoft/Phi-3-mini-4k-instruct" +# model_name = "Qwen/Qwen2.5-7B-Instruct" +model_name = "meta-llama/Llama-3.2-1B" +# model_name = "Qwen/Qwen3-1.7B" +# model_name = "allenai/OLMo-2-0425-1B" +# model_name = "ibm-granite/granite-3.3-2b-base" +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + continuous_batching=False, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, +) + +# model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. +model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=1, + batch_size=1, + mxint8_kv_cache=True, + mxfp6_matmul=True, +) + +# Create tokenizer and run model.generate and passes the input prompts to it. +tokenizer = AutoTokenizer.from_pretrained(model_name) +model.generate( + prompts=[ + "My name is ", + ], + tokenizer=tokenizer, + generation_len=128 +) diff --git a/examples/gemma3_example/ccl_gemma3_mm.py b/examples/gemma3_example/ccl_gemma3_mm.py new file mode 100644 index 000000000..484c0f8ce --- /dev/null +++ b/examples/gemma3_example/ccl_gemma3_mm.py @@ -0,0 +1,119 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +# Change model_id to "google/gemma-3-27b-it" for 27B model +model_id = "google/gemma-3-4b-it" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +# config.text_config.num_hidden_layers = 1 +# config.vision_config.num_hidden_layers = 2 +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id) + +# pass HF_TOKEN if gated model +# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### +ctx_len = 8192 +comp_ctx_lengths_prefill = [3072] +comp_ctx_lengths_decode = [4096, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + config=config, + attn_implementation="eager", + kv_offload=True, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, +) + +### use skip_vision=Ture, if want to run only text, or false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=896, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_27b.yaml", + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the transformers architecture in LLMs."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=896, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + mos=1, + node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_27b.yaml", + ) + + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/granite_example/ccl_granite_vision_inference.py b/examples/granite_example/ccl_granite_vision_inference.py new file mode 100644 index 000000000..e03b94a5e --- /dev/null +++ b/examples/granite_example/ccl_granite_vision_inference.py @@ -0,0 +1,127 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# Add HuggingFace Token to access the model +HF_TOKEN = "" + + +def run_model( + model_name, + token, + query, + image_url, + kv_offload=False, + prefill_seq_len=5500, + ctx_len=6000, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=384, + num_cores=16, + num_devices=1, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name, token=token) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + token=token, + kv_offload=kv_offload, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1610 x 1109) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1109 and width 1610 + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1610, 1109)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "ibm-granite/granite-vision-3.2-2b" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 5500 + ctx_len = 8192 + generation_len = 128 + img_size = 384 + num_cores = 16 + num_devices = 4 + ctx_len = 8192 + comp_ctx_lengths_prefill = [5500] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + token=HF_TOKEN, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: + +The image depicts two cats lying on a pink blanket that is spread out on a red couch. The cats are positioned in a relaxed manner, with their bodies stretched out and their heads resting on the blanket. +The cat on the left is a smaller, tabby cat with a mix of black, gray, and white fur. It has a long, slender body and a distinctive tail that is curled up near its tail end. The cat on the right is a larger, +tabby cat with a mix of gray, black, and brown fur. It has + +""" diff --git a/examples/granite_example/ccl_granitemoe_inference.py b/examples/granite_example/ccl_granitemoe_inference.py new file mode 100644 index 000000000..57668ca24 --- /dev/null +++ b/examples/granite_example/ccl_granitemoe_inference.py @@ -0,0 +1,40 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.constants import Constants + +model_name = "ibm-research/PowerMoE-3b" +""" +# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function +# We will use prompt_len=1 for compilation for both cb and non-cb inference +""" + +ctx_len = 2048 +comp_ctx_lengths_prefill = [256] +comp_ctx_lengths_decode = [512, 1024, ctx_len] + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + continuous_batching=False, +) +model.compile( + prefill_seq_len=1, + ctx_len=ctx_len, + batch_size=1, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, +) +tokenizer = AutoTokenizer.from_pretrained(model_name) +exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) diff --git a/examples/intern_example/ccl_internvl_inference.py b/examples/intern_example/ccl_internvl_inference.py new file mode 100644 index 000000000..5595d26cd --- /dev/null +++ b/examples/intern_example/ccl_internvl_inference.py @@ -0,0 +1,286 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from io import BytesIO +from typing import List + +import requests +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.logging_utils import logger + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +# Process the input messages to generate prompt for the model. +def get_prompt(messages) -> str: + """Get the prompt for generation.""" + ## Chat template used for InternVL + system_prompt = "<|im_start|>system\n你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + sep = "<|im_end|>\n" + + ret = system_prompt + sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + sep + else: + ret += role + return ret + + +# Processor class for InternVL models +class InternProcessor: + """ + InternVL model only has an AutoTokenizer so this class performs the processing tasks similar to an AutoProcessor. + The methods used here are borrowed from the original InternVL modelling files. + "https://huggingface.co/OpenGVLab/InternVL2_5-1B/" + """ + + def __init__(self, model: nn.Module, tokenizer): + self.model = model + image_size = self.model.config.force_image_size or self.model.config.vision_config.image_size + patch_size = self.model.config.vision_config.patch_size + self.template = model.config.template + self.num_image_token = int((image_size // patch_size) ** 2 * (self.model.config.downsample_ratio**2)) + self.tokenizer = tokenizer + + def build_transform(self, input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) + return transform + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + def load_image(self, image, input_size=448, max_num=12): + transform = self.build_transform(input_size=input_size) + images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + def __call__( + self, + pixel_values, + question, + messages, + roles, + history=None, + num_patches_list=None, + IMG_START_TOKEN="", + IMG_END_TOKEN="", + IMG_CONTEXT_TOKEN="", + verbose=False, + ) -> str: + if history is None and pixel_values is not None and "" not in question: + question = "\n" + question + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + assert pixel_values is None or len(pixel_values) == sum(num_patches_list) + img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.model.img_context_token_id = img_context_token_id + + messages.append([roles[0], question]) + messages.append([roles[1], None]) + query = get_prompt(messages) + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + logger.info(f"dynamic ViT batch size: {image_bs}") + + for num_patches in num_patches_list: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace("", image_tokens, 1) + return query + + +def run_intern_on_aic( + model_name, + prompt, + image_url, + messages, + roles, + kv_offload=False, + prefill_seq_len=3840, + num_devices=1, + num_cores=16, +): + ## STEP 1 -- LOAD THE MODEL + + # The original Intern-VL model, despite being multimodal, is loaded using `AutoModelForCausalLM` in Huggingface. + # To maintain compatibility, we load this model using `QEFFAutoModelForCausalLM`. + + ctx_len = 8192 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + # model = QEFFAutoModelForCausalLM.from_pretrained(model_name, kv_offload=kv_offload, trust_remote_code=True) + + model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + kv_offload=kv_offload, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + trust_remote_code=True, + ) + + ## STEP 2 -- EXPORT & COMPILE THE MODEL + + model.compile( + num_cores=num_cores, + num_devices=num_devices, + ctx_len=ctx_len, + prefill_seq_len=prefill_seq_len, + mxfp6_matmul=False, + ) + + ## STEP 3 -- SETUP THE PROCESSOR + + # InternVL doesn't have an AutoProcessor yet, so we will use our own processor class "InternProcessor" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + internProcessor = InternProcessor(model.model, tokenizer) + + ## STEP 4 -- PREPROCESS THE INPUTS + + img = requests.get(image_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + + # Images are resized to (1000, 747) for inference + image = image.resize((1000, 747)) + + # preprocess the resized image + pixel_values = internProcessor.load_image(image, max_num=12) + question = "\n" + prompt + query = internProcessor(pixel_values, question, messages, roles) + inputs = tokenizer( + query, return_tensors="pt", padding="max_length", max_length=prefill_seq_len, padding_side="right" + ) + + inputs["pixel_values"] = pixel_values + + ## STEP 5 -- RUN INFERENCE VIA GENERATE FUNCTION + streamer = TextStreamer(tokenizer) + model.generate(inputs=inputs, streamer=streamer, generation_len=128) + + +if __name__ == "__main__": + model_name = "OpenGVLab/InternVL2_5-1B" + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + + # Inputs for the model + prompt = "Please describe the image in detail." + image_url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg" + + ## Compilation parameters + + # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + kv_offload = False + + # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with + # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to + # incorporate the memory for the merged embeddings. + + prefill_seq_len = 3840 + num_devices = 4 + num_cores = 16 + + run_intern_on_aic( + model_name=model_name, + prompt=prompt, + image_url=image_url, + messages=messages, + roles=roles, + kv_offload=kv_offload, + prefill_seq_len=prefill_seq_len, + num_devices=num_devices, + num_cores=num_cores, + ) + + +""" +Expected Response: + +The image is a promotional graphic for Microsoft Azure. It features a blue background with a hexagonal pattern on the left side. The hexagons are white and are arranged in a way that suggests a network or connectivity theme. + +On the right side of the image, the Microsoft Azure logo is prominently displayed. The logo consists of the Azure name in white, with the Microsoft logo above it, which includes four colored squares (blue, green, yellow, and red). Below the logo, the word "Azure" is written in large white letters. + +Below the logo, there is text that reads: +- "By Dinesh Kumar Wick +""" diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py new file mode 100644 index 000000000..4d09b08f3 --- /dev/null +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -0,0 +1,42 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.constants import Constants + +model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" +""" +# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function +# We will use prompt_len=1 for compilation for both cb and non-cb inference +""" + +ctx_len = 8192 + +comp_ctx_lengths_prefill = [4096] +comp_ctx_lengths_decode = [6144,8192] + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + continuous_batching=False, +) +model.compile( + prefill_seq_len=1, + ctx_len=ctx_len, + batch_size=1, + num_cores=16, + num_devices=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, +) +tokenizer = AutoTokenizer.from_pretrained(model_name) +exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) diff --git a/tests/transformers/test_comp_ctx_length.py b/tests/transformers/test_comp_ctx_length.py new file mode 100644 index 000000000..e145ad698 --- /dev/null +++ b/tests/transformers/test_comp_ctx_length.py @@ -0,0 +1,193 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import copy +import os +from time import perf_counter + +import onnx +import pytest +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +configs = [ + # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] +config_ids = [x.model_type for x in configs] + +model_kwargs = {"attn_implementation": "eager"} + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +def test_causal_lm_unsupported(cb): + model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt")) + with pytest.warns(): + QEFFAutoModelForCausalLM(model, cb) + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_init(config, cb): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + with pytest.raises(TypeError): + QEFFAutoModelForCausalLM(AutoModel.from_config(config, **model_kwargs), cb) + assert qeff_model.model.__class__.__name__.startswith("QEff") + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_pretrained(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + model.save_pretrained(tmp_path) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(tmp_path, cb) + assert qeff_model.model.__class__.__name__.startswith("QEff") + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_hash(config, cb): + hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash + hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash + + assert hash_0_0 == hash_0_1 + + cfg1 = copy.deepcopy(config) + cfg1.num_hidden_layers -= 1 + hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash + cfg2 = copy.deepcopy(config) + cfg2.num_hidden_layers -= 1 + hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash + assert hash_1_0 == hash_1_1 + + assert hash_0_0 != hash_1_0 + + if cb: + hash_0_no_cb = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), False + ).model_hash + assert hash_0_0 != hash_0_no_cb + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_export(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + ctx_len = 2048 + comp_ctx_lengths_prefill = [256] + comp_ctx_lengths_decode = [512, 1024, ctx_len] + + qeff_model = QEFFAutoModelForCausalLM( + model, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + qeff_model.export(tmp_path) + model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) + assert model_path.is_dir() + assert qeff_model.onnx_path.is_file() + assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + + # Check if the KV-cache inputs and outputs are created + onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False) + retained_output_names = { + x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") + } + retained_output_names.issubset({x.name for x in onnx_model.graph.input}) + + # Check if there is no re-export + start = perf_counter() + qeff_model.export(tmp_path) + end = perf_counter() + export_time = end - start + assert export_time < 2.0 + + +@pytest.fixture +def tmp_cache(tmp_path, monkeypatch): + monkeypatch.setattr("QEfficient.base.modeling_qeff.QEFF_HOME", tmp_path) + yield tmp_path + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_compile(config, cb, tmp_cache): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + ctx_len = 2048 + comp_ctx_lengths_prefill = [256] + comp_ctx_lengths_decode = [512, 1024, ctx_len] + qeff_model = QEFFAutoModelForCausalLM( + model, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + compile_params = {"prefill_seq_len": 8, "ctx_len": ctx_len} + if cb: + compile_params["full_batch_size"] = 32 + compile_params["batch_size"] = 8 + qeff_model.compile(**compile_params) + model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash) + + # Check if ONNX is exported properly + assert model_path.is_dir() + assert qeff_model.onnx_path.is_file() + assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + + # Check if QPC is compiled properly + assert qeff_model.qpc_path.is_dir() + assert (qeff_model.qpc_path / "programqpc.bin").is_file() + assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash + + # Check if there is no re-compilation + start = perf_counter() + qeff_model.compile(**compile_params) + end = perf_counter() + compile_time = end - start + assert compile_time < 2.0 + assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) From 5410733fa417b683411a1fa25019412f3df34c5b Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 16 Oct 2025 23:29:20 -0700 Subject: [PATCH 02/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- examples/ccl_mistral3_example.py | 7 ++++--- examples/ccl_qwen2_5_vl_example.py | 10 +++++----- examples/compute_context_length.py | 4 ++-- examples/qwen3moe_example/ccl_qwen3moe_inference.py | 2 +- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/ccl_mistral3_example.py b/examples/ccl_mistral3_example.py index ed02a4fa9..b76227a22 100644 --- a/examples/ccl_mistral3_example.py +++ b/examples/ccl_mistral3_example.py @@ -38,12 +38,13 @@ def run_model( config = AutoConfig.from_pretrained(model_name) config.vision_config._attn_implementation = "eager" - model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, - kv_offload=kv_offload, + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + kv_offload=kv_offload, config=config, ctx_len=ctx_len, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode + comp_ctx_lengths_decode=comp_ctx_lengths_decode, ) ## STEP - 2 Export & Compile the Model diff --git a/examples/ccl_qwen2_5_vl_example.py b/examples/ccl_qwen2_5_vl_example.py index 7056011f2..74063929b 100644 --- a/examples/ccl_qwen2_5_vl_example.py +++ b/examples/ccl_qwen2_5_vl_example.py @@ -24,16 +24,16 @@ ctx_len = 32768 comp_ctx_lengths_prefill = [4000] -comp_ctx_lengths_decode = [4096, 8192,16384, ctx_len] +comp_ctx_lengths_decode = [4096, 8192, 16384, ctx_len] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + model_id, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, - attn_implementation="eager", - kv_offload=True, - config=config + attn_implementation="eager", + kv_offload=True, + config=config, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py index 554c61c84..c1e5dc0df 100644 --- a/examples/compute_context_length.py +++ b/examples/compute_context_length.py @@ -17,7 +17,7 @@ ctx_len = 1024 comp_ctx_lengths_prefill = [256] -comp_ctx_lengths_decode = [512,ctx_len] +comp_ctx_lengths_decode = [512, ctx_len] # model_name = "google/gemma-7b" # model_name = "google/gemma-2-2b" @@ -57,5 +57,5 @@ "My name is ", ], tokenizer=tokenizer, - generation_len=128 + generation_len=128, ) diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index 4d09b08f3..4a7a16c1b 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -19,7 +19,7 @@ ctx_len = 8192 comp_ctx_lengths_prefill = [4096] -comp_ctx_lengths_decode = [6144,8192] +comp_ctx_lengths_decode = [6144, 8192] model = QEFFAutoModelForCausalLM.from_pretrained( model_name, From 13271c6e0501c3e60c3f42a27d8a40017d7bfdd7 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Fri, 17 Oct 2025 10:40:44 -0700 Subject: [PATCH 03/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- examples/intern_example/ccl_internvl_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/intern_example/ccl_internvl_inference.py b/examples/intern_example/ccl_internvl_inference.py index 5595d26cd..0828b1d41 100644 --- a/examples/intern_example/ccl_internvl_inference.py +++ b/examples/intern_example/ccl_internvl_inference.py @@ -251,7 +251,7 @@ def run_intern_on_aic( # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. # The outputs of the Vision Encoder are then passed to the Language model via host in this case. - kv_offload = False + kv_offload = True # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to From 3332962971b27a3ad3e5e78adcc6308a69226977 Mon Sep 17 00:00:00 2001 From: vjanfaza Date: Fri, 17 Oct 2025 10:46:50 -0700 Subject: [PATCH 04/41] Delete examples/granite_example/ccl_granitemoe_inference.py Signed-off-by: vjanfaza --- .../ccl_granitemoe_inference.py | 40 ------------------- 1 file changed, 40 deletions(-) delete mode 100644 examples/granite_example/ccl_granitemoe_inference.py diff --git a/examples/granite_example/ccl_granitemoe_inference.py b/examples/granite_example/ccl_granitemoe_inference.py deleted file mode 100644 index 57668ca24..000000000 --- a/examples/granite_example/ccl_granitemoe_inference.py +++ /dev/null @@ -1,40 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from transformers import AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM -from QEfficient.utils.constants import Constants - -model_name = "ibm-research/PowerMoE-3b" -""" -# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function -# We will use prompt_len=1 for compilation for both cb and non-cb inference -""" - -ctx_len = 2048 -comp_ctx_lengths_prefill = [256] -comp_ctx_lengths_decode = [512, 1024, ctx_len] - -model = QEFFAutoModelForCausalLM.from_pretrained( - model_name, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - continuous_batching=False, -) -model.compile( - prefill_seq_len=1, - ctx_len=ctx_len, - batch_size=1, - num_cores=16, - num_devices=4, - mxfp6_matmul=False, - mxint8_kv_cache=False, -) -tokenizer = AutoTokenizer.from_pretrained(model_name) -exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) From b4bf5f9ab5153aaea0dd093a4db19f7809439209 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Sun, 19 Oct 2025 08:07:22 -0700 Subject: [PATCH 05/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 13 +++++++++---- .../qwen3moe_example/ccl_qwen3moe_inference.py | 14 +++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5c6f67ddc..ad08fd0ac 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2235,8 +2235,9 @@ def __init__( self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) - - if self.comp_ctx_lengths_prefill: + prefill_seq_len = kwargs.pop("prefill_seq_len", 128) + + if self.comp_ctx_lengths_prefill and prefill_seq_len > 1: self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len ) @@ -2338,7 +2339,9 @@ def from_pretrained( comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) - if comp_ctx_lengths_prefill: + prefill_seq_len = kwargs.pop("prefill_seq_len", 128) + + if comp_ctx_lengths_prefill and prefill_seq_len > 1: comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len ) @@ -2356,6 +2359,7 @@ def from_pretrained( comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, + prefill_seq_len=prefill_seq_len, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs, @@ -2368,6 +2372,7 @@ def from_pretrained( comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, + prefill_seq_len=prefill_seq_len, **kwargs, ) @@ -2643,7 +2648,7 @@ def build_decode_specialization( A dictionary defining the decode specialization, or None if it would be a duplicate of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ - if prefill_seq_len == 1 and not self.continuous_batching and comp_ctx_lengths is None: + if prefill_seq_len == 1 and not self.continuous_batching:# and comp_ctx_lengths is None return None # Avoid duplication with prefill spec = { "batch_size": full_batch_size if self.continuous_batching else batch_size, diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index 4a7a16c1b..98258affc 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -16,10 +16,11 @@ # We will use prompt_len=1 for compilation for both cb and non-cb inference """ -ctx_len = 8192 - -comp_ctx_lengths_prefill = [4096] -comp_ctx_lengths_decode = [6144, 8192] +ctx_len = 65536 +prefill_seq_len = 1 +# In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same specializations. +comp_ctx_lengths_prefill = [4096,8192,16384,32768,ctx_len] +comp_ctx_lengths_decode = [4096,8192,16384,32768,ctx_len] model = QEFFAutoModelForCausalLM.from_pretrained( model_name, @@ -27,9 +28,11 @@ comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, continuous_batching=False, + prefill_seq_len=prefill_seq_len, ) + # prefill_seq_len=prefill_seq_len, model.compile( - prefill_seq_len=1, + prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, batch_size=1, num_cores=16, @@ -38,5 +41,6 @@ mxint8_kv_cache=True, mos=1, ) + # mos=1, tokenizer = AutoTokenizer.from_pretrained(model_name) exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) From 9363689e23794f7b4f78de7e86078b07c1235597 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Sun, 19 Oct 2025 20:00:08 -0700 Subject: [PATCH 06/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index ad08fd0ac..139ac30eb 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2648,8 +2648,10 @@ def build_decode_specialization( A dictionary defining the decode specialization, or None if it would be a duplicate of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ - if prefill_seq_len == 1 and not self.continuous_batching:# and comp_ctx_lengths is None - return None # Avoid duplication with prefill + if prefill_seq_len == 1: + if not self.continuous_batching or batch_size==1: + return None # Avoid duplication with prefill + spec = { "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, From a4fca59c95286f8b74f508c027c2fe69c134eda3 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Sun, 19 Oct 2025 20:13:16 -0700 Subject: [PATCH 07/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 8 ++++---- .../qwen3moe_example/ccl_qwen3moe_inference.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 139ac30eb..bf2a445ad 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2236,7 +2236,7 @@ def __init__( self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - + if self.comp_ctx_lengths_prefill and prefill_seq_len > 1: self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len @@ -2340,7 +2340,7 @@ def from_pretrained( comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - + if comp_ctx_lengths_prefill and prefill_seq_len > 1: comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len @@ -2649,9 +2649,9 @@ def build_decode_specialization( of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ if prefill_seq_len == 1: - if not self.continuous_batching or batch_size==1: + if not self.continuous_batching or batch_size == 1: return None # Avoid duplication with prefill - + spec = { "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index 98258affc..12e9ca1fc 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -16,31 +16,31 @@ # We will use prompt_len=1 for compilation for both cb and non-cb inference """ -ctx_len = 65536 +ctx_len = 32768 prefill_seq_len = 1 # In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same specializations. -comp_ctx_lengths_prefill = [4096,8192,16384,32768,ctx_len] -comp_ctx_lengths_decode = [4096,8192,16384,32768,ctx_len] +comp_ctx_lengths_prefill = [4096, 8192, 16384, ctx_len] +comp_ctx_lengths_decode = [4096, 8192, 16384, ctx_len] model = QEFFAutoModelForCausalLM.from_pretrained( model_name, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, - continuous_batching=False, + continuous_batching=True, prefill_seq_len=prefill_seq_len, ) - # prefill_seq_len=prefill_seq_len, +# prefill_seq_len=prefill_seq_len, model.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - batch_size=1, + full_batch_size=1, num_cores=16, num_devices=4, mxfp6_matmul=True, mxint8_kv_cache=True, mos=1, ) - # mos=1, +# mos=1, tokenizer = AutoTokenizer.from_pretrained(model_name) exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) From 5f047b4dae773652a5a2d5e70d070f18cd53143f Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 18:48:40 -0700 Subject: [PATCH 08/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- .../transformers/models/modeling_auto.py | 64 ++----------------- QEfficient/utils/check_ccl_specializations.py | 25 ++++++-- 2 files changed, 26 insertions(+), 63 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index bf2a445ad..6421a5b91 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -879,13 +879,7 @@ def __init__( self.model = model self.config = model.config - self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - if self.comp_ctx_lengths_prefill: - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len - ) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) @@ -933,14 +927,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - - if comp_ctx_lengths_prefill: - comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len - ) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( @@ -1498,14 +1485,7 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") super().__init__(model, **kwargs) - self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - - if self.comp_ctx_lengths_prefill: - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len - ) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): @@ -1554,14 +1534,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - - if comp_ctx_lengths_prefill: - comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len - ) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) from transformers import AutoConfig @@ -2115,14 +2088,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - - if comp_ctx_lengths_prefill: - comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len - ) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( @@ -2232,15 +2198,7 @@ def __init__( self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) self.is_tlm = transformed - self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - - if self.comp_ctx_lengths_prefill and prefill_seq_len > 1: - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len - ) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) self.hash_params["qeff_auto_class"] = self.__class__.__name__ @@ -2336,15 +2294,7 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) - comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - - if comp_ctx_lengths_prefill and prefill_seq_len > 1: - comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len - ) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index dbfb08926..8107447de 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -8,15 +8,28 @@ from typing import List, Optional +# def process_ccl_specializations( +# ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None +# ): def process_ccl_specializations( - ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None + kwargs ): + ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + prefill_seq_len = kwargs.pop("prefill_seq_len", 128) + if ctx_len is None: raise TypeError("`ctx_len` is required when loading the model.") - if ccl_prefill is None: - ccl_prefill = [ctx_len] - if ccl_decode is None: - ccl_decode = [ctx_len] + + if ccl_prefill is None or ccl_decode is None: + return None, None, ctx_len, prefill_seq_len + + if prefill_seq_len == 1: + #both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. + ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) + ccl_union_all = [min(x, ctx_len) for x in ccl_union_all] + return ccl_union_all, ccl_union_all, ctx_len, prefill_seq_len # Step 1: Cap values to ctx_len ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] @@ -40,4 +53,4 @@ def process_ccl_specializations( updated_prefill.sort() ccl_decode.sort() - return updated_prefill, ccl_decode + return updated_prefill, ccl_decode, ctx_len, prefill_seq_len From 71c5182651dabbf228e68cac6d893c47b51487df Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 18:53:54 -0700 Subject: [PATCH 09/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 16 ++++++++++++---- QEfficient/utils/check_ccl_specializations.py | 10 +++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 6421a5b91..9fb9a9c0a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -927,7 +927,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( + kwargs + ) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( @@ -1534,7 +1536,9 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( + kwargs + ) from transformers import AutoConfig @@ -2088,7 +2092,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( + kwargs + ) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( @@ -2294,7 +2300,9 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( + kwargs + ) kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 8107447de..45d2ea903 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -5,15 +5,11 @@ # # ----------------------------------------------------------------------------- -from typing import List, Optional - # def process_ccl_specializations( # ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None # ): -def process_ccl_specializations( - kwargs -): +def process_ccl_specializations(kwargs): ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) @@ -24,9 +20,9 @@ def process_ccl_specializations( if ccl_prefill is None or ccl_decode is None: return None, None, ctx_len, prefill_seq_len - + if prefill_seq_len == 1: - #both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. + # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) ccl_union_all = [min(x, ctx_len) for x in ccl_union_all] return ccl_union_all, ccl_union_all, ctx_len, prefill_seq_len From 1d74b42ffc568ed73d1299d25d7b9f32ade42b32 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 18:57:22 -0700 Subject: [PATCH 10/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- QEfficient/utils/check_ccl_specializations.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 45d2ea903..0c7555512 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -5,10 +5,6 @@ # # ----------------------------------------------------------------------------- - -# def process_ccl_specializations( -# ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None -# ): def process_ccl_specializations(kwargs): ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) From 811b1ce4dd99cd47b37ce4667acb44325372ba0d Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 18:58:30 -0700 Subject: [PATCH 11/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- QEfficient/utils/check_ccl_specializations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 0c7555512..3e66bfd35 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- + def process_ccl_specializations(kwargs): ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) From 7b57d90bed09383113badaf1c4bb9372d2b30e88 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 19:15:01 -0700 Subject: [PATCH 12/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- QEfficient/utils/check_ccl_specializations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 3e66bfd35..6cb54a6c5 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -12,12 +12,12 @@ def process_ccl_specializations(kwargs): ctx_len = kwargs.pop("ctx_len", None) prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - if ctx_len is None: - raise TypeError("`ctx_len` is required when loading the model.") - if ccl_prefill is None or ccl_decode is None: return None, None, ctx_len, prefill_seq_len + if ctx_len is None: + raise TypeError("`ctx_len` is required when loading the model with CCL.") + if prefill_seq_len == 1: # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) From 0b88a32b5945d2ce5d721d2d6a04b2b082b9bb35 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 22 Oct 2025 16:06:00 -0700 Subject: [PATCH 13/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- .../models/codegen/modeling_codegen.py | 11 +- .../models/falcon/modeling_falcon.py | 11 +- .../models/gemma3/modeling_gemma3.py | 24 +- .../transformers/models/gpt2/modeling_gpt2.py | 11 +- .../transformers/models/gptj/modeling_gptj.py | 11 +- .../models/grok_1/modeling_grok1.py | 11 +- .../models/internvl/modeling_internvl.py | 26 ++- .../models/llama4/modeling_llama4.py | 24 +- .../models/llava/modeling_llava.py | 26 ++- .../models/llava_next/modeling_llava_next.py | 18 +- .../models/mistral3/modeling_mistral3.py | 24 +- .../models/mllama/modeling_mllama.py | 8 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 16 +- .../models/whisper/modeling_whisper.py | 13 +- examples/ccl_qwen2_5_vl_example.py | 7 +- examples/compute_context_length.py | 11 +- .../ccl_qwen3moe_inference.py | 12 +- tests/transformers/test_comp_ctx_length.py | 205 ++++++++++++++---- 18 files changed, 365 insertions(+), 104 deletions(-) diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py index 776bfce43..15efa2ce5 100644 --- a/QEfficient/transformers/models/codegen/modeling_codegen.py +++ b/QEfficient/transformers/models/codegen/modeling_codegen.py @@ -72,6 +72,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -123,7 +124,9 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -147,6 +150,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -245,6 +249,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, attention_mask=attention_mask, position_ids=position_ids, @@ -294,6 +299,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -312,6 +318,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, batch_index=batch_index, @@ -348,6 +355,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -361,6 +369,7 @@ def forward( attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 8f2c3730d..218852b15 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -117,6 +117,7 @@ def forward( attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, @@ -140,7 +141,9 @@ def forward( query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) if attention_mask is not None: @@ -172,6 +175,7 @@ def forward( attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None, head_mask: Optional[torch.Tensor] = None, @@ -195,6 +199,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, alibi=alibi, head_mask=head_mask, @@ -245,6 +250,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -307,6 +313,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, head_mask=head_mask[i], use_cache=use_cache, @@ -352,6 +359,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -368,6 +376,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, head_mask=head_mask, inputs_embeds=inputs_embeds, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 2e8494e8e..95ee662b4 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -603,7 +603,13 @@ def __init__(self, model): self.lm_head = self.model.lm_head def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -637,7 +643,13 @@ def get_qeff_language_decoder(self): return QEffGemma3DecoderWrapper(self) def forward( - self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): image_features = self.get_image_features(pixel_values=pixel_values) inputs_embeds = self.get_input_embeddings()(input_ids) @@ -669,8 +681,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -749,7 +761,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -825,7 +837,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index d68a65430..59d864907 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -65,6 +65,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -118,9 +119,11 @@ def forward( if (past_key_value is not None and not is_cross_attention) or ( past_key_value is not None and is_cross_attention and not is_updated ): + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # save all key/value_layer to cache to be re-used for fast auto-regressive generation # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -156,6 +159,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -174,6 +178,7 @@ def forward( hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, + comp_ctx_lengths=comp_ctx_lengths, position_ids=position_ids, batch_index=batch_index, head_mask=head_mask, @@ -232,6 +237,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -341,6 +347,7 @@ def forward( outputs = block( hidden_states, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -392,6 +399,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -418,6 +426,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/gptj/modeling_gptj.py b/QEfficient/transformers/models/gptj/modeling_gptj.py index dc3e5e6d2..da5bd881c 100644 --- a/QEfficient/transformers/models/gptj/modeling_gptj.py +++ b/QEfficient/transformers/models/gptj/modeling_gptj.py @@ -83,6 +83,7 @@ def forward( self, hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -134,7 +135,9 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -151,6 +154,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -164,6 +168,7 @@ def forward( attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -191,6 +196,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -270,6 +276,7 @@ def forward( outputs = block( hidden_states=hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -314,6 +321,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -339,6 +347,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 567a8e070..a0f9cd915 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -55,6 +55,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -93,7 +94,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -205,6 +208,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -235,6 +239,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -277,6 +282,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -351,6 +357,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -395,6 +402,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -441,6 +449,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 29e6ac9a4..96c59325f 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -37,7 +37,13 @@ def __init__(self, model): self.language_model = self.model.language_model def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape @@ -82,8 +88,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -172,7 +178,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -216,7 +222,7 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -291,7 +297,13 @@ def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool return inputs def forward( - self, input_ids, pixel_values, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + pixel_values, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): input_embeds = self.language_model.get_input_embeddings()(input_ids) vision_embeds = self.extract_feature(pixel_values) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 82678e380..0fbdbea5f 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -850,7 +850,13 @@ def __init__(self, model): self.config = self.model.config def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index @@ -880,7 +886,13 @@ def get_qeff_language_decoder(self): return QEffLlama4DecoderWrapper(self) def forward( - self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.language_model.get_input_embeddings()(input_ids) vision_feature_layer = self.config.vision_config.vision_feature_layer @@ -917,8 +929,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -1034,7 +1046,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -1109,7 +1121,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 450fc79b6..dc6653db0 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -54,7 +54,13 @@ def __init__(self, model): self.lm_head = self.model.lm_head def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) @@ -89,7 +95,13 @@ def get_qeff_language_decoder(self): return QEFFLlavaDecoderWrapper(self) def forward( - self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.get_input_embeddings()(input_ids) # Image features @@ -128,7 +140,7 @@ def forward( image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -178,8 +190,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -259,7 +271,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index b23073fa7..2e4848b6b 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Optional import numpy as np import torch @@ -126,7 +126,13 @@ def __init__(self, model): self.lm_head = self.model.lm_head def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) @@ -159,7 +165,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlavaNextDecoderWrapper(self) - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -241,8 +247,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -365,7 +371,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index a5f1301d2..694ed4cde 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -169,7 +169,13 @@ def __init__(self, model): self.language_model = self.model.language_model def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) @@ -204,7 +210,13 @@ def get_qeff_language_decoder(self): return QEFFMistral3DecoderWrapper(self) def forward( - self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.get_input_embeddings()(input_ids) image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]).repeat(pixel_values.shape[0], 1) @@ -238,7 +250,7 @@ def forward( return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size @@ -309,8 +321,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -392,7 +404,7 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 2197bec91..d6fb1dcd2 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -899,7 +899,7 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN CTX_LEN = constants.ONNX_EXPORT_CTX_LEN @@ -983,8 +983,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -1055,7 +1055,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 3b1d116de..ac91d5477 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -696,7 +696,13 @@ def __init__(self, model): self.language_model = self.model.model.language_model def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -730,7 +736,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -799,8 +805,8 @@ def get_specializations( img_size: None, height: int = None, width: int = None, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -932,7 +938,7 @@ def smart_resize( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.num_hidden_layers diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index e078493a7..79907818d 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -55,6 +55,7 @@ def forward( position_ids_layer: torch.Tensor = None, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -99,7 +100,9 @@ def forward( key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() if past_key_value is not None: - cache_kwargs = {"position_ids": position_ids_layer} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids_layer, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -181,6 +184,7 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.LongTensor] = None, @@ -215,6 +219,7 @@ def forward( hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -388,6 +393,7 @@ def forward( cross_attn_head_mask=None, position_ids=None, past_key_values=None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds=None, use_cache=None, output_attentions=None, @@ -532,6 +538,7 @@ def forward( layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, position_ids_layer=position_ids, @@ -643,6 +650,7 @@ def forward( cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, @@ -674,6 +682,7 @@ def forward( head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -719,6 +728,7 @@ def forward( cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, @@ -740,6 +750,7 @@ def forward( decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, decoder_inputs_embeds=decoder_inputs_embeds, decoder_position_ids=position_ids, use_cache=use_cache, diff --git a/examples/ccl_qwen2_5_vl_example.py b/examples/ccl_qwen2_5_vl_example.py index 74063929b..b813462e3 100644 --- a/examples/ccl_qwen2_5_vl_example.py +++ b/examples/ccl_qwen2_5_vl_example.py @@ -21,10 +21,9 @@ ## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model -ctx_len = 32768 - -comp_ctx_lengths_prefill = [4000] -comp_ctx_lengths_decode = [4096, 8192, 16384, ctx_len] +ctx_len = 8192 +comp_ctx_lengths_prefill = [4096] +comp_ctx_lengths_decode = [6144, ctx_len] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py index c1e5dc0df..00d475ae0 100644 --- a/examples/compute_context_length.py +++ b/examples/compute_context_length.py @@ -31,9 +31,16 @@ # model_name = "Qwen/Qwen3-1.7B" # model_name = "allenai/OLMo-2-0425-1B" # model_name = "ibm-granite/granite-3.3-2b-base" +# model_name = "meta-llama/Llama-3.3-70B-Instruct" +# model_name = "Salesforce/codegen-350M-mono" +# model_name = "tiiuae/falcon-7b-instruct" +# model_name = "openai-community/gpt2" +# model_name = "EleutherAI/gpt-j-6b" +# model_name = "EleutherAI/gpt-j-6b" + model = QEFFAutoModelForCausalLM.from_pretrained( model_name, - continuous_batching=False, + continuous_batching=True, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, @@ -45,7 +52,7 @@ ctx_len=ctx_len, num_cores=16, num_devices=1, - batch_size=1, + full_batch_size=1, mxint8_kv_cache=True, mxfp6_matmul=True, ) diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index 12e9ca1fc..f200c6fa6 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -16,25 +16,25 @@ # We will use prompt_len=1 for compilation for both cb and non-cb inference """ -ctx_len = 32768 +ctx_len = 1024 prefill_seq_len = 1 # In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same specializations. -comp_ctx_lengths_prefill = [4096, 8192, 16384, ctx_len] -comp_ctx_lengths_decode = [4096, 8192, 16384, ctx_len] +comp_ctx_lengths_prefill = [256, 512, ctx_len] +comp_ctx_lengths_decode = [256, 512, ctx_len] model = QEFFAutoModelForCausalLM.from_pretrained( model_name, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, - continuous_batching=True, + continuous_batching=False, prefill_seq_len=prefill_seq_len, ) -# prefill_seq_len=prefill_seq_len, + model.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - full_batch_size=1, + batch_size=1, num_cores=16, num_devices=4, mxfp6_matmul=True, diff --git a/tests/transformers/test_comp_ctx_length.py b/tests/transformers/test_comp_ctx_length.py index e145ad698..31b9da07e 100644 --- a/tests/transformers/test_comp_ctx_length.py +++ b/tests/transformers/test_comp_ctx_length.py @@ -1,6 +1,6 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. # SPDX-License-Identifier: BSD-3-Clause # # ---------------------------------------------------------------------------- @@ -14,6 +14,8 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.utils import constants, get_padding_shape_from_config +from QEfficient.utils.hash_utils import hash_dict_params configs = [ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params @@ -30,6 +32,7 @@ ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("starcoder2", 256, 2, 4, 128, 512, 127, {}), ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ @@ -62,17 +65,41 @@ @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) def test_causal_lm_unsupported(cb): model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt")) + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] with pytest.warns(): - QEFFAutoModelForCausalLM(model, cb) + QEFFAutoModelForCausalLM( + model, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) @pytest.mark.parametrize("config", configs, ids=config_ids) def test_causal_lm_init(config, cb): model = AutoModelForCausalLM.from_config(config, **model_kwargs) - qeff_model = QEFFAutoModelForCausalLM(model, cb) + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] + qeff_model = QEFFAutoModelForCausalLM( + model, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) with pytest.raises(TypeError): - QEFFAutoModelForCausalLM(AutoModel.from_config(config, **model_kwargs), cb) + QEFFAutoModelForCausalLM( + AutoModel.from_config(config, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) assert qeff_model.model.__class__.__name__.startswith("QEff") @@ -82,43 +109,112 @@ def test_causal_lm_pretrained(config, cb, tmp_path): model = AutoModelForCausalLM.from_config(config, **model_kwargs) model.save_pretrained(tmp_path) - qeff_model = QEFFAutoModelForCausalLM.from_pretrained(tmp_path, cb) + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + tmp_path, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) assert qeff_model.model.__class__.__name__.startswith("QEff") @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) @pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_hash(config, cb): - hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash - hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash +def test_causal_lm_export_and_hash(config, cb, tmp_path): + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] + model_0_0 = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_0_0.export(tmp_path) + model_path = tmp_path.with_name(tmp_path.name + "-" + model_0_0.export_hash) + assert model_path.is_dir() + assert model_0_0.onnx_path.is_file() + assert model_0_0.onnx_path.relative_to(model_path).parts == (model_0_0.model_name + ".onnx",) + + # Check if the KV-cache inputs and outputs are created + onnx_model = onnx.load(model_0_0.onnx_path, load_external_data=False) + retained_output_names = { + x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") + } + retained_output_names.issubset({x.name for x in onnx_model.graph.input}) + + # Check if there is no re-export + start = perf_counter() + model_0_0.export(tmp_path) + end = perf_counter() + export_time = end - start + assert export_time < 2.0 + + # Check if hashing is happening properly + model_0_1 = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_0_1.export(tmp_path) + hash_0_0 = model_0_0.export_hash + hash_0_1 = model_0_1.export_hash assert hash_0_0 == hash_0_1 cfg1 = copy.deepcopy(config) cfg1.num_hidden_layers -= 1 - hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash + model_1_0 = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(cfg1, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_1_0.export(tmp_path) + hash_1_0 = model_1_0.export_hash cfg2 = copy.deepcopy(config) cfg2.num_hidden_layers -= 1 - hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash + model_1_1 = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(cfg2, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_1_1.export(tmp_path) + hash_1_1 = model_1_1.export_hash assert hash_1_0 == hash_1_1 assert hash_0_0 != hash_1_0 if cb: - hash_0_no_cb = QEFFAutoModelForCausalLM( - AutoModelForCausalLM.from_config(config, **model_kwargs), False - ).model_hash + model_0_no_cb = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), + False, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_0_no_cb.export(tmp_path) + hash_0_no_cb = model_0_no_cb.export_hash assert hash_0_0 != hash_0_no_cb @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) @pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_export(config, cb, tmp_path): +def test_causal_lm_hash_creation(config, cb, tmp_path): model = AutoModelForCausalLM.from_config(config, **model_kwargs) - ctx_len = 2048 - comp_ctx_lengths_prefill = [256] - comp_ctx_lengths_decode = [512, 1024, ctx_len] - + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] qeff_model = QEFFAutoModelForCausalLM( model, cb, @@ -127,29 +223,59 @@ def test_causal_lm_export(config, cb, tmp_path): ctx_len=ctx_len, ) qeff_model.export(tmp_path) - model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) - assert model_path.is_dir() - assert qeff_model.onnx_path.is_file() - assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + hash_params = {} + hash_params["config"] = qeff_model.model.config.to_diff_dict() + hash_params["peft_config"] = None + hash_params["applied_transform_names"] = qeff_model._transform_names() + hash_params["qeff_auto_class"] = qeff_model.__class__.__name__ + hash_params["qaic_config"] = None - # Check if the KV-cache inputs and outputs are created - onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False) - retained_output_names = { - x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") + # Create parameters separately for hash creation + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + kv_cache_shape = get_padding_shape_from_config( + qeff_model.model.config, fbs if qeff_model.continuous_batching else bs, seq_len + ) + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, } - retained_output_names.issubset({x.name for x in onnx_model.graph.input}) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d + pkv_dynamic_axes = { + 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size", + 1: "ctx_len", + } + else: # pkv is 4d + pkv_dynamic_axes = { + 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size", + 2: "ctx_len", + } + output_names = [] + output_names.append("logits") - # Check if there is no re-export - start = perf_counter() - qeff_model.export(tmp_path) - end = perf_counter() - export_time = end - start - assert export_time < 2.0 + for i in range(qeff_model.num_layers): + for kv in ["key", "value"]: + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + output_names.append(f"past_{kv}.{i}_RetainedState") + + if qeff_model.continuous_batching: + dynamic_axes["batch_index"] = {0: "batch_size"} + + export_params = {} + export_params["output_names"] = output_names + export_params["dynamic_axes"] = dynamic_axes + hash_params["export_params"] = export_params + manual_hash = hash_dict_params(hash_params) + + assert manual_hash == qeff_model.export_hash @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.base.modeling_qeff.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) yield tmp_path @@ -157,9 +283,10 @@ def tmp_cache(tmp_path, monkeypatch): @pytest.mark.parametrize("config", configs, ids=config_ids) def test_causal_lm_compile(config, cb, tmp_cache): model = AutoModelForCausalLM.from_config(config, **model_kwargs) - ctx_len = 2048 - comp_ctx_lengths_prefill = [256] - comp_ctx_lengths_decode = [512, 1024, ctx_len] + + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] qeff_model = QEFFAutoModelForCausalLM( model, cb, @@ -172,7 +299,7 @@ def test_causal_lm_compile(config, cb, tmp_cache): compile_params["full_batch_size"] = 32 compile_params["batch_size"] = 8 qeff_model.compile(**compile_params) - model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash) + model_path = tmp_cache / qeff_model.model_name / (qeff_model.model_name + "-" + qeff_model.export_hash) # Check if ONNX is exported properly assert model_path.is_dir() @@ -182,7 +309,7 @@ def test_causal_lm_compile(config, cb, tmp_cache): # Check if QPC is compiled properly assert qeff_model.qpc_path.is_dir() assert (qeff_model.qpc_path / "programqpc.bin").is_file() - assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash + assert qeff_model.qpc_path.relative_to(tmp_cache).parts[1] == qeff_model.model_name + "-" + qeff_model.export_hash # Check if there is no re-compilation start = perf_counter() From 2ade9137e62d1c118830ad4e33651489c37d7965 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 22 Oct 2025 17:28:48 -0700 Subject: [PATCH 14/41] fixing lora testing Signed-off-by: Vahid Janfaza --- QEfficient/peft/lora/layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/peft/lora/layers.py b/QEfficient/peft/lora/layers.py index 6b75e696f..79abeba77 100644 --- a/QEfficient/peft/lora/layers.py +++ b/QEfficient/peft/lora/layers.py @@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor): # multilora implementation: lora_ids other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1) selected_lora_a_weights = CtxGatherFuncCB.apply( - self.lora_a_weights, lora_ids, other_indices_a + self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2] ) # other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1) selected_lora_b_weights = CtxGatherFuncCB.apply( - self.lora_b_weights, lora_ids, other_indices_b + self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2] ) # other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1) selected_lora_scalings = CtxGatherFuncCB.apply( - self.lora_scalings, lora_ids, other_indices_s + self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2] ) # selected_lora_a_weights = selected_lora_a_weights.squeeze(1) From acf35442fa0ffb3e7f1c6cb09468dbeba863a359 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 22 Oct 2025 17:32:04 -0700 Subject: [PATCH 15/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- examples/compute_context_length.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py index 00d475ae0..dc6991b16 100644 --- a/examples/compute_context_length.py +++ b/examples/compute_context_length.py @@ -41,9 +41,6 @@ model = QEFFAutoModelForCausalLM.from_pretrained( model_name, continuous_batching=True, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, ) # model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. From 2643e9f1d157f8011266bda3891219936fd786a8 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 16 Oct 2025 23:25:29 -0700 Subject: [PATCH 16/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/cloud/infer.py | 12 + QEfficient/customop/ctx_scatter_gather.py | 16 +- QEfficient/customop/ctx_scatter_gather_cb.py | 18 +- .../generation/text_generation_inference.py | 76 +++++ QEfficient/transformers/cache_utils.py | 50 +-- .../models/gemma/modeling_gemma.py | 11 +- .../models/gemma2/modeling_gemma2.py | 17 +- .../models/gemma3/modeling_gemma3.py | 111 +++++-- .../models/granite/modeling_granite.py | 15 +- .../models/granitemoe/modeling_granitemoe.py | 9 + .../models/internvl/modeling_internvl.py | 100 ++++-- .../models/llama/modeling_llama.py | 11 +- .../models/llama4/modeling_llama4.py | 125 ++++++-- .../llama_swiftkv/modeling_llama_swiftkv.py | 28 +- .../models/llava/modeling_llava.py | 92 ++++-- .../models/llava_next/modeling_llava_next.py | 106 +++++-- .../models/mistral/modeling_mistral.py | 11 +- .../models/mistral3/modeling_mistral3.py | 86 ++++-- .../models/mixtral_moe/modeling_mixtral.py | 11 +- .../models/mllama/modeling_mllama.py | 98 ++++-- .../transformers/models/modeling_auto.py | 272 +++++++++++++++-- .../transformers/models/mpt/modeling_mpt.py | 11 +- .../models/olmo2/modeling_olmo2.py | 11 +- .../transformers/models/phi/modeling_phi.py | 17 +- .../transformers/models/phi3/modeling_phi3.py | 10 + .../models/qwen2/modeling_qwen2.py | 11 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 93 ++++-- .../models/qwen3/modeling_qwen3.py | 11 +- .../models/qwen3_moe/modeling_qwen3_moe.py | 11 +- .../models/starcoder2/modeling_starcoder2.py | 11 +- QEfficient/utils/check_ccl_specializations.py | 43 +++ examples/ccl_image_text_to_text_inference.py | 135 +++++++++ examples/ccl_llama4_example.py | 126 ++++++++ examples/ccl_mistral3_example.py | 120 ++++++++ examples/ccl_qwen2_5_vl_example.py | 189 ++++++++++++ examples/compute_context_length.py | 61 ++++ examples/gemma3_example/ccl_gemma3_mm.py | 119 ++++++++ .../ccl_granite_vision_inference.py | 127 ++++++++ .../ccl_granitemoe_inference.py | 40 +++ .../intern_example/ccl_internvl_inference.py | 286 ++++++++++++++++++ .../ccl_qwen3moe_inference.py | 42 +++ tests/transformers/test_comp_ctx_length.py | 193 ++++++++++++ 42 files changed, 2685 insertions(+), 257 deletions(-) create mode 100644 QEfficient/utils/check_ccl_specializations.py create mode 100644 examples/ccl_image_text_to_text_inference.py create mode 100644 examples/ccl_llama4_example.py create mode 100644 examples/ccl_mistral3_example.py create mode 100644 examples/ccl_qwen2_5_vl_example.py create mode 100644 examples/compute_context_length.py create mode 100644 examples/gemma3_example/ccl_gemma3_mm.py create mode 100644 examples/granite_example/ccl_granite_vision_inference.py create mode 100644 examples/granite_example/ccl_granitemoe_inference.py create mode 100644 examples/intern_example/ccl_internvl_inference.py create mode 100644 examples/qwen3moe_example/ccl_qwen3moe_inference.py create mode 100644 tests/transformers/test_comp_ctx_length.py diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index 814122b9d..fbff5b18b 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -340,6 +340,18 @@ def main( "--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation." ) parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.") + parser.add_argument( + "--comp-ctx-lengths-prefill", + type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")], + default=[512], + help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).", + ) + parser.add_argument( + "--comp-ctx-lengths-decode", + type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")], + default=[2048], + help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).", + ) parser.add_argument( "--mxfp6", "--mxfp6_matmul", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c4f5a7bbd..269ccb0be 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) -def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT: - ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0])) +def CtxGather( + data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 +) -> onnxscript.FLOAT: + # Create a shape tensor based on comp_ctx_len + shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0) + + # Directly use the shape tensor without validation + ctx_indices = ops.Expand(ctx_indices, shape_tensor) ctx_indices = ops.Unsqueeze(ctx_indices, [-1]) return ops.GatherND(data, ctx_indices, batch_dims=2) @@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function): """ @staticmethod - def forward(data: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) + def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value: + return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data) diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index 75d9a12ef..cc9693716 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -97,16 +97,20 @@ def symbolic( @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxGatherCB( - data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32 + data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32 ) -> onnxscript.FLOAT: batch_size = ops.Gather(ops.Shape(batch_index), [0]) num_heads = ops.Gather(ops.Shape(data), [1]) - ctx_len = ops.Gather(ops.Shape(data), [2]) + # using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well. + ctx_len = ops.Reshape(comp_ctx_len, [1]) # Expanded shape to create indices zero = ops.Constant(value_ints=[0]) one = ops.Constant(value_ints=[1]) - exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + # exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0) + exp_shape = ops.Concat( + ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0 + ) # Create indices batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape) @@ -119,7 +123,7 @@ def CtxGatherCB( class CtxGatherFuncCB(torch.autograd.Function): @staticmethod - def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor): + def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = batch_index.view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) return data[batch_indices, head_indices, ctx_indices] @@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs): pass @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data) + def symbolic( + g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int + ) -> torch.Value: + return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data) @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 6d04cf573..cf4b6aa27 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv( prompts_txt_file_path: Optional[str] = None, device_id: Optional[List[int]] = None, generation_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, enable_debug_logs: bool = False, stream: bool = True, write_io_dir: Optional[str] = None, @@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv( qpc_path=qpc_path, device_id=device_id, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, full_batch_size=full_batch_size, @@ -430,6 +434,8 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, @@ -439,6 +445,8 @@ def __init__( sampling_params: Optional[Dict[str, Any]] = None, ) -> None: self._ctx_len = ctx_len + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self._write_io_dir = write_io_dir self.is_tlm = is_tlm self.return_pdfs = return_pdfs @@ -797,7 +805,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)] inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1) + if self.comp_ctx_lengths_prefill is not None: + self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + for i in range(num_chunks): + if self.comp_ctx_lengths_prefill is not None: + if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ :, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len @@ -816,6 +834,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i generation_len, ) + def initialize_ccl(self, decode_inputs): + self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(decode_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + + return ccl_id, max_ccl_id + def run_continuous_batching_decode(self, prompt_queue, generation_len): """ Runs continuous batching decode for the given prompt queue and generation length. @@ -847,6 +878,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): # Prepare decode inputs inputs. decode_inputs = self.prepare_decode_inputs() + if self.comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + while prompt_queue or current_decode_ongoing.any(): outputs = self._session.run(decode_inputs) @@ -884,6 +919,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): batch_id_map[decode_batch_id] ] + if self.comp_ctx_lengths_decode is not None: + ###Recalculate ccl_id based on position ids### + # Determine the maximum value of position_ids across all batch elements + max_position_id = np.max(decode_inputs["position_ids"]) + + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + else: current_decode_ongoing[decode_batch_id] = False else: @@ -896,6 +945,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len): if self.include_sampler: decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"] + if self.comp_ctx_lengths_decode is not None: + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + if ( + decode_inputs["position_ids"][decode_batch_id, -1] + >= self.comp_ctx_lengths_decode[ccl_id] - 1 + ): + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + generated_id_current_index[decode_batch_id] += 1 return decode_pause_time @@ -922,7 +980,18 @@ def run_decode( self._session.set_buffers({"logits": logits_out_placeholder}) finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id num_token = 0 + + if self.comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + + cache_index = np.max(decode_inputs["position_ids"]) for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id] + if streamer: streamer.put(decode_inputs["input_ids"][0]) outputs = self._session.run(decode_inputs) @@ -934,6 +1003,7 @@ def run_decode( # Prepare inputs for next iteration decode_inputs["input_ids"] = self._fetch_next_token_id(outputs) decode_inputs["position_ids"][:, -1] += 1 + cache_index += 1 self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1] finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id if self.include_sampler: @@ -983,6 +1053,8 @@ def __init__( qpc_path: str, full_batch_size: Optional[int] = None, ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, device_id: Optional[List[int]] = None, enable_debug_logs: bool = False, write_io_dir: Optional[str] = None, @@ -996,6 +1068,8 @@ def __init__( qpc_path=qpc_path, full_batch_size=full_batch_size, ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, device_id=device_id, enable_debug_logs=enable_debug_logs, write_io_dir=write_io_dir, @@ -1007,6 +1081,8 @@ def __init__( self._full_batch_size = self._qaic_model.full_batch_size self._tokenizer = self._qaic_model.tokenizer self._ctx_len = ctx_len + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self._perf_metrics = None self._prompt_queue = None self._text_streamer = None diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index bbd937d52..0d123d25f 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -40,8 +40,9 @@ def read_only(self, cache_kwargs): k_out, v_out = self.keys, self.values position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) - ctx_len = k_out.shape[2] - ctx_indices = torch.arange(ctx_len)[None, None, ...] + comp_ctx_len = cache_kwargs.get("CCL") + + ctx_indices = torch.arange(comp_ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -53,12 +54,11 @@ def read_only(self, cache_kwargs): ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, comp_ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, comp_ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) - + k_out = CtxGatherFunc.apply(k_out, ctx_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, comp_ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -121,6 +121,7 @@ def update( else: position_ids = cache_kwargs.get("position_ids") batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value form the kwargs + comp_ctx_len = cache_kwargs.get("CCL") # Scatter if batch_index is not None: @@ -137,8 +138,7 @@ def update( k_out, v_out = self.keys, self.values # Gather - ctx_len = k_out.shape[2] - ctx_indices = torch.arange(ctx_len)[None, None, ...] + ctx_indices = torch.arange(comp_ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit @@ -149,11 +149,11 @@ def update( ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) if batch_index is not None: - k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) - v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, comp_ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, comp_ctx_len) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) - v_out = CtxGatherFunc.apply(v_out, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, comp_ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -392,6 +392,8 @@ def update( else: position_ids = cache_kwargs.get("position_ids") sliding_window_pattern = cache_kwargs.get("sliding_window_pattern") + comp_ctx_len = cache_kwargs.get("CCL") + is_sliding_layer = torch.tensor(bool((layer_idx + 1) % sliding_window_pattern)) layer_ctx_len = self.key_cache[layer_idx].shape[2] kv_position_ids = torch.where( @@ -417,20 +419,24 @@ def update( ctx_len = self.key_cache[layer_idx].shape[2] ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max else: invalid_idx_value = 0 + + ctx_indices = ctx_indices[:, :, :comp_ctx_len] + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) + rolling_indices = rolling_indices[:comp_ctx_len] final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(v_out, final_indices, comp_ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out @@ -492,6 +498,8 @@ def update( else: position_ids = cache_kwargs.get("position_ids") + comp_ctx_len = cache_kwargs.get("CCL") + is_sliding_layer = torch.tensor(bool(self.is_sliding[layer_idx])) # Update the position_ids to handle the sliding window @@ -519,21 +527,25 @@ def update( ctx_len = min(layer_ctx_len, k_out.shape[2]) ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = kv_position_ids.max(1, keepdim=True).values.unsqueeze(1) - invalid_mask = ctx_indices > gather_limit if torch.onnx.is_in_onnx_export(): invalid_idx_value = torch.iinfo(torch.int32).max else: invalid_idx_value = 0 + + ctx_indices = ctx_indices[:, :, :comp_ctx_len] + invalid_mask = ctx_indices > gather_limit + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) # Rolling indices for sliding window all_indices = torch.arange(layer_ctx_len) + kv_position_ids.max() + 1 rolling_indices = torch.where(all_indices > layer_ctx_len - 1, all_indices % layer_ctx_len, all_indices) + rolling_indices = rolling_indices[:comp_ctx_len] final_indices = torch.where( (is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), rolling_indices, ctx_indices ) - k_out = CtxGatherFunc.apply(k_out, final_indices) - v_out = CtxGatherFunc.apply(v_out, final_indices) + k_out = CtxGatherFunc.apply(k_out, final_indices, comp_ctx_len) + v_out = CtxGatherFunc.apply(v_out, final_indices, comp_ctx_len) ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out diff --git a/QEfficient/transformers/models/gemma/modeling_gemma.py b/QEfficient/transformers/models/gemma/modeling_gemma.py index eea1e3898..4c64109d8 100644 --- a/QEfficient/transformers/models/gemma/modeling_gemma.py +++ b/QEfficient/transformers/models/gemma/modeling_gemma.py @@ -137,6 +137,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -153,7 +154,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -186,6 +189,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -214,6 +218,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -243,6 +248,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -299,6 +305,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -334,6 +341,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -350,6 +358,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma2/modeling_gemma2.py b/QEfficient/transformers/models/gemma2/modeling_gemma2.py index be3ba942d..85bba2989 100644 --- a/QEfficient/transformers/models/gemma2/modeling_gemma2.py +++ b/QEfficient/transformers/models/gemma2/modeling_gemma2.py @@ -144,6 +144,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -160,8 +161,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -194,6 +203,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -226,6 +236,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -266,6 +277,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -338,6 +350,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -381,6 +394,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -404,6 +418,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 20b7036fd..2e8494e8e 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- import copy -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch from torch import nn @@ -215,6 +215,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -245,6 +246,8 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -253,6 +256,7 @@ def forward( "position_ids": position_ids, "is_sliding": self.is_sliding, "sliding_window_pattern": self.config.sliding_window_pattern, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -297,6 +301,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -323,6 +328,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -363,6 +369,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -429,6 +436,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -466,6 +474,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -525,6 +534,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -592,7 +602,9 @@ def __init__(self, model): self.config = self.model.config self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_index @@ -603,7 +615,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -620,7 +636,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffGemma3DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): image_features = self.get_image_features(pixel_values=pixel_values) inputs_embeds = self.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -632,7 +650,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -647,6 +669,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -667,24 +691,55 @@ def get_specializations( "ctx_len": ctx_len, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "sliding_window": self.language_model.config.sliding_window, - "img_size": img_size, - "mm_tokens_per_image": mm_tokens_per_image, - }, - ] + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "sliding_window": self.language_model.config.sliding_window, + "img_size": img_size, + "mm_tokens_per_image": mm_tokens_per_image, + }, + ] + specializations = {} if kv_offload: @@ -694,7 +749,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -719,6 +774,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): ) lang_dynamic_axes[f"past_{kv}.{i}"] = apply_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -767,7 +825,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: @@ -813,6 +871,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): seq_len=constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN, ) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/granite/modeling_granite.py b/QEfficient/transformers/models/granite/modeling_granite.py index 2a2d47d6d..dd3d6c7f3 100644 --- a/QEfficient/transformers/models/granite/modeling_granite.py +++ b/QEfficient/transformers/models/granite/modeling_granite.py @@ -129,6 +129,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -145,8 +146,16 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -171,6 +180,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -226,6 +236,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -267,6 +278,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -319,6 +331,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py index c085f6a5e..07031d7fc 100644 --- a/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py +++ b/QEfficient/transformers/models/granitemoe/modeling_granitemoe.py @@ -123,6 +123,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -142,6 +143,8 @@ def forward( cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, @@ -149,6 +152,7 @@ def forward( "cache_position": cache_position, "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -209,6 +213,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -286,6 +291,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -297,6 +303,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -492,6 +499,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -546,6 +554,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 38d0fe167..29e6ac9a4 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import List + import torch import torch.nn as nn import torch.nn.functional as F @@ -34,7 +36,9 @@ def __init__(self, model): self.config = self.model.language_model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape image_input_embeds = input_embeds.reshape(B * N, C) @@ -55,7 +59,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) return outputs.logits, vision_embeds, image_idx, outputs.past_key_values @@ -74,6 +82,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -104,24 +114,54 @@ def get_specializations( "batched_num_patches": batch_size * num_patches, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "num_patches": num_patches, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "num_patches": num_patches, + "img_size": img_size, + "vision_size": vision_size, + }, + ] specializations = {} @@ -132,7 +172,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -146,6 +186,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -173,7 +216,7 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -234,6 +277,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -244,7 +290,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): return inputs - def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, pixel_values, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): input_embeds = self.language_model.get_input_embeddings()(input_ids) vision_embeds = self.extract_feature(pixel_values) B, N, C = input_embeds.shape @@ -266,7 +314,11 @@ def forward(self, input_ids, pixel_values, position_ids, image_idx, past_key_val inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), input_embeds, image_input_embeds) inputs_embeds = inputs_embeds.reshape(B, N, C) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) diff --git a/QEfficient/transformers/models/llama/modeling_llama.py b/QEfficient/transformers/models/llama/modeling_llama.py index f2a68f80e..58d174270 100644 --- a/QEfficient/transformers/models/llama/modeling_llama.py +++ b/QEfficient/transformers/models/llama/modeling_llama.py @@ -132,6 +132,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, cache_position: Optional[torch.LongTensor] = None, @@ -154,7 +155,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -187,6 +190,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -202,6 +206,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -229,6 +234,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -277,6 +283,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -310,6 +317,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -326,6 +334,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 212fe16ae..82678e380 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -470,6 +470,7 @@ def forward( position_embeddings: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -503,6 +504,8 @@ def forward( if past_key_value is not None: chunk_position_ids = position_ids + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] if self.use_rope: chunk_position_ids = torch.where( @@ -510,7 +513,11 @@ def forward( ) # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": chunk_position_ids} + cache_kwargs = { + "batch_index": batch_index, + "position_ids": chunk_position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -543,6 +550,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -562,6 +570,7 @@ def forward( position_embeddings=position_embeddings, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -615,6 +624,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -682,6 +692,7 @@ def forward( attention_mask=causal_mask_mapping[decoder_layer.attention_type], position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -731,6 +742,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -754,6 +766,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -836,7 +849,9 @@ def __init__(self, model): self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index indices1 = selected.to(torch.int64).cumsum(1) - 1 @@ -846,7 +861,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -860,7 +879,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlama4DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.language_model.get_input_embeddings()(input_ids) vision_feature_layer = self.config.vision_config.vision_feature_layer vision_feature_select_strategy = self.config.vision_config.vision_feature_select_strategy @@ -880,7 +901,11 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.language_model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -892,6 +917,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -941,28 +968,62 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_tiles": max_num_tiles, - "img_size": img_size, - "vision_size": vision_size, - "chunk_length": prefill_seq_len, - "chunk_ctx_len": chunk_ctx_len, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + }, + ] specializations = {} @@ -973,7 +1034,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -993,6 +1054,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -1045,7 +1109,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1102,6 +1166,9 @@ def get_dummy_inputs(self, kv_offload: bool = False): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py index 9fd1ed782..5b36b1019 100644 --- a/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py +++ b/QEfficient/transformers/models/llama_swiftkv/modeling_llama_swiftkv.py @@ -89,6 +89,7 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.LongTensor, past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: torch.Tensor = None, batch_index: Optional[torch.LongTensor] = None, ) -> torch.Tensor: @@ -105,8 +106,10 @@ def forward( "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " "with a layer index." ) + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] kv_seq_len = past_key_value.get_seq_length(self.layer_idx) - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.read_only(self.layer_idx, cache_kwargs=cache_kwargs) cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) @@ -155,6 +158,7 @@ def forward( hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, + comp_ctx_lengths, causal_mask, batch_index: Optional[torch.LongTensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -166,6 +170,7 @@ def forward( hidden_states=hidden_states, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, batch_index=batch_index, ) @@ -201,11 +206,19 @@ def __init__(self, config: QEffLlamaSwiftKVConfig): self.norm_swiftkv = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) def _run_swiftkv_layers( - self, hidden_states: torch.Tensor, position_ids: torch.Tensor, past_key_values, causal_mask, batch_index + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + past_key_values, + comp_ctx_lengths, + causal_mask, + batch_index, ) -> torch.Tensor: for layer_idx in range(self.config.num_key_value_layers, self.config.num_hidden_layers): layer = self.layers[layer_idx] - hidden_states = layer(hidden_states, position_ids, past_key_values, causal_mask, batch_index) + hidden_states = layer( + hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index + ) hidden_states = self.norm(hidden_states) return hidden_states, past_key_values @@ -289,6 +302,7 @@ def forward( input_ids: Optional[torch.Tensor], position_ids: torch.Tensor, past_key_values: List[torch.Tensor], + comp_ctx_lengths: Optional[torch.LongTensor], batch_index: Optional[torch.LongTensor] = None, ): inputs_embeds = self.embed_tokens(input_ids) @@ -328,6 +342,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=False, use_cache=True, @@ -373,7 +388,7 @@ def forward( causal_mask = causal_mask[torch.arange(bsz).reshape(-1, 1), :, last_pos_id, :] hidden_states, next_decoder_cache = self._run_swiftkv_layers( - hidden_states, position_ids, past_key_values, causal_mask, batch_index + hidden_states, position_ids, past_key_values, comp_ctx_lengths, causal_mask, batch_index ) # We can fill the orig_hidden_states with the processed hidden_states here but it's not needed as for next token prediction # we only need the last valid pos_indices hidden_states. @@ -405,9 +420,12 @@ def forward( input_ids: torch.Tensor, position_ids: torch.Tensor, past_key_values: Optional[Union[List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, ): - hidden_states, output_past_key_values = self.model(input_ids, position_ids, past_key_values, batch_index) + hidden_states, output_past_key_values = self.model( + input_ids, position_ids, past_key_values, comp_ctx_lengths, batch_index + ) logits = self.lm_head(hidden_states) return CausalLMOutputWithPast( loss=None, diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index e260beb05..450fc79b6 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from typing import List + import torch import torch.nn as nn import torch.utils.checkpoint @@ -51,7 +53,9 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -65,6 +69,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, return_dict=True, ) @@ -83,7 +88,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFLlavaDecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.get_input_embeddings()(input_ids) # Image features image_outputs = self.vision_tower(pixel_values, output_hidden_states=True) @@ -109,6 +116,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -120,7 +128,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -150,6 +158,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: @@ -166,6 +178,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -187,24 +201,55 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + + if comp_ctx_lengths_prefill and comp_ctx_lengths_decode: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + ] + specializations = {} if kv_offload: @@ -214,7 +259,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -230,6 +275,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index 2fa1d9234..b23073fa7 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -6,6 +6,8 @@ # ----------------------------------------------------------------------------- +from typing import List + import numpy as np import torch import torch.nn as nn @@ -123,7 +125,9 @@ def __init__(self, model): self.language_model = self.model.language_model self.lm_head = self.model.lm_head - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.config.image_token_index @@ -138,6 +142,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) image_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -154,7 +159,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlavaNextDecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -217,6 +222,10 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): ) ) lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, constants.GRANITEVISION_CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -232,6 +241,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -285,30 +296,67 @@ def get_specializations( "img_size": img_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "image_size_height": image_size_height, - "image_size_width": image_size_width, - "num_patches": num_patches, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "image_size_height": image_size_height, - "image_size_width": image_size_width, - "num_patches": num_patches, - "max_num_images": max_num_images, - "img_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size_height": image_size_height, + "image_size_width": image_size_width, + "num_patches": num_patches, + "max_num_images": max_num_images, + "img_size": img_size, + "vision_size": vision_size, + }, + ] + specializations = {} if kv_offload: specializations["vision"] = vision @@ -317,7 +365,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { @@ -332,6 +380,10 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): for i in range(num_layers): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/mistral/modeling_mistral.py b/QEfficient/transformers/models/mistral/modeling_mistral.py index ca23cc144..30c73ae8b 100644 --- a/QEfficient/transformers/models/mistral/modeling_mistral.py +++ b/QEfficient/transformers/models/mistral/modeling_mistral.py @@ -140,6 +140,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -163,7 +164,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -196,6 +199,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -226,6 +230,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -256,6 +261,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -316,6 +322,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -354,6 +361,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -377,6 +385,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index 735eec9e5..a5f1301d2 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import torch import torch.nn as nn @@ -106,6 +106,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, vision_feature_layer: Optional[Union[int, list[int]]] = None, use_cache: Optional[bool] = None, @@ -126,6 +127,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -166,7 +168,9 @@ def __init__(self, model): self.config = self.model.config self.language_model = self.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) mask = input_ids == self.model.config.image_token_index @@ -179,6 +183,7 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va inputs_embeds=inputs_embeds_1, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) # Cast to int32 to avoid ONNXRT issue @@ -198,7 +203,9 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEFFMistral3DecoderWrapper(self) - def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_values): + def forward( + self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.get_input_embeddings()(input_ids) image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]).repeat(pixel_values.shape[0], 1) image_features = self.get_image_features( @@ -219,6 +226,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, ) # Cast to int32 to avoid ONNXRT issue logit_idx = position_ids.to(torch.int32).argmax(1, keepdim=True) @@ -230,7 +238,7 @@ def forward(self, input_ids, position_ids, pixel_values, image_idx, past_key_val return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size @@ -282,6 +290,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -298,6 +309,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -323,22 +336,50 @@ def get_specializations( "vision_size": vision_size, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "image_size": img_size, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "image_size": img_size, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "image_size": img_size, + "vision_size": vision_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "image_size": img_size, + "vision_size": vision_size, + } + ) + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "image_size": img_size, + "vision_size": vision_size, + }, + ] specializations = {} @@ -351,7 +392,7 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers @@ -368,6 +409,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py index 9b9e3448a..6e61568ac 100644 --- a/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py +++ b/QEfficient/transformers/models/mixtral_moe/modeling_mixtral.py @@ -137,6 +137,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -159,7 +160,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -245,6 +248,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_router_logits: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -282,6 +286,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -314,6 +319,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -375,6 +381,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_router_logits=output_router_logits, use_cache=use_cache, cache_position=cache_position, @@ -412,6 +419,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -435,6 +443,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index cb24f1de4..2197bec91 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -177,6 +177,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, @@ -249,6 +250,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, position_embeddings: torch.Tensor = None, use_cache: bool = False, @@ -278,9 +280,12 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -316,6 +321,7 @@ def forward( full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -350,6 +356,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -379,6 +386,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: bool = None, @@ -396,13 +404,15 @@ def forward( key_states = key_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) value_states = value_states.view(bsz, -1, self.num_key_value_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # if we have a new image + new tokens, we only computed key_states on that new image # we still update the cross key states, past_image, new_image. And use it! key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, - {"batch_index": batch_index, "position_ids": position_ids}, + {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]}, ) elif past_key_value is not None: key_states, value_states = ( @@ -448,6 +458,7 @@ def forward( full_text_row_masked_out_mask: Tuple[torch.Tensor, torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -461,6 +472,7 @@ def forward( attention_mask=cross_attention_mask, cross_attention_states=cross_attention_states, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, cache_position=cache_position, ) @@ -594,6 +606,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.FloatTensor] = None, cross_attention_mask: Optional[torch.Tensor] = None, @@ -658,6 +671,7 @@ def forward( full_text_row_masked_out_mask=full_text_row_masked_out_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, ) @@ -688,6 +702,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cross_attention_states: Optional[torch.LongTensor] = None, cross_attention_mask: Optional[torch.LongTensor] = None, @@ -707,6 +722,7 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -774,6 +790,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, @@ -820,6 +837,7 @@ def forward( cross_attention_mask=cross_attention_mask, full_text_row_masked_out_mask=full_text_row_masked_out_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, inputs_embeds=inputs_embeds, cache_position=cache_position, @@ -853,6 +871,7 @@ def forward( cross_attention_states: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -869,6 +888,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, cache_position=cache_position, @@ -879,7 +899,7 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN CTX_LEN = constants.ONNX_EXPORT_CTX_LEN @@ -943,6 +963,10 @@ def get_dummy_inputs(self, kv_offload: bool = False): lang_inputs["past_key_values"] = lang_inputs["past_key_values"].to_legacy_cache() lang_inputs["position_ids"] = torch.full(lang_inputs["position_ids"].shape, CTX_LEN - 1) + + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: @@ -959,6 +983,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -973,22 +999,53 @@ def get_specializations( logger.warning("Setting `img_size=448` as it was neither passed nor found in vision_config") vision = [{"batch_size": batch_size, "max_num_images": max_num_images, "img_size": img_size}] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "max_num_images": max_num_images, - "img_size": img_size, - }, - ] + + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "max_num_images": max_num_images, + "img_size": img_size, + } + ) + + # Remaining elements use comp_ctx_lengths[1:] in a loop + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "max_num_images": max_num_images, + "img_size": img_size, + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "max_num_images": max_num_images, + "img_size": img_size, + }, + ] + specializations = {} if kv_offload: @@ -998,7 +1055,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers @@ -1023,6 +1080,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..5c6f67ddc 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -56,6 +56,7 @@ constants, get_padding_shape_from_config, ) +from QEfficient.utils.check_ccl_specializations import process_ccl_specializations from QEfficient.utils.logging_utils import logger @@ -877,6 +878,15 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") self.model = model self.config = model.config + + self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + if self.comp_ctx_lengths_prefill: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len + ) + self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) self.input_shapes, self.output_names = None, None @@ -922,8 +932,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if comp_ctx_lengths_prefill: + comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len + ) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + **kwargs, + ) @property def onnx_path(self): @@ -978,8 +1005,8 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(kv_offload=True) + inputs = self.model.get_dummy_inputs(self.comp_ctx_lengths_decode, kv_offload=True) + dynamic_axes = self.model.get_onnx_dynamic_axes(self.comp_ctx_lengths_decode, kv_offload=True) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1083,6 +1110,8 @@ def compile( batch_size=batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, img_size=img_size, kv_offload=True, **compiler_options, @@ -1332,6 +1361,11 @@ def kv_offload_generate( lang_session.set_buffers(vision_outputs) + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + # Prepare inputs for prefill chunk_inputs = lang_inputs.copy() prefill_start = perf_counter() @@ -1339,6 +1373,13 @@ def kv_offload_generate( # Run prefill chunk_inputs = lang_inputs.copy() for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs["input_ids"] = lang_inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = lang_inputs["position_ids"][ ..., i * prefill_seq_len : (i + 1) * prefill_seq_len @@ -1368,8 +1409,25 @@ def kv_offload_generate( streamer.put(lang_inputs["input_ids"][0]) # Decode loop + if self.comp_ctx_lengths_decode is not None: + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_position_id = np.max(lang_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + decode_start = perf_counter() for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + outputs = lang_session.run(lang_inputs) # Prepare inputs for next iteration @@ -1440,6 +1498,15 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") super().__init__(model, **kwargs) + self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if self.comp_ctx_lengths_prefill: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len + ) + # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): self.model.config.llm_config.use_cache = True @@ -1486,6 +1553,16 @@ def from_pretrained( logger.warning("Updating low_cpu_mem_usage=False") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if comp_ctx_lengths_prefill: + comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len + ) + from transformers import AutoConfig config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True) @@ -1493,7 +1570,14 @@ def from_pretrained( config.vision_config.use_flash_attn = "false" model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, config, *args, **kwargs) - return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + **kwargs, + ) def export( self, @@ -1515,8 +1599,8 @@ def export( str Path to the generated ONNX graph file. """ - inputs = self.model.get_dummy_inputs() - dynamic_axes = self.model.get_onnx_dynamic_axes() + inputs = self.model.get_dummy_inputs(self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes(self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) @@ -1598,6 +1682,8 @@ def compile( batch_size=batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, img_size=img_size, **compiler_options, ) @@ -1782,12 +1868,24 @@ def cloud_ai_100_generate( inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) inputs["image_idx"] = np.array([[0]]) + if self.comp_ctx_lengths_prefill is not None: + list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + qpc_session.activate() chunk_inputs = inputs.copy() prefill_start = perf_counter() # Run prefill for i in range(num_chunks): + if ( + self.comp_ctx_lengths_prefill is not None + and (i + 1) * prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id] + ): + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + chunk_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] chunk_inputs["position_ids"] = inputs["position_ids"][:, i * prefill_seq_len : (i + 1) * prefill_seq_len] outputs = qpc_session.run(chunk_inputs) @@ -1811,8 +1909,25 @@ def cloud_ai_100_generate( inputs.pop("pixel_values") # Decode loop + if self.comp_ctx_lengths_decode is not None: + list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode] + max_ccl_id = len(self.comp_ctx_lengths_decode) - 1 + max_position_id = np.max(inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)): + if max_position_id < self.comp_ctx_lengths_decode[i]: + ccl_id = i + break + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + decode_start = perf_counter() for num_token in range(1, generation_len): + if self.comp_ctx_lengths_decode is not None: + if max_position_id >= self.comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + outputs = qpc_session.run(inputs) # Prepare inputs for next iteration inputs["input_ids"] = outputs["logits"].argmax(2) @@ -1950,6 +2065,9 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) Union[_QEffAutoModelForImageTextToTextDualQPC, _QEFFAutoModelForImageTextToTextSingleQPC] The wrapped model instance, configured for either dual or single QPC. """ + self.comp_ctx_lengths_prefill = kwargs.get("comp_ctx_lengths_prefill", None) + self.comp_ctx_lengths_decode = kwargs.get("comp_ctx_lengths_decode", None) + if kv_offload: return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) else: @@ -1996,8 +2114,26 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) + + comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if comp_ctx_lengths_prefill: + comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len + ) + model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) - return cls(model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + return cls( + model, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + **kwargs, + ) MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP = { @@ -2096,6 +2232,15 @@ def __init__( self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) self.is_tlm = transformed + self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + + if self.comp_ctx_lengths_prefill: + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len + ) + self.hash_params["qeff_auto_class"] = self.__class__.__name__ # ---Sampling--- @@ -2190,6 +2335,14 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) + comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + if comp_ctx_lengths_prefill: + comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len + ) + kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) if qaic_config is not None: @@ -2199,13 +2352,22 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( - model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs + model, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + kv_offload=kv_offload, + pretrained_model_name_or_path=pretrained_model_name_or_path, + **kwargs, ) return cls( model, continuous_batching=continuous_batching, qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, **kwargs, ) @@ -2255,6 +2417,10 @@ def export(self, export_dir: Optional[str] = None) -> str: "input_ids": {0: "batch_size", 1: "seq_len"}, "position_ids": {0: "batch_size", 1: "seq_len"}, } + if self.comp_ctx_lengths_prefill is not None: + example_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d pkv_dynamic_axes = { 0: "full_batch_size" if self.continuous_batching else "batch_size", @@ -2400,6 +2566,7 @@ def build_prefill_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -2431,6 +2598,9 @@ def build_prefill_specialization( "ctx_len": ctx_len, "num_logits_to_keep": 1 if self.is_tlm else None, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -2443,6 +2613,7 @@ def build_decode_specialization( self, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths: Optional[int] = None, batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, @@ -2472,7 +2643,7 @@ def build_decode_specialization( A dictionary defining the decode specialization, or None if it would be a duplicate of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ - if prefill_seq_len == 1 and not self.continuous_batching: + if prefill_seq_len == 1 and not self.continuous_batching and comp_ctx_lengths is None: return None # Avoid duplication with prefill spec = { "batch_size": full_batch_size if self.continuous_batching else batch_size, @@ -2480,6 +2651,8 @@ def build_decode_specialization( "ctx_len": ctx_len, "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, } + if comp_ctx_lengths is not None: + spec["comp_ctx_lengths"] = comp_ctx_lengths if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size @@ -2494,6 +2667,8 @@ def compile( *, prefill_seq_len: int = 32, ctx_len: int = 128, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, batch_size: int = 1, full_batch_size: Optional[int] = None, kv_cache_batch_size: Optional[int] = None, @@ -2581,6 +2756,23 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + # For comp_ctx_lengths Disaggregated applications + if self.comp_ctx_lengths_prefill is None: + if comp_ctx_lengths_prefill is not None: + import ast + + if isinstance(comp_ctx_lengths_prefill, str): + try: + # Safely evaluate the string to a Python list for disaggregated input + self.comp_ctx_lengths_prefill = ast.literal_eval(comp_ctx_lengths_prefill) + self.comp_ctx_lengths_decode = ast.literal_eval(comp_ctx_lengths_decode) + + except (ValueError, SyntaxError): + raise ValueError("Invalid format for comp_ctx_lengths. Expected a list-like string.") + else: + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode + # --- Validation --- if prefill_only is not None and not isinstance(prefill_only, bool): raise TypeError("`prefill_only` must be a boolean.") @@ -2611,26 +2803,58 @@ def compile( # --- Specializations --- specializations = [] if prefill_only is None or prefill_only or prefill_seq_len == 1: - specializations.append( - self.build_prefill_specialization( + if self.comp_ctx_lengths_prefill is not None: + # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization + for i in range(0, len(self.comp_ctx_lengths_prefill)): + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_prefill[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + ) + ) + + else: + specializations.append( + self.build_prefill_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + ) + ) + + if prefill_only is None or not prefill_only: + if self.comp_ctx_lengths_decode is not None: + # Adding elements from self.comp_ctx_lengths_decode to decode_specialization + for i in range(0, len(self.comp_ctx_lengths_decode)): + decode_spec = self.build_decode_specialization( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths=self.comp_ctx_lengths_decode[i], + batch_size=batch_size, + kv_cache_batch_size=kv_cache_batch_size, + full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, + ) + if decode_spec: + specializations.append(decode_spec) + + else: + decode_spec = self.build_decode_specialization( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, + num_speculative_tokens=num_speculative_tokens, ) - ) - if prefill_only is None or not prefill_only: - decode_spec = self.build_decode_specialization( - prefill_seq_len=prefill_seq_len, - ctx_len=ctx_len, - batch_size=batch_size, - kv_cache_batch_size=kv_cache_batch_size, - full_batch_size=full_batch_size, - num_speculative_tokens=num_speculative_tokens, - ) - if decode_spec: - specializations.append(decode_spec) + if decode_spec: + specializations.append(decode_spec) # --- Compilation --- kv_cache_dtype = "mxint8" if mxint8_kv_cache else "float16" @@ -2708,6 +2932,8 @@ def generate( tokenizer, self.qpc_path, prompt=prompts, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, device_id=device_id, generation_len=generation_len, automation=kwargs.pop("automation", False), diff --git a/QEfficient/transformers/models/mpt/modeling_mpt.py b/QEfficient/transformers/models/mpt/modeling_mpt.py index 9bf6a4422..16ca54051 100644 --- a/QEfficient/transformers/models/mpt/modeling_mpt.py +++ b/QEfficient/transformers/models/mpt/modeling_mpt.py @@ -39,6 +39,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, ): @@ -51,7 +52,9 @@ def forward( value_states = value_states.reshape(batch_size, seq_length, self.n_heads, self.head_dim).transpose(1, 2) if past_key_value is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_scores = torch.matmul(query_states, key_states.transpose(-1, -2)) * self.softmax_scale @@ -101,6 +104,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, use_cache: bool = False, output_attentions: bool = False, ): @@ -118,6 +122,7 @@ def forward( batch_index=batch_index, attention_mask=attention_mask, past_key_value=layer_past, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, ) @@ -144,6 +149,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -205,6 +211,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -250,6 +257,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -271,6 +279,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, diff --git a/QEfficient/transformers/models/olmo2/modeling_olmo2.py b/QEfficient/transformers/models/olmo2/modeling_olmo2.py index 0d23729c1..bf946da39 100644 --- a/QEfficient/transformers/models/olmo2/modeling_olmo2.py +++ b/QEfficient/transformers/models/olmo2/modeling_olmo2.py @@ -132,6 +132,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -154,8 +155,10 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -188,6 +191,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -203,6 +207,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -233,6 +238,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -286,6 +292,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -322,6 +329,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -343,6 +351,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/phi/modeling_phi.py b/QEfficient/transformers/models/phi/modeling_phi.py index 18557f1ca..a5e53216a 100644 --- a/QEfficient/transformers/models/phi/modeling_phi.py +++ b/QEfficient/transformers/models/phi/modeling_phi.py @@ -67,6 +67,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -104,8 +105,16 @@ def forward( key_states = torch.cat((key_rot, key_pass), dim=-1) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -140,6 +149,7 @@ def forward( output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 **kwargs, @@ -181,6 +191,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, @@ -213,6 +224,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -274,6 +286,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -316,6 +329,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -370,6 +384,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, diff --git a/QEfficient/transformers/models/phi3/modeling_phi3.py b/QEfficient/transformers/models/phi3/modeling_phi3.py index 4b5234a5a..851395f08 100644 --- a/QEfficient/transformers/models/phi3/modeling_phi3.py +++ b/QEfficient/transformers/models/phi3/modeling_phi3.py @@ -140,6 +140,7 @@ def forward( batch_index: Optional[torch.LongTensor] = None, position_ids=Optional[torch.Tensor], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: @@ -162,9 +163,12 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] cache_kwargs = { "batch_index": batch_index, "position_ids": position_ids, + "CCL": attention_mask.shape[-1], } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) @@ -198,6 +202,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -235,6 +240,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, position_embeddings=position_embeddings, @@ -265,6 +271,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -314,6 +321,7 @@ def forward( position_ids=position_ids, batch_index=batch_index, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, cache_position=cache_position, **kwargs, @@ -350,6 +358,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_hidden_states: Optional[bool] = None, @@ -366,6 +375,7 @@ def forward( batch_index=batch_index, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, use_cache=use_cache, output_hidden_states=output_hidden_states, diff --git a/QEfficient/transformers/models/qwen2/modeling_qwen2.py b/QEfficient/transformers/models/qwen2/modeling_qwen2.py index 24e8df46c..1aca7039d 100644 --- a/QEfficient/transformers/models/qwen2/modeling_qwen2.py +++ b/QEfficient/transformers/models/qwen2/modeling_qwen2.py @@ -150,6 +150,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -166,7 +167,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -200,6 +203,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -231,6 +235,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -261,6 +266,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -313,6 +319,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -348,6 +355,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -364,6 +372,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 030dd7a56..3b1d116de 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -399,6 +399,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -425,8 +426,16 @@ def forward( ) if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids[0]} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids[0], + "CCL": attention_mask.shape[-1], + } key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -457,6 +466,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, @@ -496,6 +506,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -528,6 +539,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -578,6 +590,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -616,6 +629,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -639,6 +653,7 @@ def forward( position_ids=position_ids, attention_mask=attention_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, @@ -680,7 +695,9 @@ def __init__(self, model): self.model = model self.language_model = self.model.model.language_model - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape selected = input_ids == self.model.config.image_token_id @@ -691,7 +708,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va image_input_embeds = torch.where(selected.unsqueeze(-1), image_features_expanded, inputs_embeds) inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_input_embeds) outputs = self.model.model( - inputs_embeds=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + inputs_embeds=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) logit_index = position_ids[0].to(torch.int32).argmax(1, keepdim=True) @@ -709,7 +730,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -757,6 +778,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -775,6 +799,8 @@ def get_specializations( img_size: None, height: int = None, width: int = None, + comp_ctx_lengths_prefill: List[int] = None, + comp_ctx_lengths_decode: List[int] = None, kv_offload: bool = False, **compiler_options, ): @@ -856,20 +882,46 @@ def smart_resize( "grid_w": grid_w, } ] - lang = [ - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "vision_size": vision_size, - }, - ] + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang.append( + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + } + ) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang.append( + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + } + ) + + else: + lang = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + }, + { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "vision_size": vision_size, + }, + ] specializations = {} @@ -880,7 +932,7 @@ def smart_resize( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.num_hidden_layers @@ -899,6 +951,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/transformers/models/qwen3/modeling_qwen3.py b/QEfficient/transformers/models/qwen3/modeling_qwen3.py index ecdb36019..ccf918c2c 100644 --- a/QEfficient/transformers/models/qwen3/modeling_qwen3.py +++ b/QEfficient/transformers/models/qwen3/modeling_qwen3.py @@ -151,6 +151,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -167,7 +168,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -201,6 +204,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -232,6 +236,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -262,6 +267,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -314,6 +320,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -349,6 +356,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -367,6 +375,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 591f7c1b0..c8a5ae2fd 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -201,6 +201,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -217,7 +218,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -243,6 +246,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -274,6 +278,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -300,6 +305,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, batch_index: Optional[torch.LongTensor] = None, @@ -342,6 +348,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -369,6 +376,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -385,6 +393,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=inputs_embeds, batch_index=batch_index, use_cache=use_cache, diff --git a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py index 9a327761d..075b8aedb 100644 --- a/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py +++ b/QEfficient/transformers/models/starcoder2/modeling_starcoder2.py @@ -69,6 +69,7 @@ def forward( attention_mask: Optional[torch.Tensor], position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, cache_position: Optional[torch.LongTensor] = None, **kwargs, @@ -84,7 +85,9 @@ def forward( query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) attention_interface = eager_attention_forward @@ -118,6 +121,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = False, cache_position: Optional[torch.LongTensor] = None, @@ -153,6 +157,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -184,6 +189,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -237,6 +243,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, cache_position=cache_position, @@ -273,6 +280,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -289,6 +297,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py new file mode 100644 index 000000000..dbfb08926 --- /dev/null +++ b/QEfficient/utils/check_ccl_specializations.py @@ -0,0 +1,43 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from typing import List, Optional + + +def process_ccl_specializations( + ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None +): + if ctx_len is None: + raise TypeError("`ctx_len` is required when loading the model.") + if ccl_prefill is None: + ccl_prefill = [ctx_len] + if ccl_decode is None: + ccl_decode = [ctx_len] + + # Step 1: Cap values to ctx_len + ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] + ccl_decode = [min(x, ctx_len) for x in ccl_decode] + + # Step 2: Remove duplicates within each list + ccl_prefill = list(set(ccl_prefill)) + ccl_decode = list(set(ccl_decode)) + + # Step 3: Ensure no overlap between ccl_prefill and ccl_decode + updated_prefill = [] + for val in ccl_prefill: + while val in ccl_decode or val in updated_prefill: + val -= 1 + if val < 0: + break # Prevent negative values + if val >= 0: + updated_prefill.append(val) + + # Step 4: Sort both lists + updated_prefill.sort() + ccl_decode.sort() + + return updated_prefill, ccl_decode diff --git a/examples/ccl_image_text_to_text_inference.py b/examples/ccl_image_text_to_text_inference.py new file mode 100644 index 000000000..932a407b9 --- /dev/null +++ b/examples/ccl_image_text_to_text_inference.py @@ -0,0 +1,135 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# Add HuggingFace Token to access the model +HF_TOKEN = "" + + +def run_model( + model_name, + token, + query, + image_url, + kv_offload=False, + prefill_seq_len=32, + ctx_len=512, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=560, + num_cores=16, + num_devices=1, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name, token=token) + + # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + token=token, + attn_implementation="eager", + kv_offload=kv_offload, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + image = Image.open(requests.get(image_url, stream=True).raw) + messages = [ + { + "role": "user", + "content": [ + {"type": "image"}, + {"type": "text", "text": query}, + ], + } + ] + input_text = [processor.apply_chat_template(messages, add_generation_prompt=True)] + + inputs = processor( + text=input_text, + images=image, + return_tensors="pt", + add_special_tokens=False, + padding="max_length", + max_length=prefill_seq_len, + ) + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output_statistics = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output_statistics) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "llava-hf/llava-1.5-7b-hf" + # model_name = "meta-llama/Llama-3.2-11B-Vision-Instruct" + query = "Describe this image." + image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 32 + ctx_len = 8192 + generation_len = 128 + img_size = 336 + # img_size = 560 + num_cores = 16 + num_devices = 4 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + token=HF_TOKEN, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: + +This image depicts a charming anthropomorphic rabbit standing on a dirt path in front of a picturesque stone cottage, surrounded by a serene landscape. + +The rabbit, with its light brown fur and distinctive long ears, is attired in a stylish blue coat, brown vest, and tan pants, exuding a sense of sophistication. The dirt path, flanked by vibrant flowers and lush greenery, leads to the cottage, which features a thatched roof and a chimney, adding to the rustic charm of the scene. In the background, rolling hills and trees create a breathtaking panorama, while the sky above is a brilliant blue with white clouds, completing the + +""" diff --git a/examples/ccl_llama4_example.py b/examples/ccl_llama4_example.py new file mode 100644 index 000000000..5fc715589 --- /dev/null +++ b/examples/ccl_llama4_example.py @@ -0,0 +1,126 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config) +model.eval() +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id) + +### For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### +ctx_len = 8192 +comp_ctx_lengths_prefill = [3072] +comp_ctx_lengths_decode = [4096, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText( + model, + kv_offload=True, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, +) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "Can you describe the image in detail.", + }, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=700) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=1024) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print() diff --git a/examples/ccl_mistral3_example.py b/examples/ccl_mistral3_example.py new file mode 100644 index 000000000..ed02a4fa9 --- /dev/null +++ b/examples/ccl_mistral3_example.py @@ -0,0 +1,120 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + + +def run_model( + model_name, + query, + image_url, + kv_offload=False, + prefill_seq_len=128, + ctx_len=4096, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=1540, + num_cores=16, + num_devices=4, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + config = AutoConfig.from_pretrained(model_name) + config.vision_config._attn_implementation = "eager" + + model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, + kv_offload=kv_offload, + config=config, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1540 x 1540) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1540 and width 1540 as defined in the config + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1540, 1540)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "https://www.ilankelman.org/stopsigns/australia.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 128 + ctx_len = 8192 + generation_len = 128 + num_cores = 16 + num_devices = 4 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: +The image depicts a street scene in what appears to be a Chinatown district. The focal point is a traditional Chinese archway, known as a paifang, which is intricately designed with red columns and ornate details. The archway features Chinese characters at the top, which translate to "Chinatown Gate." +In the foreground, there is a red stop sign mounted on a pole. The street is relatively quiet, with a single dark-colored SUV driving through the archway. On either side of the archway, there are stone lion statues, which are common decorative elements in Chinese architecture and symbolize protection. + + +""" diff --git a/examples/ccl_qwen2_5_vl_example.py b/examples/ccl_qwen2_5_vl_example.py new file mode 100644 index 000000000..7056011f2 --- /dev/null +++ b/examples/ccl_qwen2_5_vl_example.py @@ -0,0 +1,189 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import torch +import torch.nn.functional as F +import transformers +from PIL import Image +from qwen_vl_utils import process_vision_info +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) + +## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model + +ctx_len = 32768 + +comp_ctx_lengths_prefill = [4000] +comp_ctx_lengths_decode = [4096, 8192,16384, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + attn_implementation="eager", + kv_offload=True, + config=config +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + + ## Set Batch_Size ## + batch_size = 2 + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Tell me about yourself."}, + ], + }, + ] + + messages = [messages] * batch_size + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + pos_ids, rope_deltas = qeff_model.model.get_rope_index( + inputs["input_ids"], + image_grid_thw=None, + video_grid_thw=None, + second_per_grid_ts=None, + attention_mask=inputs["attention_mask"], + ) + + input_ids_length = inputs["input_ids"].shape[1] + + inputs["position_ids"] = torch.cat([pos_ids, pos_ids[0].unsqueeze(0)], dim=0) + + prefill_seq_len = 128 + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + batch_size = 1 + ## Vision + Text ## + qeff_model.compile( + batch_size=batch_size, + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + + messages_1 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe this image."}, + ], + }, + ] + + messages_2 = [ + { + "role": "user", + "content": [ + {"type": "image", "image": image}, + {"type": "text", "text": "Describe about the color of the dog."}, + ], + }, + ] + + messages = [messages_2] * batch_size + + texts = [processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in messages] + + image_inputs, video_inputs = process_vision_info(messages) + inputs = processor( + text=texts, + images=image_inputs, + videos=video_inputs, + padding=True, + return_tensors="pt", + ) + input_ids_length = inputs["input_ids"].shape[1] + + inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) + + pos_ids, rope_deltas = qeff_model.model.model.get_rope_index( + inputs["input_ids"], + inputs["image_grid_thw"], + video_grid_thw=None, + second_per_grid_ts=None, + attention_mask=inputs["attention_mask"], + ) + + inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0) + + prefill_seq_len = 128 + num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float + padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len + + inputs["position_ids"] = F.pad( + inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 + ) + + inputs.pop("image_grid_thw") + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py new file mode 100644 index 000000000..554c61c84 --- /dev/null +++ b/examples/compute_context_length.py @@ -0,0 +1,61 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +## In this example, you can run a model for static and continuous batching with different Compute-Context-Length (CCL) inputs. ## + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +## Using optional variable comp_ctx_lengths variable you can pass a list of context lengths. It will run the model with default context length if comp_ctx_lengths=None. ## +## - The first comp_ctx_lengths_prefill list shows the compute-ctx-length list for prefilling process. ## +## - The second comp_ctx_lengths_decode list will be used for decoding. During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ## + +ctx_len = 1024 +comp_ctx_lengths_prefill = [256] +comp_ctx_lengths_decode = [512,ctx_len] + +# model_name = "google/gemma-7b" +# model_name = "google/gemma-2-2b" +# model_name = "ibm-granite/granite-3.1-8b-instruct" +# model_name = "Snowflake/Llama-3.1-SwiftKV-8B-Instruct" +# model_name = "mistralai/Mistral-7B-v0.1" +# model_name = "microsoft/phi-1_5" +# model_name = "microsoft/Phi-3-mini-4k-instruct" +# model_name = "Qwen/Qwen2.5-7B-Instruct" +model_name = "meta-llama/Llama-3.2-1B" +# model_name = "Qwen/Qwen3-1.7B" +# model_name = "allenai/OLMo-2-0425-1B" +# model_name = "ibm-granite/granite-3.3-2b-base" +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + continuous_batching=False, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, +) + +# model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. +model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=1, + batch_size=1, + mxint8_kv_cache=True, + mxfp6_matmul=True, +) + +# Create tokenizer and run model.generate and passes the input prompts to it. +tokenizer = AutoTokenizer.from_pretrained(model_name) +model.generate( + prompts=[ + "My name is ", + ], + tokenizer=tokenizer, + generation_len=128 +) diff --git a/examples/gemma3_example/ccl_gemma3_mm.py b/examples/gemma3_example/ccl_gemma3_mm.py new file mode 100644 index 000000000..484c0f8ce --- /dev/null +++ b/examples/gemma3_example/ccl_gemma3_mm.py @@ -0,0 +1,119 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +# Change model_id to "google/gemma-3-27b-it" for 27B model +model_id = "google/gemma-3-4b-it" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +# config.text_config.num_hidden_layers = 1 +# config.vision_config.num_hidden_layers = 2 +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id) + +# pass HF_TOKEN if gated model +# For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### +ctx_len = 8192 +comp_ctx_lengths_prefill = [3072] +comp_ctx_lengths_decode = [4096, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + config=config, + attn_implementation="eager", + kv_offload=True, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, +) + +### use skip_vision=Ture, if want to run only text, or false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=896, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_27b.yaml", + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe the transformers architecture in LLMs."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=896, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + mos=1, + node_precision_info="examples/gemma3_example/fp32_nodes_gemma3_27b.yaml", + ) + + ### IMAGE + TEXT ### + image_url = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" + ) + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url}, + {"type": "text", "text": "Can you describe the image in detail."}, + ], + }, + ] + + inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + output = qeff_model.generate(inputs=inputs, generation_len=100) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) diff --git a/examples/granite_example/ccl_granite_vision_inference.py b/examples/granite_example/ccl_granite_vision_inference.py new file mode 100644 index 000000000..e03b94a5e --- /dev/null +++ b/examples/granite_example/ccl_granite_vision_inference.py @@ -0,0 +1,127 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +from PIL import Image +from transformers import AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +# Add HuggingFace Token to access the model +HF_TOKEN = "" + + +def run_model( + model_name, + token, + query, + image_url, + kv_offload=False, + prefill_seq_len=5500, + ctx_len=6000, + comp_ctx_lengths_prefill=None, + comp_ctx_lengths_decode=None, + generation_len=128, + img_size=384, + num_cores=16, + num_devices=1, +): + ## STEP - 1 Load the Processor and Model + + processor = AutoProcessor.from_pretrained(model_name, token=token) + + # `kv_offload` is used to compile the model in a 2 QPCs.Currently we are not supporting 1 qpc so the flag false is not allowed. + # The `kv_offload` flag should always be set to True. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + token=token, + kv_offload=kv_offload, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + + ## STEP - 2 Export & Compile the Model + + model.compile( + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + mxfp6_matmul=False, + ) + + ## STEP - 3 Load and process the inputs for Inference + + # We are resizing the image to (w x h) (1610 x 1109) so that any image can work on the model irrespective of image dimensssions + # we have a fixed size of height 1109 and width 1610 + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((1610, 1109)) + + messages = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": query}]}] + input_text = processor.apply_chat_template(messages, add_generation_prompt=True) + inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt") + + ## STEP - 4 Run Inference on the compiled model + + streamer = TextStreamer(processor.tokenizer) + output = model.generate(inputs=inputs, streamer=streamer, generation_len=generation_len) + print(output) + + +if __name__ == "__main__": + # Model name and Input parameters + model_name = "ibm-granite/granite-vision-3.2-2b" + + # Please add prompt here + query = "Describe the image" + + # Please pass image url or image path .The format of the image should be jpg. + image_url = "http://images.cocodataset.org/val2017/000000039769.jpg" + + # Compilation parameters for the model + kv_offload = True + prefill_seq_len = 5500 + ctx_len = 8192 + generation_len = 128 + img_size = 384 + num_cores = 16 + num_devices = 4 + ctx_len = 8192 + comp_ctx_lengths_prefill = [5500] + comp_ctx_lengths_decode = [6144, ctx_len] + + run_model( + model_name=model_name, + token=HF_TOKEN, + query=query, + kv_offload=kv_offload, + image_url=image_url, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + generation_len=generation_len, + img_size=img_size, + num_cores=num_cores, + num_devices=num_devices, + ) + + +""" +Expected Response: + +The image depicts two cats lying on a pink blanket that is spread out on a red couch. The cats are positioned in a relaxed manner, with their bodies stretched out and their heads resting on the blanket. +The cat on the left is a smaller, tabby cat with a mix of black, gray, and white fur. It has a long, slender body and a distinctive tail that is curled up near its tail end. The cat on the right is a larger, +tabby cat with a mix of gray, black, and brown fur. It has + +""" diff --git a/examples/granite_example/ccl_granitemoe_inference.py b/examples/granite_example/ccl_granitemoe_inference.py new file mode 100644 index 000000000..57668ca24 --- /dev/null +++ b/examples/granite_example/ccl_granitemoe_inference.py @@ -0,0 +1,40 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.constants import Constants + +model_name = "ibm-research/PowerMoE-3b" +""" +# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function +# We will use prompt_len=1 for compilation for both cb and non-cb inference +""" + +ctx_len = 2048 +comp_ctx_lengths_prefill = [256] +comp_ctx_lengths_decode = [512, 1024, ctx_len] + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + continuous_batching=False, +) +model.compile( + prefill_seq_len=1, + ctx_len=ctx_len, + batch_size=1, + num_cores=16, + num_devices=4, + mxfp6_matmul=False, + mxint8_kv_cache=False, +) +tokenizer = AutoTokenizer.from_pretrained(model_name) +exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) diff --git a/examples/intern_example/ccl_internvl_inference.py b/examples/intern_example/ccl_internvl_inference.py new file mode 100644 index 000000000..5595d26cd --- /dev/null +++ b/examples/intern_example/ccl_internvl_inference.py @@ -0,0 +1,286 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from io import BytesIO +from typing import List + +import requests +import torch +import torch.nn as nn +import torchvision.transforms as T +from PIL import Image +from torchvision.transforms.functional import InterpolationMode +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.logging_utils import logger + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + + +# Process the input messages to generate prompt for the model. +def get_prompt(messages) -> str: + """Get the prompt for generation.""" + ## Chat template used for InternVL + system_prompt = "<|im_start|>system\n你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。" + sep = "<|im_end|>\n" + + ret = system_prompt + sep + for role, message in messages: + if message: + if type(message) is tuple: + message, _, _ = message + ret += role + message + sep + else: + ret += role + return ret + + +# Processor class for InternVL models +class InternProcessor: + """ + InternVL model only has an AutoTokenizer so this class performs the processing tasks similar to an AutoProcessor. + The methods used here are borrowed from the original InternVL modelling files. + "https://huggingface.co/OpenGVLab/InternVL2_5-1B/" + """ + + def __init__(self, model: nn.Module, tokenizer): + self.model = model + image_size = self.model.config.force_image_size or self.model.config.vision_config.image_size + patch_size = self.model.config.vision_config.patch_size + self.template = model.config.template + self.num_image_token = int((image_size // patch_size) ** 2 * (self.model.config.downsample_ratio**2)) + self.tokenizer = tokenizer + + def build_transform(self, input_size): + MEAN, STD = IMAGENET_MEAN, IMAGENET_STD + transform = T.Compose( + [ + T.Lambda(lambda img: img.convert("RGB") if img.mode != "RGB" else img), + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.ToTensor(), + T.Normalize(mean=MEAN, std=STD), + ] + ) + return transform + + def find_closest_aspect_ratio(self, aspect_ratio, target_ratios, width, height, image_size): + best_ratio_diff = float("inf") + best_ratio = (1, 1) + area = width * height + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess(self, image, min_num=1, max_num=12, image_size=448, use_thumbnail=False): + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + # calculate the existing image aspect ratio + target_ratios = set( + (i, j) + for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + # find the closest aspect ratio to the target + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + # calculate the target width and height + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + # resize the image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size, + ) + # split the image + split_img = resized_img.crop(box) + processed_images.append(split_img) + assert len(processed_images) == blocks + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + return processed_images + + def load_image(self, image, input_size=448, max_num=12): + transform = self.build_transform(input_size=input_size) + images = self.dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num) + pixel_values = [transform(image) for image in images] + pixel_values = torch.stack(pixel_values) + return pixel_values + + def __call__( + self, + pixel_values, + question, + messages, + roles, + history=None, + num_patches_list=None, + IMG_START_TOKEN="", + IMG_END_TOKEN="", + IMG_CONTEXT_TOKEN="", + verbose=False, + ) -> str: + if history is None and pixel_values is not None and "" not in question: + question = "\n" + question + if num_patches_list is None: + num_patches_list = [pixel_values.shape[0]] if pixel_values is not None else [] + assert pixel_values is None or len(pixel_values) == sum(num_patches_list) + img_context_token_id = self.tokenizer.convert_tokens_to_ids(IMG_CONTEXT_TOKEN) + self.model.img_context_token_id = img_context_token_id + + messages.append([roles[0], question]) + messages.append([roles[1], None]) + query = get_prompt(messages) + if verbose and pixel_values is not None: + image_bs = pixel_values.shape[0] + logger.info(f"dynamic ViT batch size: {image_bs}") + + for num_patches in num_patches_list: + image_tokens = IMG_START_TOKEN + IMG_CONTEXT_TOKEN * self.num_image_token * num_patches + IMG_END_TOKEN + query = query.replace("", image_tokens, 1) + return query + + +def run_intern_on_aic( + model_name, + prompt, + image_url, + messages, + roles, + kv_offload=False, + prefill_seq_len=3840, + num_devices=1, + num_cores=16, +): + ## STEP 1 -- LOAD THE MODEL + + # The original Intern-VL model, despite being multimodal, is loaded using `AutoModelForCausalLM` in Huggingface. + # To maintain compatibility, we load this model using `QEFFAutoModelForCausalLM`. + + ctx_len = 8192 + comp_ctx_lengths_prefill = [4096] + comp_ctx_lengths_decode = [6144, ctx_len] + + # model = QEFFAutoModelForCausalLM.from_pretrained(model_name, kv_offload=kv_offload, trust_remote_code=True) + + model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + kv_offload=kv_offload, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + trust_remote_code=True, + ) + + ## STEP 2 -- EXPORT & COMPILE THE MODEL + + model.compile( + num_cores=num_cores, + num_devices=num_devices, + ctx_len=ctx_len, + prefill_seq_len=prefill_seq_len, + mxfp6_matmul=False, + ) + + ## STEP 3 -- SETUP THE PROCESSOR + + # InternVL doesn't have an AutoProcessor yet, so we will use our own processor class "InternProcessor" + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + internProcessor = InternProcessor(model.model, tokenizer) + + ## STEP 4 -- PREPROCESS THE INPUTS + + img = requests.get(image_url, stream=True) + image = Image.open(BytesIO(img.content)).convert("RGB") + + # Images are resized to (1000, 747) for inference + image = image.resize((1000, 747)) + + # preprocess the resized image + pixel_values = internProcessor.load_image(image, max_num=12) + question = "\n" + prompt + query = internProcessor(pixel_values, question, messages, roles) + inputs = tokenizer( + query, return_tensors="pt", padding="max_length", max_length=prefill_seq_len, padding_side="right" + ) + + inputs["pixel_values"] = pixel_values + + ## STEP 5 -- RUN INFERENCE VIA GENERATE FUNCTION + streamer = TextStreamer(tokenizer) + model.generate(inputs=inputs, streamer=streamer, generation_len=128) + + +if __name__ == "__main__": + model_name = "OpenGVLab/InternVL2_5-1B" + + # Chat Template information for prompt preprocessing + messages: List[List[str]] = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + + # Inputs for the model + prompt = "Please describe the image in detail." + image_url = "https://image.slidesharecdn.com/azureintroduction-191206101932/75/Introduction-to-Microsoft-Azure-Cloud-1-2048.jpg" + + ## Compilation parameters + + # `kv_offload` is used to compile the model in a Single QPC or 2 QPCs. + # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. + # The outputs of the Vision Encoder are then passed to the Language model via host in this case. + + kv_offload = False + + # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with + # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to + # incorporate the memory for the merged embeddings. + + prefill_seq_len = 3840 + num_devices = 4 + num_cores = 16 + + run_intern_on_aic( + model_name=model_name, + prompt=prompt, + image_url=image_url, + messages=messages, + roles=roles, + kv_offload=kv_offload, + prefill_seq_len=prefill_seq_len, + num_devices=num_devices, + num_cores=num_cores, + ) + + +""" +Expected Response: + +The image is a promotional graphic for Microsoft Azure. It features a blue background with a hexagonal pattern on the left side. The hexagons are white and are arranged in a way that suggests a network or connectivity theme. + +On the right side of the image, the Microsoft Azure logo is prominently displayed. The logo consists of the Azure name in white, with the Microsoft logo above it, which includes four colored squares (blue, green, yellow, and red). Below the logo, the word "Azure" is written in large white letters. + +Below the logo, there is text that reads: +- "By Dinesh Kumar Wick +""" diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py new file mode 100644 index 000000000..4d09b08f3 --- /dev/null +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -0,0 +1,42 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils.constants import Constants + +model_name = "Qwen/Qwen3-30B-A3B-Instruct-2507" +""" +# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function +# We will use prompt_len=1 for compilation for both cb and non-cb inference +""" + +ctx_len = 8192 + +comp_ctx_lengths_prefill = [4096] +comp_ctx_lengths_decode = [6144,8192] + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + continuous_batching=False, +) +model.compile( + prefill_seq_len=1, + ctx_len=ctx_len, + batch_size=1, + num_cores=16, + num_devices=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + mos=1, +) +tokenizer = AutoTokenizer.from_pretrained(model_name) +exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) diff --git a/tests/transformers/test_comp_ctx_length.py b/tests/transformers/test_comp_ctx_length.py new file mode 100644 index 000000000..e145ad698 --- /dev/null +++ b/tests/transformers/test_comp_ctx_length.py @@ -0,0 +1,193 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import copy +import os +from time import perf_counter + +import onnx +import pytest +from transformers import AutoConfig, AutoModel, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +configs = [ + # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params + ("gpt2", 256, 2, 4, 128, 512, 127, {}), + ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("falcon", 256, 2, 4, 128, 512, 127, {}), + ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("mpt", 256, 2, 4, 128, 512, 127, {}), + ("phi", 256, 2, 4, 128, 512, 127, {}), + ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), +] + +configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs +] +config_ids = [x.model_type for x in configs] + +model_kwargs = {"attn_implementation": "eager"} + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +def test_causal_lm_unsupported(cb): + model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt")) + with pytest.warns(): + QEFFAutoModelForCausalLM(model, cb) + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_init(config, cb): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + with pytest.raises(TypeError): + QEFFAutoModelForCausalLM(AutoModel.from_config(config, **model_kwargs), cb) + assert qeff_model.model.__class__.__name__.startswith("QEff") + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_pretrained(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + model.save_pretrained(tmp_path) + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(tmp_path, cb) + assert qeff_model.model.__class__.__name__.startswith("QEff") + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_hash(config, cb): + hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash + hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash + + assert hash_0_0 == hash_0_1 + + cfg1 = copy.deepcopy(config) + cfg1.num_hidden_layers -= 1 + hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash + cfg2 = copy.deepcopy(config) + cfg2.num_hidden_layers -= 1 + hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash + assert hash_1_0 == hash_1_1 + + assert hash_0_0 != hash_1_0 + + if cb: + hash_0_no_cb = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), False + ).model_hash + assert hash_0_0 != hash_0_no_cb + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_export(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + ctx_len = 2048 + comp_ctx_lengths_prefill = [256] + comp_ctx_lengths_decode = [512, 1024, ctx_len] + + qeff_model = QEFFAutoModelForCausalLM( + model, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + qeff_model.export(tmp_path) + model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) + assert model_path.is_dir() + assert qeff_model.onnx_path.is_file() + assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + + # Check if the KV-cache inputs and outputs are created + onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False) + retained_output_names = { + x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") + } + retained_output_names.issubset({x.name for x in onnx_model.graph.input}) + + # Check if there is no re-export + start = perf_counter() + qeff_model.export(tmp_path) + end = perf_counter() + export_time = end - start + assert export_time < 2.0 + + +@pytest.fixture +def tmp_cache(tmp_path, monkeypatch): + monkeypatch.setattr("QEfficient.base.modeling_qeff.QEFF_HOME", tmp_path) + yield tmp_path + + +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", configs, ids=config_ids) +def test_causal_lm_compile(config, cb, tmp_cache): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + ctx_len = 2048 + comp_ctx_lengths_prefill = [256] + comp_ctx_lengths_decode = [512, 1024, ctx_len] + qeff_model = QEFFAutoModelForCausalLM( + model, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + compile_params = {"prefill_seq_len": 8, "ctx_len": ctx_len} + if cb: + compile_params["full_batch_size"] = 32 + compile_params["batch_size"] = 8 + qeff_model.compile(**compile_params) + model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash) + + # Check if ONNX is exported properly + assert model_path.is_dir() + assert qeff_model.onnx_path.is_file() + assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + + # Check if QPC is compiled properly + assert qeff_model.qpc_path.is_dir() + assert (qeff_model.qpc_path / "programqpc.bin").is_file() + assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash + + # Check if there is no re-compilation + start = perf_counter() + qeff_model.compile(**compile_params) + end = perf_counter() + compile_time = end - start + assert compile_time < 2.0 + assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) From 495b44f8334a3e69c9ea3c23672d651db82f388a Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 16 Oct 2025 23:29:20 -0700 Subject: [PATCH 17/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- examples/ccl_mistral3_example.py | 7 ++++--- examples/ccl_qwen2_5_vl_example.py | 10 +++++----- examples/compute_context_length.py | 4 ++-- examples/qwen3moe_example/ccl_qwen3moe_inference.py | 2 +- 4 files changed, 12 insertions(+), 11 deletions(-) diff --git a/examples/ccl_mistral3_example.py b/examples/ccl_mistral3_example.py index ed02a4fa9..b76227a22 100644 --- a/examples/ccl_mistral3_example.py +++ b/examples/ccl_mistral3_example.py @@ -38,12 +38,13 @@ def run_model( config = AutoConfig.from_pretrained(model_name) config.vision_config._attn_implementation = "eager" - model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, - kv_offload=kv_offload, + model = QEFFAutoModelForImageTextToText.from_pretrained( + model_name, + kv_offload=kv_offload, config=config, ctx_len=ctx_len, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode + comp_ctx_lengths_decode=comp_ctx_lengths_decode, ) ## STEP - 2 Export & Compile the Model diff --git a/examples/ccl_qwen2_5_vl_example.py b/examples/ccl_qwen2_5_vl_example.py index 7056011f2..74063929b 100644 --- a/examples/ccl_qwen2_5_vl_example.py +++ b/examples/ccl_qwen2_5_vl_example.py @@ -24,16 +24,16 @@ ctx_len = 32768 comp_ctx_lengths_prefill = [4000] -comp_ctx_lengths_decode = [4096, 8192,16384, ctx_len] +comp_ctx_lengths_decode = [4096, 8192, 16384, ctx_len] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( - model_id, + model_id, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, - attn_implementation="eager", - kv_offload=True, - config=config + attn_implementation="eager", + kv_offload=True, + config=config, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py index 554c61c84..c1e5dc0df 100644 --- a/examples/compute_context_length.py +++ b/examples/compute_context_length.py @@ -17,7 +17,7 @@ ctx_len = 1024 comp_ctx_lengths_prefill = [256] -comp_ctx_lengths_decode = [512,ctx_len] +comp_ctx_lengths_decode = [512, ctx_len] # model_name = "google/gemma-7b" # model_name = "google/gemma-2-2b" @@ -57,5 +57,5 @@ "My name is ", ], tokenizer=tokenizer, - generation_len=128 + generation_len=128, ) diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index 4d09b08f3..4a7a16c1b 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -19,7 +19,7 @@ ctx_len = 8192 comp_ctx_lengths_prefill = [4096] -comp_ctx_lengths_decode = [6144,8192] +comp_ctx_lengths_decode = [6144, 8192] model = QEFFAutoModelForCausalLM.from_pretrained( model_name, From 9d1a63ac5511f69665a9b25d562c270a709a9ed6 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Fri, 17 Oct 2025 10:40:44 -0700 Subject: [PATCH 18/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- examples/intern_example/ccl_internvl_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/intern_example/ccl_internvl_inference.py b/examples/intern_example/ccl_internvl_inference.py index 5595d26cd..0828b1d41 100644 --- a/examples/intern_example/ccl_internvl_inference.py +++ b/examples/intern_example/ccl_internvl_inference.py @@ -251,7 +251,7 @@ def run_intern_on_aic( # The Dual QPC approach splits the model to perform Image Encoding and Output generation in 2 different QPCs. # The outputs of the Vision Encoder are then passed to the Language model via host in this case. - kv_offload = False + kv_offload = True # InternVL is an Early-Fusion model that uses placeholder tokens within the input_ids to interleave text_embeddings with # Image embeddings and generate final input_embeds for outout generation. Hence we need very large prefill_seq_len (3840 in this case) to From eb3aea51f1b49c1d5f885066d1fd078f2db5691d Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Sun, 19 Oct 2025 08:07:22 -0700 Subject: [PATCH 19/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 13 +++++++++---- .../qwen3moe_example/ccl_qwen3moe_inference.py | 14 +++++++++----- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5c6f67ddc..ad08fd0ac 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2235,8 +2235,9 @@ def __init__( self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) - - if self.comp_ctx_lengths_prefill: + prefill_seq_len = kwargs.pop("prefill_seq_len", 128) + + if self.comp_ctx_lengths_prefill and prefill_seq_len > 1: self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len ) @@ -2338,7 +2339,9 @@ def from_pretrained( comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) - if comp_ctx_lengths_prefill: + prefill_seq_len = kwargs.pop("prefill_seq_len", 128) + + if comp_ctx_lengths_prefill and prefill_seq_len > 1: comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len ) @@ -2356,6 +2359,7 @@ def from_pretrained( comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, + prefill_seq_len=prefill_seq_len, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs, @@ -2368,6 +2372,7 @@ def from_pretrained( comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, + prefill_seq_len=prefill_seq_len, **kwargs, ) @@ -2643,7 +2648,7 @@ def build_decode_specialization( A dictionary defining the decode specialization, or None if it would be a duplicate of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ - if prefill_seq_len == 1 and not self.continuous_batching and comp_ctx_lengths is None: + if prefill_seq_len == 1 and not self.continuous_batching:# and comp_ctx_lengths is None return None # Avoid duplication with prefill spec = { "batch_size": full_batch_size if self.continuous_batching else batch_size, diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index 4a7a16c1b..98258affc 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -16,10 +16,11 @@ # We will use prompt_len=1 for compilation for both cb and non-cb inference """ -ctx_len = 8192 - -comp_ctx_lengths_prefill = [4096] -comp_ctx_lengths_decode = [6144, 8192] +ctx_len = 65536 +prefill_seq_len = 1 +# In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same specializations. +comp_ctx_lengths_prefill = [4096,8192,16384,32768,ctx_len] +comp_ctx_lengths_decode = [4096,8192,16384,32768,ctx_len] model = QEFFAutoModelForCausalLM.from_pretrained( model_name, @@ -27,9 +28,11 @@ comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, continuous_batching=False, + prefill_seq_len=prefill_seq_len, ) + # prefill_seq_len=prefill_seq_len, model.compile( - prefill_seq_len=1, + prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, batch_size=1, num_cores=16, @@ -38,5 +41,6 @@ mxint8_kv_cache=True, mos=1, ) + # mos=1, tokenizer = AutoTokenizer.from_pretrained(model_name) exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) From 736c775287e3a73ec436fa456edbecf3729e9c04 Mon Sep 17 00:00:00 2001 From: vjanfaza Date: Fri, 17 Oct 2025 10:46:50 -0700 Subject: [PATCH 20/41] Delete examples/granite_example/ccl_granitemoe_inference.py Signed-off-by: vjanfaza --- .../ccl_granitemoe_inference.py | 40 ------------------- 1 file changed, 40 deletions(-) delete mode 100644 examples/granite_example/ccl_granitemoe_inference.py diff --git a/examples/granite_example/ccl_granitemoe_inference.py b/examples/granite_example/ccl_granitemoe_inference.py deleted file mode 100644 index 57668ca24..000000000 --- a/examples/granite_example/ccl_granitemoe_inference.py +++ /dev/null @@ -1,40 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -from transformers import AutoTokenizer - -from QEfficient import QEFFAutoModelForCausalLM -from QEfficient.utils.constants import Constants - -model_name = "ibm-research/PowerMoE-3b" -""" -# For CB inference, set continuous_batching to True and add full_batch_size,mxfp6,mint8 argument in compile function -# We will use prompt_len=1 for compilation for both cb and non-cb inference -""" - -ctx_len = 2048 -comp_ctx_lengths_prefill = [256] -comp_ctx_lengths_decode = [512, 1024, ctx_len] - -model = QEFFAutoModelForCausalLM.from_pretrained( - model_name, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - continuous_batching=False, -) -model.compile( - prefill_seq_len=1, - ctx_len=ctx_len, - batch_size=1, - num_cores=16, - num_devices=4, - mxfp6_matmul=False, - mxint8_kv_cache=False, -) -tokenizer = AutoTokenizer.from_pretrained(model_name) -exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) From 027625cf820eb29c28ee85d1167c30d65805bfc9 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Sun, 19 Oct 2025 20:00:08 -0700 Subject: [PATCH 21/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index ad08fd0ac..139ac30eb 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2648,8 +2648,10 @@ def build_decode_specialization( A dictionary defining the decode specialization, or None if it would be a duplicate of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ - if prefill_seq_len == 1 and not self.continuous_batching:# and comp_ctx_lengths is None - return None # Avoid duplication with prefill + if prefill_seq_len == 1: + if not self.continuous_batching or batch_size==1: + return None # Avoid duplication with prefill + spec = { "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, From 42b4b7f752a8498e86b52d72f4f58326c981d34c Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Sun, 19 Oct 2025 20:13:16 -0700 Subject: [PATCH 22/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 8 ++++---- .../qwen3moe_example/ccl_qwen3moe_inference.py | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 139ac30eb..bf2a445ad 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2236,7 +2236,7 @@ def __init__( self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - + if self.comp_ctx_lengths_prefill and prefill_seq_len > 1: self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len @@ -2340,7 +2340,7 @@ def from_pretrained( comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - + if comp_ctx_lengths_prefill and prefill_seq_len > 1: comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len @@ -2649,9 +2649,9 @@ def build_decode_specialization( of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ if prefill_seq_len == 1: - if not self.continuous_batching or batch_size==1: + if not self.continuous_batching or batch_size == 1: return None # Avoid duplication with prefill - + spec = { "batch_size": full_batch_size if self.continuous_batching else batch_size, "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index 98258affc..12e9ca1fc 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -16,31 +16,31 @@ # We will use prompt_len=1 for compilation for both cb and non-cb inference """ -ctx_len = 65536 +ctx_len = 32768 prefill_seq_len = 1 # In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same specializations. -comp_ctx_lengths_prefill = [4096,8192,16384,32768,ctx_len] -comp_ctx_lengths_decode = [4096,8192,16384,32768,ctx_len] +comp_ctx_lengths_prefill = [4096, 8192, 16384, ctx_len] +comp_ctx_lengths_decode = [4096, 8192, 16384, ctx_len] model = QEFFAutoModelForCausalLM.from_pretrained( model_name, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, - continuous_batching=False, + continuous_batching=True, prefill_seq_len=prefill_seq_len, ) - # prefill_seq_len=prefill_seq_len, +# prefill_seq_len=prefill_seq_len, model.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - batch_size=1, + full_batch_size=1, num_cores=16, num_devices=4, mxfp6_matmul=True, mxint8_kv_cache=True, mos=1, ) - # mos=1, +# mos=1, tokenizer = AutoTokenizer.from_pretrained(model_name) exec_info = model.generate(prompts=Constants.INPUT_STR, tokenizer=tokenizer) From fa3c2f6d92d44ea1601ce21605bf70feca099c4b Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 18:48:40 -0700 Subject: [PATCH 23/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- .../transformers/models/modeling_auto.py | 64 ++----------------- QEfficient/utils/check_ccl_specializations.py | 25 ++++++-- 2 files changed, 26 insertions(+), 63 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index bf2a445ad..6421a5b91 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -879,13 +879,7 @@ def __init__( self.model = model self.config = model.config - self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - if self.comp_ctx_lengths_prefill: - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len - ) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) @@ -933,14 +927,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - - if comp_ctx_lengths_prefill: - comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len - ) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( @@ -1498,14 +1485,7 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") super().__init__(model, **kwargs) - self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - - if self.comp_ctx_lengths_prefill: - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len - ) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): @@ -1554,14 +1534,7 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - - if comp_ctx_lengths_prefill: - comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len - ) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) from transformers import AutoConfig @@ -2115,14 +2088,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - - if comp_ctx_lengths_prefill: - comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len - ) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( @@ -2232,15 +2198,7 @@ def __init__( self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) self.is_tlm = transformed - self.comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - self.comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - - if self.comp_ctx_lengths_prefill and prefill_seq_len > 1: - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations( - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len - ) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) self.hash_params["qeff_auto_class"] = self.__class__.__name__ @@ -2336,15 +2294,7 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) - comp_ctx_lengths_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - comp_ctx_lengths_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - - if comp_ctx_lengths_prefill and prefill_seq_len > 1: - comp_ctx_lengths_prefill, comp_ctx_lengths_decode = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len - ) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index dbfb08926..8107447de 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -8,15 +8,28 @@ from typing import List, Optional +# def process_ccl_specializations( +# ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None +# ): def process_ccl_specializations( - ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None + kwargs ): + ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) + ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) + ctx_len = kwargs.pop("ctx_len", None) + prefill_seq_len = kwargs.pop("prefill_seq_len", 128) + if ctx_len is None: raise TypeError("`ctx_len` is required when loading the model.") - if ccl_prefill is None: - ccl_prefill = [ctx_len] - if ccl_decode is None: - ccl_decode = [ctx_len] + + if ccl_prefill is None or ccl_decode is None: + return None, None, ctx_len, prefill_seq_len + + if prefill_seq_len == 1: + #both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. + ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) + ccl_union_all = [min(x, ctx_len) for x in ccl_union_all] + return ccl_union_all, ccl_union_all, ctx_len, prefill_seq_len # Step 1: Cap values to ctx_len ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] @@ -40,4 +53,4 @@ def process_ccl_specializations( updated_prefill.sort() ccl_decode.sort() - return updated_prefill, ccl_decode + return updated_prefill, ccl_decode, ctx_len, prefill_seq_len From 8fb32659f8cd12ae3b1df9d97b8fdd949aa29ca1 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 18:53:54 -0700 Subject: [PATCH 24/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 16 ++++++++++++---- QEfficient/utils/check_ccl_specializations.py | 10 +++------- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 6421a5b91..9fb9a9c0a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -927,7 +927,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( + kwargs + ) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( @@ -1534,7 +1536,9 @@ def from_pretrained( kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( + kwargs + ) from transformers import AutoConfig @@ -2088,7 +2092,9 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( + kwargs + ) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( @@ -2294,7 +2300,9 @@ def from_pretrained( kv_offload = kwargs.pop("kv_offload", None) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations(kwargs) + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( + kwargs + ) kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 8107447de..45d2ea903 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -5,15 +5,11 @@ # # ----------------------------------------------------------------------------- -from typing import List, Optional - # def process_ccl_specializations( # ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None # ): -def process_ccl_specializations( - kwargs -): +def process_ccl_specializations(kwargs): ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) ctx_len = kwargs.pop("ctx_len", None) @@ -24,9 +20,9 @@ def process_ccl_specializations( if ccl_prefill is None or ccl_decode is None: return None, None, ctx_len, prefill_seq_len - + if prefill_seq_len == 1: - #both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. + # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) ccl_union_all = [min(x, ctx_len) for x in ccl_union_all] return ccl_union_all, ccl_union_all, ctx_len, prefill_seq_len From ee2f54e48b5d3d366510c810d52a45bbc5f2bdf5 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 18:57:22 -0700 Subject: [PATCH 25/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- QEfficient/utils/check_ccl_specializations.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 45d2ea903..0c7555512 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -5,10 +5,6 @@ # # ----------------------------------------------------------------------------- - -# def process_ccl_specializations( -# ccl_prefill: Optional[List[int]] = None, ccl_decode: Optional[List[int]] = None, ctx_len: Optional[int] = None -# ): def process_ccl_specializations(kwargs): ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) From bb2a2073b90e2c6a935f503aadd3b86ff1264b20 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 18:58:30 -0700 Subject: [PATCH 26/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- QEfficient/utils/check_ccl_specializations.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 0c7555512..3e66bfd35 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -5,6 +5,7 @@ # # ----------------------------------------------------------------------------- + def process_ccl_specializations(kwargs): ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) From 0e9c851346b761adba0937a4d368492c791c7a35 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 21 Oct 2025 19:15:01 -0700 Subject: [PATCH 27/41] improving handeling CCL lists Signed-off-by: Vahid Janfaza --- QEfficient/utils/check_ccl_specializations.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 3e66bfd35..6cb54a6c5 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -12,12 +12,12 @@ def process_ccl_specializations(kwargs): ctx_len = kwargs.pop("ctx_len", None) prefill_seq_len = kwargs.pop("prefill_seq_len", 128) - if ctx_len is None: - raise TypeError("`ctx_len` is required when loading the model.") - if ccl_prefill is None or ccl_decode is None: return None, None, ctx_len, prefill_seq_len + if ctx_len is None: + raise TypeError("`ctx_len` is required when loading the model with CCL.") + if prefill_seq_len == 1: # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) From 6cedad2aba5dcc987f5e27b7db3b7de4085e11d4 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 22 Oct 2025 16:06:00 -0700 Subject: [PATCH 28/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- .../models/codegen/modeling_codegen.py | 11 +- .../models/falcon/modeling_falcon.py | 11 +- .../models/gemma3/modeling_gemma3.py | 24 +- .../transformers/models/gpt2/modeling_gpt2.py | 11 +- .../transformers/models/gptj/modeling_gptj.py | 11 +- .../models/grok_1/modeling_grok1.py | 11 +- .../models/internvl/modeling_internvl.py | 26 ++- .../models/llama4/modeling_llama4.py | 24 +- .../models/llava/modeling_llava.py | 26 ++- .../models/llava_next/modeling_llava_next.py | 18 +- .../models/mistral3/modeling_mistral3.py | 24 +- .../models/mllama/modeling_mllama.py | 8 +- .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 16 +- .../models/whisper/modeling_whisper.py | 13 +- examples/ccl_qwen2_5_vl_example.py | 7 +- examples/compute_context_length.py | 11 +- .../ccl_qwen3moe_inference.py | 12 +- tests/transformers/test_comp_ctx_length.py | 205 ++++++++++++++---- 18 files changed, 365 insertions(+), 104 deletions(-) diff --git a/QEfficient/transformers/models/codegen/modeling_codegen.py b/QEfficient/transformers/models/codegen/modeling_codegen.py index 776bfce43..15efa2ce5 100644 --- a/QEfficient/transformers/models/codegen/modeling_codegen.py +++ b/QEfficient/transformers/models/codegen/modeling_codegen.py @@ -72,6 +72,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, @@ -123,7 +124,9 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key, value = layer_past.update(key.to(hidden_states.dtype), value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -147,6 +150,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, tuple[tuple[torch.Tensor]]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -245,6 +249,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, attention_mask=attention_mask, position_ids=position_ids, @@ -294,6 +299,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -312,6 +318,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, batch_index=batch_index, @@ -348,6 +355,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -361,6 +369,7 @@ def forward( attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, diff --git a/QEfficient/transformers/models/falcon/modeling_falcon.py b/QEfficient/transformers/models/falcon/modeling_falcon.py index 8f2c3730d..218852b15 100644 --- a/QEfficient/transformers/models/falcon/modeling_falcon.py +++ b/QEfficient/transformers/models/falcon/modeling_falcon.py @@ -117,6 +117,7 @@ def forward( attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Cache] = None, head_mask: Optional[torch.Tensor] = None, @@ -140,7 +141,9 @@ def forward( query_layer, key_layer = qeff_apply_rotary_pos_emb(query_layer, key_layer, cos, sin, position_ids) if layer_past is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) if attention_mask is not None: @@ -172,6 +175,7 @@ def forward( attention_mask: torch.Tensor, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, layer_past: Optional[Union[Cache, Tuple[torch.Tensor, torch.Tensor]]] = None, head_mask: Optional[torch.Tensor] = None, @@ -195,6 +199,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, alibi=alibi, head_mask=head_mask, @@ -245,6 +250,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Union[Cache, Tuple[Tuple[torch.Tensor, torch.Tensor], ...]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, head_mask: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.LongTensor] = None, use_cache: Optional[bool] = None, @@ -307,6 +313,7 @@ def forward( attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, head_mask=head_mask[i], use_cache=use_cache, @@ -352,6 +359,7 @@ def forward( position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, head_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None, use_cache: Optional[bool] = None, @@ -368,6 +376,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, head_mask=head_mask, inputs_embeds=inputs_embeds, diff --git a/QEfficient/transformers/models/gemma3/modeling_gemma3.py b/QEfficient/transformers/models/gemma3/modeling_gemma3.py index 2e8494e8e..95ee662b4 100644 --- a/QEfficient/transformers/models/gemma3/modeling_gemma3.py +++ b/QEfficient/transformers/models/gemma3/modeling_gemma3.py @@ -603,7 +603,13 @@ def __init__(self, model): self.lm_head = self.model.lm_head def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -637,7 +643,13 @@ def get_qeff_language_decoder(self): return QEffGemma3DecoderWrapper(self) def forward( - self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): image_features = self.get_image_features(pixel_values=pixel_values) inputs_embeds = self.get_input_embeddings()(input_ids) @@ -669,8 +681,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -749,7 +761,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -825,7 +837,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 896) else: diff --git a/QEfficient/transformers/models/gpt2/modeling_gpt2.py b/QEfficient/transformers/models/gpt2/modeling_gpt2.py index d68a65430..59d864907 100644 --- a/QEfficient/transformers/models/gpt2/modeling_gpt2.py +++ b/QEfficient/transformers/models/gpt2/modeling_gpt2.py @@ -65,6 +65,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -118,9 +119,11 @@ def forward( if (past_key_value is not None and not is_cross_attention) or ( past_key_value is not None and is_cross_attention and not is_updated ): + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # save all key/value_layer to cache to be re-used for fast auto-regressive generation # Update the cache_kwargs with position_ids for Cloud AI 100 - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -156,6 +159,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.FloatTensor]], past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -174,6 +178,7 @@ def forward( hidden_states, past_key_value=past_key_value, attention_mask=attention_mask, + comp_ctx_lengths=comp_ctx_lengths, position_ids=position_ids, batch_index=batch_index, head_mask=head_mask, @@ -232,6 +237,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -341,6 +347,7 @@ def forward( outputs = block( hidden_states, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -392,6 +399,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -418,6 +426,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/gptj/modeling_gptj.py b/QEfficient/transformers/models/gptj/modeling_gptj.py index dc3e5e6d2..da5bd881c 100644 --- a/QEfficient/transformers/models/gptj/modeling_gptj.py +++ b/QEfficient/transformers/models/gptj/modeling_gptj.py @@ -83,6 +83,7 @@ def forward( self, hidden_states: torch.FloatTensor, layer_past: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -134,7 +135,9 @@ def forward( query = query.permute(0, 2, 1, 3) if layer_past is not None: - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key, value = layer_past.update(key, value, self.layer_idx, cache_kwargs) # compute self-attention: V x Softmax(QK^T) @@ -151,6 +154,7 @@ def forward( self, hidden_states: Optional[torch.FloatTensor], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -164,6 +168,7 @@ def forward( attn_outputs, attn_weights = self.attn( hidden_states=hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -191,6 +196,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -270,6 +276,7 @@ def forward( outputs = block( hidden_states=hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=causal_mask, position_ids=position_ids, batch_index=batch_index, @@ -314,6 +321,7 @@ def forward( self, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -339,6 +347,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, diff --git a/QEfficient/transformers/models/grok_1/modeling_grok1.py b/QEfficient/transformers/models/grok_1/modeling_grok1.py index 567a8e070..a0f9cd915 100644 --- a/QEfficient/transformers/models/grok_1/modeling_grok1.py +++ b/QEfficient/transformers/models/grok_1/modeling_grok1.py @@ -55,6 +55,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: bool = False, use_cache: bool = False, @@ -93,7 +94,9 @@ def forward( query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: - cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"batch_index": batch_index, "position_ids": position_ids, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update(key_states, value_states, layer_idx, cache_kwargs) # repeat k/v heads if n_kv_heads < n_heads @@ -205,6 +208,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, output_router_logits: Optional[bool] = False, @@ -235,6 +239,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -277,6 +282,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, @@ -351,6 +357,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, output_attentions=output_attentions, use_cache=use_cache, @@ -395,6 +402,7 @@ def forward( attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, @@ -441,6 +449,7 @@ def forward( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, inputs_embeds=inputs_embeds, use_cache=use_cache, diff --git a/QEfficient/transformers/models/internvl/modeling_internvl.py b/QEfficient/transformers/models/internvl/modeling_internvl.py index 29e6ac9a4..96c59325f 100644 --- a/QEfficient/transformers/models/internvl/modeling_internvl.py +++ b/QEfficient/transformers/models/internvl/modeling_internvl.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -37,7 +37,13 @@ def __init__(self, model): self.language_model = self.model.language_model def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): input_embeds = self.model.language_model.get_input_embeddings()(input_ids) B, N, C = input_embeds.shape @@ -82,8 +88,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -172,7 +178,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -216,7 +222,7 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", constants.INTERN_IMG_SIZE) else: @@ -291,7 +297,13 @@ def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool return inputs def forward( - self, input_ids, pixel_values, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + pixel_values, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): input_embeds = self.language_model.get_input_embeddings()(input_ids) vision_embeds = self.extract_feature(pixel_values) diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 82678e380..0fbdbea5f 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -850,7 +850,13 @@ def __init__(self, model): self.config = self.model.config def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) selected = input_ids == self.model.config.image_token_index @@ -880,7 +886,13 @@ def get_qeff_language_decoder(self): return QEffLlama4DecoderWrapper(self) def forward( - self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.language_model.get_input_embeddings()(input_ids) vision_feature_layer = self.config.vision_config.vision_feature_layer @@ -917,8 +929,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -1034,7 +1046,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -1109,7 +1121,7 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: diff --git a/QEfficient/transformers/models/llava/modeling_llava.py b/QEfficient/transformers/models/llava/modeling_llava.py index 450fc79b6..dc6653db0 100644 --- a/QEfficient/transformers/models/llava/modeling_llava.py +++ b/QEfficient/transformers/models/llava/modeling_llava.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Optional import torch import torch.nn as nn @@ -54,7 +54,13 @@ def __init__(self, model): self.lm_head = self.model.lm_head def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) @@ -89,7 +95,13 @@ def get_qeff_language_decoder(self): return QEFFLlavaDecoderWrapper(self) def forward( - self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.get_input_embeddings()(input_ids) # Image features @@ -128,7 +140,7 @@ def forward( image_idx = torch.where(image_idx < next_image_idx, next_image_idx, image_idx) return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -178,8 +190,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -259,7 +271,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers diff --git a/QEfficient/transformers/models/llava_next/modeling_llava_next.py b/QEfficient/transformers/models/llava_next/modeling_llava_next.py index b23073fa7..2e4848b6b 100755 --- a/QEfficient/transformers/models/llava_next/modeling_llava_next.py +++ b/QEfficient/transformers/models/llava_next/modeling_llava_next.py @@ -6,7 +6,7 @@ # ----------------------------------------------------------------------------- -from typing import List +from typing import List, Optional import numpy as np import torch @@ -126,7 +126,13 @@ def __init__(self, model): self.lm_head = self.model.lm_head def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) image_features = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) @@ -159,7 +165,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffLlavaNextDecoderWrapper(self) - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): num_layers = self.config.text_config.num_hidden_layers num_key_value_heads = self.config.text_config.num_key_value_heads head_dim = self.config.text_config.hidden_size // self.config.text_config.num_attention_heads @@ -241,8 +247,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -365,7 +371,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers vision_dynamic_axes = { diff --git a/QEfficient/transformers/models/mistral3/modeling_mistral3.py b/QEfficient/transformers/models/mistral3/modeling_mistral3.py index a5f1301d2..694ed4cde 100644 --- a/QEfficient/transformers/models/mistral3/modeling_mistral3.py +++ b/QEfficient/transformers/models/mistral3/modeling_mistral3.py @@ -169,7 +169,13 @@ def __init__(self, model): self.language_model = self.model.language_model def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) vision_embeds = vision_embeds.to(inputs_embeds.device, inputs_embeds.dtype) @@ -204,7 +210,13 @@ def get_qeff_language_decoder(self): return QEFFMistral3DecoderWrapper(self) def forward( - self, input_ids, position_ids, pixel_values, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + position_ids, + pixel_values, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.get_input_embeddings()(input_ids) image_sizes = torch.tensor([[pixel_values.shape[2], pixel_values.shape[3]]]).repeat(pixel_values.shape[0], 1) @@ -238,7 +250,7 @@ def forward( return logits, pixel_values, image_idx, outputs.past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) height = self.config.vision_config.image_size @@ -309,8 +321,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -392,7 +404,7 @@ def get_specializations( lang[1].pop("vision_size") return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.text_config.num_hidden_layers diff --git a/QEfficient/transformers/models/mllama/modeling_mllama.py b/QEfficient/transformers/models/mllama/modeling_mllama.py index 2197bec91..d6fb1dcd2 100644 --- a/QEfficient/transformers/models/mllama/modeling_mllama.py +++ b/QEfficient/transformers/models/mllama/modeling_mllama.py @@ -899,7 +899,7 @@ def forward( logits = self.lm_head(hidden_states).float() return logits, image_idx, outputs.past_key_values, pixel_values - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): BS = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE SEQ_LEN = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN CTX_LEN = constants.ONNX_EXPORT_CTX_LEN @@ -983,8 +983,8 @@ def get_specializations( prefill_seq_len: int, ctx_len: int, img_size: int, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -1055,7 +1055,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): txt_cfg = self.config.get_text_config() num_hidden_layers = txt_cfg.num_hidden_layers cross_attention_layers = txt_cfg.cross_attention_layers diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 3b1d116de..ac91d5477 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -696,7 +696,13 @@ def __init__(self, model): self.language_model = self.model.model.language_model def forward( - self, input_ids, vision_embeds, position_ids, image_idx, past_key_values, comp_ctx_lengths: List[int] = None + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) B, N, C = inputs_embeds.shape @@ -730,7 +736,7 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -799,8 +805,8 @@ def get_specializations( img_size: None, height: int = None, width: int = None, - comp_ctx_lengths_prefill: List[int] = None, - comp_ctx_lengths_decode: List[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, kv_offload: bool = False, **compiler_options, ): @@ -932,7 +938,7 @@ def smart_resize( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: List[int] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes num_layers = self.config.num_hidden_layers diff --git a/QEfficient/transformers/models/whisper/modeling_whisper.py b/QEfficient/transformers/models/whisper/modeling_whisper.py index e078493a7..79907818d 100644 --- a/QEfficient/transformers/models/whisper/modeling_whisper.py +++ b/QEfficient/transformers/models/whisper/modeling_whisper.py @@ -55,6 +55,7 @@ def forward( position_ids_layer: torch.Tensor = None, key_value_states: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, layer_head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False, @@ -99,7 +100,9 @@ def forward( key_states = key_states.transpose(1, 2).contiguous() value_states = value_states.transpose(1, 2).contiguous() if past_key_value is not None: - cache_kwargs = {"position_ids": position_ids_layer} + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + cache_kwargs = {"position_ids": position_ids_layer, "CCL": attention_mask.shape[-1]} key_states, value_states = past_key_value.update( key_states, value_states, self.layer_idx, cache_kwargs ) @@ -181,6 +184,7 @@ def forward( layer_head_mask: Optional[torch.Tensor] = None, cross_attn_layer_head_mask: Optional[torch.Tensor] = None, past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = True, cache_position: Optional[torch.LongTensor] = None, @@ -215,6 +219,7 @@ def forward( hidden_states, self_attn_weights = self.self_attn( hidden_states=hidden_states, past_key_value=self_attn_past_key_value, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, layer_head_mask=layer_head_mask, output_attentions=output_attentions, @@ -388,6 +393,7 @@ def forward( cross_attn_head_mask=None, position_ids=None, past_key_values=None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, inputs_embeds=None, use_cache=None, output_attentions=None, @@ -532,6 +538,7 @@ def forward( layer_head_mask=(head_mask[idx] if head_mask is not None else None), cross_attn_layer_head_mask=(cross_attn_head_mask[idx] if cross_attn_head_mask is not None else None), past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, output_attentions=output_attentions, use_cache=use_cache, position_ids_layer=position_ids, @@ -643,6 +650,7 @@ def forward( cross_attn_head_mask=None, encoder_outputs=None, past_key_values=None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, decoder_inputs_embeds=None, use_cache=None, output_attentions=None, @@ -674,6 +682,7 @@ def forward( head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, inputs_embeds=decoder_inputs_embeds, use_cache=use_cache, output_attentions=output_attentions, @@ -719,6 +728,7 @@ def forward( cross_attn_head_mask: Optional[torch.Tensor] = None, encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, past_key_values: Optional[Union[EncoderDecoderCache, Tuple[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, decoder_inputs_embeds: Optional[Tuple[torch.FloatTensor]] = None, position_ids: Optional[Tuple[torch.LongTensor]] = None, labels: Optional[torch.LongTensor] = None, @@ -740,6 +750,7 @@ def forward( decoder_head_mask=decoder_head_mask, cross_attn_head_mask=cross_attn_head_mask, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, decoder_inputs_embeds=decoder_inputs_embeds, decoder_position_ids=position_ids, use_cache=use_cache, diff --git a/examples/ccl_qwen2_5_vl_example.py b/examples/ccl_qwen2_5_vl_example.py index 74063929b..b813462e3 100644 --- a/examples/ccl_qwen2_5_vl_example.py +++ b/examples/ccl_qwen2_5_vl_example.py @@ -21,10 +21,9 @@ ## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model -ctx_len = 32768 - -comp_ctx_lengths_prefill = [4000] -comp_ctx_lengths_decode = [4096, 8192, 16384, ctx_len] +ctx_len = 8192 +comp_ctx_lengths_prefill = [4096] +comp_ctx_lengths_decode = [6144, ctx_len] qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py index c1e5dc0df..00d475ae0 100644 --- a/examples/compute_context_length.py +++ b/examples/compute_context_length.py @@ -31,9 +31,16 @@ # model_name = "Qwen/Qwen3-1.7B" # model_name = "allenai/OLMo-2-0425-1B" # model_name = "ibm-granite/granite-3.3-2b-base" +# model_name = "meta-llama/Llama-3.3-70B-Instruct" +# model_name = "Salesforce/codegen-350M-mono" +# model_name = "tiiuae/falcon-7b-instruct" +# model_name = "openai-community/gpt2" +# model_name = "EleutherAI/gpt-j-6b" +# model_name = "EleutherAI/gpt-j-6b" + model = QEFFAutoModelForCausalLM.from_pretrained( model_name, - continuous_batching=False, + continuous_batching=True, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, @@ -45,7 +52,7 @@ ctx_len=ctx_len, num_cores=16, num_devices=1, - batch_size=1, + full_batch_size=1, mxint8_kv_cache=True, mxfp6_matmul=True, ) diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index 12e9ca1fc..f200c6fa6 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -16,25 +16,25 @@ # We will use prompt_len=1 for compilation for both cb and non-cb inference """ -ctx_len = 32768 +ctx_len = 1024 prefill_seq_len = 1 # In moe models when compiling with prefill_seq_len=1 and non-continuous-batching mode, prefill and decode will share the same specializations. -comp_ctx_lengths_prefill = [4096, 8192, 16384, ctx_len] -comp_ctx_lengths_decode = [4096, 8192, 16384, ctx_len] +comp_ctx_lengths_prefill = [256, 512, ctx_len] +comp_ctx_lengths_decode = [256, 512, ctx_len] model = QEFFAutoModelForCausalLM.from_pretrained( model_name, comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, - continuous_batching=True, + continuous_batching=False, prefill_seq_len=prefill_seq_len, ) -# prefill_seq_len=prefill_seq_len, + model.compile( prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, - full_batch_size=1, + batch_size=1, num_cores=16, num_devices=4, mxfp6_matmul=True, diff --git a/tests/transformers/test_comp_ctx_length.py b/tests/transformers/test_comp_ctx_length.py index e145ad698..31b9da07e 100644 --- a/tests/transformers/test_comp_ctx_length.py +++ b/tests/transformers/test_comp_ctx_length.py @@ -1,6 +1,6 @@ # ----------------------------------------------------------------------------- # -# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved. +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. # SPDX-License-Identifier: BSD-3-Clause # # ---------------------------------------------------------------------------- @@ -14,6 +14,8 @@ from transformers import AutoConfig, AutoModel, AutoModelForCausalLM from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.utils import constants, get_padding_shape_from_config +from QEfficient.utils.hash_utils import hash_dict_params configs = [ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params @@ -30,6 +32,7 @@ ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ("starcoder2", 256, 2, 4, 128, 512, 127, {}), ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ @@ -62,17 +65,41 @@ @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) def test_causal_lm_unsupported(cb): model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt")) + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] with pytest.warns(): - QEFFAutoModelForCausalLM(model, cb) + QEFFAutoModelForCausalLM( + model, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) @pytest.mark.parametrize("config", configs, ids=config_ids) def test_causal_lm_init(config, cb): model = AutoModelForCausalLM.from_config(config, **model_kwargs) - qeff_model = QEFFAutoModelForCausalLM(model, cb) + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] + qeff_model = QEFFAutoModelForCausalLM( + model, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) with pytest.raises(TypeError): - QEFFAutoModelForCausalLM(AutoModel.from_config(config, **model_kwargs), cb) + QEFFAutoModelForCausalLM( + AutoModel.from_config(config, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) assert qeff_model.model.__class__.__name__.startswith("QEff") @@ -82,43 +109,112 @@ def test_causal_lm_pretrained(config, cb, tmp_path): model = AutoModelForCausalLM.from_config(config, **model_kwargs) model.save_pretrained(tmp_path) - qeff_model = QEFFAutoModelForCausalLM.from_pretrained(tmp_path, cb) + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + tmp_path, + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) assert qeff_model.model.__class__.__name__.startswith("QEff") @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) @pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_hash(config, cb): - hash_0_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash - hash_0_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(config, **model_kwargs), cb).model_hash +def test_causal_lm_export_and_hash(config, cb, tmp_path): + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] + model_0_0 = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_0_0.export(tmp_path) + model_path = tmp_path.with_name(tmp_path.name + "-" + model_0_0.export_hash) + assert model_path.is_dir() + assert model_0_0.onnx_path.is_file() + assert model_0_0.onnx_path.relative_to(model_path).parts == (model_0_0.model_name + ".onnx",) + + # Check if the KV-cache inputs and outputs are created + onnx_model = onnx.load(model_0_0.onnx_path, load_external_data=False) + retained_output_names = { + x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") + } + retained_output_names.issubset({x.name for x in onnx_model.graph.input}) + + # Check if there is no re-export + start = perf_counter() + model_0_0.export(tmp_path) + end = perf_counter() + export_time = end - start + assert export_time < 2.0 + + # Check if hashing is happening properly + model_0_1 = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_0_1.export(tmp_path) + hash_0_0 = model_0_0.export_hash + hash_0_1 = model_0_1.export_hash assert hash_0_0 == hash_0_1 cfg1 = copy.deepcopy(config) cfg1.num_hidden_layers -= 1 - hash_1_0 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg1, **model_kwargs), cb).model_hash + model_1_0 = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(cfg1, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_1_0.export(tmp_path) + hash_1_0 = model_1_0.export_hash cfg2 = copy.deepcopy(config) cfg2.num_hidden_layers -= 1 - hash_1_1 = QEFFAutoModelForCausalLM(AutoModelForCausalLM.from_config(cfg2, **model_kwargs), cb).model_hash + model_1_1 = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(cfg2, **model_kwargs), + cb, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_1_1.export(tmp_path) + hash_1_1 = model_1_1.export_hash assert hash_1_0 == hash_1_1 assert hash_0_0 != hash_1_0 if cb: - hash_0_no_cb = QEFFAutoModelForCausalLM( - AutoModelForCausalLM.from_config(config, **model_kwargs), False - ).model_hash + model_0_no_cb = QEFFAutoModelForCausalLM( + AutoModelForCausalLM.from_config(config, **model_kwargs), + False, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + model_0_no_cb.export(tmp_path) + hash_0_no_cb = model_0_no_cb.export_hash assert hash_0_0 != hash_0_no_cb @pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) @pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_export(config, cb, tmp_path): +def test_causal_lm_hash_creation(config, cb, tmp_path): model = AutoModelForCausalLM.from_config(config, **model_kwargs) - ctx_len = 2048 - comp_ctx_lengths_prefill = [256] - comp_ctx_lengths_decode = [512, 1024, ctx_len] - + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] qeff_model = QEFFAutoModelForCausalLM( model, cb, @@ -127,29 +223,59 @@ def test_causal_lm_export(config, cb, tmp_path): ctx_len=ctx_len, ) qeff_model.export(tmp_path) - model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.model_hash) - assert model_path.is_dir() - assert qeff_model.onnx_path.is_file() - assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) + hash_params = {} + hash_params["config"] = qeff_model.model.config.to_diff_dict() + hash_params["peft_config"] = None + hash_params["applied_transform_names"] = qeff_model._transform_names() + hash_params["qeff_auto_class"] = qeff_model.__class__.__name__ + hash_params["qaic_config"] = None - # Check if the KV-cache inputs and outputs are created - onnx_model = onnx.load(qeff_model.onnx_path, load_external_data=False) - retained_output_names = { - x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") + # Create parameters separately for hash creation + + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + kv_cache_shape = get_padding_shape_from_config( + qeff_model.model.config, fbs if qeff_model.continuous_batching else bs, seq_len + ) + dynamic_axes = { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, } - retained_output_names.issubset({x.name for x in onnx_model.graph.input}) + dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d + pkv_dynamic_axes = { + 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size", + 1: "ctx_len", + } + else: # pkv is 4d + pkv_dynamic_axes = { + 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size", + 2: "ctx_len", + } + output_names = [] + output_names.append("logits") - # Check if there is no re-export - start = perf_counter() - qeff_model.export(tmp_path) - end = perf_counter() - export_time = end - start - assert export_time < 2.0 + for i in range(qeff_model.num_layers): + for kv in ["key", "value"]: + dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes + output_names.append(f"past_{kv}.{i}_RetainedState") + + if qeff_model.continuous_batching: + dynamic_axes["batch_index"] = {0: "batch_size"} + + export_params = {} + export_params["output_names"] = output_names + export_params["dynamic_axes"] = dynamic_axes + hash_params["export_params"] = export_params + manual_hash = hash_dict_params(hash_params) + + assert manual_hash == qeff_model.export_hash @pytest.fixture def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.base.modeling_qeff.QEFF_HOME", tmp_path) + monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) yield tmp_path @@ -157,9 +283,10 @@ def tmp_cache(tmp_path, monkeypatch): @pytest.mark.parametrize("config", configs, ids=config_ids) def test_causal_lm_compile(config, cb, tmp_cache): model = AutoModelForCausalLM.from_config(config, **model_kwargs) - ctx_len = 2048 - comp_ctx_lengths_prefill = [256] - comp_ctx_lengths_decode = [512, 1024, ctx_len] + + ctx_len = 32 + comp_ctx_lengths_prefill = [16] + comp_ctx_lengths_decode = [24, ctx_len] qeff_model = QEFFAutoModelForCausalLM( model, cb, @@ -172,7 +299,7 @@ def test_causal_lm_compile(config, cb, tmp_cache): compile_params["full_batch_size"] = 32 compile_params["batch_size"] = 8 qeff_model.compile(**compile_params) - model_path = tmp_cache / (qeff_model.model_name + "-" + qeff_model.model_hash) + model_path = tmp_cache / qeff_model.model_name / (qeff_model.model_name + "-" + qeff_model.export_hash) # Check if ONNX is exported properly assert model_path.is_dir() @@ -182,7 +309,7 @@ def test_causal_lm_compile(config, cb, tmp_cache): # Check if QPC is compiled properly assert qeff_model.qpc_path.is_dir() assert (qeff_model.qpc_path / "programqpc.bin").is_file() - assert qeff_model.qpc_path.relative_to(tmp_cache).parts[0] == qeff_model.model_name + "-" + qeff_model.model_hash + assert qeff_model.qpc_path.relative_to(tmp_cache).parts[1] == qeff_model.model_name + "-" + qeff_model.export_hash # Check if there is no re-compilation start = perf_counter() From 528ad38e24b24f3f2f993f68802dc00e02ee9a83 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 22 Oct 2025 17:28:48 -0700 Subject: [PATCH 29/41] fixing lora testing Signed-off-by: Vahid Janfaza --- QEfficient/peft/lora/layers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/peft/lora/layers.py b/QEfficient/peft/lora/layers.py index 6b75e696f..79abeba77 100644 --- a/QEfficient/peft/lora/layers.py +++ b/QEfficient/peft/lora/layers.py @@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor): # multilora implementation: lora_ids other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1) selected_lora_a_weights = CtxGatherFuncCB.apply( - self.lora_a_weights, lora_ids, other_indices_a + self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2] ) # other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1) selected_lora_b_weights = CtxGatherFuncCB.apply( - self.lora_b_weights, lora_ids, other_indices_b + self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2] ) # other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1) selected_lora_scalings = CtxGatherFuncCB.apply( - self.lora_scalings, lora_ids, other_indices_s + self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2] ) # selected_lora_a_weights = selected_lora_a_weights.squeeze(1) From ba18a3e1502e5e95b960f03930efecee5ca52350 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 22 Oct 2025 17:32:04 -0700 Subject: [PATCH 30/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- examples/compute_context_length.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py index 00d475ae0..dc6991b16 100644 --- a/examples/compute_context_length.py +++ b/examples/compute_context_length.py @@ -41,9 +41,6 @@ model = QEFFAutoModelForCausalLM.from_pretrained( model_name, continuous_batching=True, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, ) # model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. From 6d056f9a5cc6522eeeb4ec1f31c1da9031713b6e Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Thu, 23 Oct 2025 09:59:13 +0000 Subject: [PATCH 31/41] Updated the test Signed-off-by: Rishin Raj --- .../ccl/test_ccl_causal_lm_models.py | 328 ++++++++++++++++++ .../test_ccl_export_compile.py} | 26 +- 2 files changed, 341 insertions(+), 13 deletions(-) create mode 100644 tests/transformers/ccl/test_ccl_causal_lm_models.py rename tests/transformers/{test_comp_ctx_length.py => ccl/test_ccl_export_compile.py} (93%) diff --git a/tests/transformers/ccl/test_ccl_causal_lm_models.py b/tests/transformers/ccl/test_ccl_causal_lm_models.py new file mode 100644 index 000000000..21224651e --- /dev/null +++ b/tests/transformers/ccl/test_ccl_causal_lm_models.py @@ -0,0 +1,328 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import copy +import os +from typing import Optional + +import numpy as np +import pytest +import torch +from transformers import AutoConfig, AutoModelForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.utils import hf_download +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants +from QEfficient.utils.device_utils import get_available_device_id +from QEfficient.utils.run_utils import ApiRunner +from QEfficient.utils.test_utils import ModelConfig + +# Test models for CCL feature +test_models_ccl = [ + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + "gpt2", + "Qwen/Qwen2-0.5B", +] + + +def get_custom_n_layers(model_name): + """ + Function to set number of layers for various types of models. + + Args: + model_name: str - Model name + + Returns: + n_layer: int or None - Number of layers to use + """ + if model_name in {"microsoft/Phi-3-mini-4k-instruct"}: + return 2 + return 16 + + +def load_causal_lm_model(model_name, n_layer=1, config=None): + """ + Function to load model from HuggingFace and transform to KV model. + + Args: + model_name: str - HuggingFace model name + n_layer: int - Number of layers + config: AutoConfig - Custom config (optional) + + Returns: + model_hf: Loaded model + params: Number of parameters + """ + torch.manual_seed(42) + model_path = hf_download( + repo_id=model_name, + ignore_patterns=["*.onnx", "*.ot", "*.md", "*.tflite", "*.pdf", "*.h5", "*.msgpack"], + ) + if config is None: + if n_layer is not None: + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + use_cache=True, + num_hidden_layers=n_layer, + attn_implementation="eager", + low_cpu_mem_usage=False, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + else: + model_hf = AutoModelForCausalLM.from_pretrained( + model_path, + use_cache=True, + attn_implementation="eager", + low_cpu_mem_usage=False, + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + else: + model_hf = AutoModelForCausalLM.from_config( + config, + attn_implementation="eager", + trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, + ) + + # Convert to FP32 if model is in BF16 or FP16 + torch_dtype = getattr(model_hf.config, "torch_dtype", None) + if torch_dtype == torch.bfloat16 or torch_dtype == torch.float16: + model_hf = model_hf.to(torch.float32) + + params = sum(p.numel() for p in model_hf.parameters()) + model_hf.eval() + return model_hf, params + + +def check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name: str, + prompt_len: int = Constants.PROMPT_LEN, + ctx_len: int = 128, + comp_ctx_lengths_prefill: Optional[list] = None, + comp_ctx_lengths_decode: Optional[list] = None, + n_layer: int = 1, + config: Optional[AutoConfig] = None, + pytorch_hf_tokens: Optional[list] = None, +): + """ + Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, + and the Cloud AI 100 model with CCL (Compute Context Length) feature, both with + and without continuous batching. + + Args: + model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + prompt_len (int): Prompt length for the model to compile. + ctx_len (int): Maximum context length to compile the model. + comp_ctx_lengths_prefill (list): List of compute context lengths for prefill. + comp_ctx_lengths_decode (list): List of compute context lengths for decode. + n_layer (int): Number of layers for the Model. + config (AutoConfig): Custom model config. + pytorch_hf_tokens (list): Pre-computed PyTorch tokens for external models. + """ + replace_transformers_quantizers() + + # Set default CCL values if not provided + if comp_ctx_lengths_prefill is None: + comp_ctx_lengths_prefill = [64] + if comp_ctx_lengths_decode is None: + comp_ctx_lengths_decode = [96, ctx_len] + + if config is None: + model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) + else: + model_hf, _ = load_causal_lm_model(model_name, config=config) + + tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) + config = model_hf.config + batch_size = len(Constants.INPUT_STR) + + api_runner = ApiRunner( + batch_size, + tokenizer, + config, + Constants.INPUT_STR, + Constants.PROMPT_LEN, + Constants.CTX_LEN, + ) + + # Run PyTorch HF model if not external model + if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + + # Create QEFF model with CCL parameters + qeff_model = QEFFAutoModelForCausalLM( + copy.deepcopy(model_hf), + pretrained_model_name_or_path=model_name, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + ) + + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + + if model_name not in ModelConfig.SWIFTKV_MODELS: + assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( + "Tokens don't match for HF PyTorch model output and KV PyTorch model output with CCL" + ) + + # Export to ONNX + onnx_model_path = qeff_model.export() + + # Note: Skipping ORT validation for CCL models as ApiRunner doesn't support comp_ctx_lengths input + # The CCL feature is validated through PyTorch and Cloud AI 100 execution + gen_len = pytorch_kv_tokens.shape[-1] + + if not get_available_device_id(): + pytest.skip("No available devices to run model on Cloud AI 100") + + # Compile for Cloud AI 100 with CCL + qpc_path = qeff_model.compile( + prefill_seq_len=prompt_len, + ctx_len=ctx_len, + num_cores=14, + mxfp6=False, + aic_enable_depth_first=False, + ) + + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) + cloud_ai_100_tokens = exec_info.generated_ids[0][:, :gen_len] + + # Validate Cloud AI 100 output matches PyTorch KV output + assert (pytorch_kv_tokens == cloud_ai_100_tokens).all(), ( + "Tokens don't match for PyTorch KV output and Cloud AI 100 output with CCL." + ) + assert os.path.isfile(os.path.join(os.path.dirname(qpc_path), "qconfig.json")) + + # Note: Continuous batching tests for CCL are skipped as they require additional runtime support + # The CCL feature validation is complete with the single-batch tests above + + +@pytest.mark.on_qaic +@pytest.mark.regular +@pytest.mark.ccl +@pytest.mark.parametrize("model_name", test_models_ccl) +def test_custom_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, custom_causal_model_config_dict): + """ + Test function to validate the dummy PyTorch model with CCL, the PyTorch model after KV changes, + the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + + Args: + model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + custom_causal_model_config_dict: Fixture providing custom model configs + """ + config = custom_causal_model_config_dict.get(model_name) + + # Using fixed reference tokens for external models + pytorch_hf_tokens = None + if model_name in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_custom_case"] + + if model_name in ModelConfig.QUANTIZED_MODELS: + check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, n_layer=2, pytorch_hf_tokens=pytorch_hf_tokens + ) + else: + check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, config=config, pytorch_hf_tokens=pytorch_hf_tokens + ) + + +@pytest.mark.nightly +@pytest.mark.on_qaic +@pytest.mark.ccl +@pytest.mark.parametrize("model_name", test_models_ccl) +def test_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): + """ + Test function to validate the PyTorch model with CCL, the PyTorch model after KV changes, + the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. + + Args: + model_name (str): Hugging Face Model Card name, Example: ``gpt2`` + """ + # Using fixed reference tokens for external models + pytorch_hf_tokens = None + if model_name in ModelConfig.EXTERNAL_MODELS: + pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_normal_case"] + + check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, n_layer=2, pytorch_hf_tokens=pytorch_hf_tokens + ) + + +@pytest.mark.on_qaic +@pytest.mark.ccl +def test_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): + """ + Test function to validate the PyTorch model with CCL, the PyTorch model after KV changes, + the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and + without continuous batching. + """ + model_name = "gpt2" + prompt_len = 1 + ctx_len = 128 + comp_ctx_lengths_prefill = [64] + comp_ctx_lengths_decode = [96, ctx_len] + + check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name=model_name, + prompt_len=prompt_len, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) + + +@pytest.mark.on_qaic +@pytest.mark.ccl +def test_ccl_causal_lm_with_different_ctx_lengths(): + """ + Test CCL feature with different context length configurations. + """ + model_name = "gpt2" + n_layer = 1 + + # Test case 1: Small context lengths + check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, + n_layer=n_layer, + ctx_len=64, + comp_ctx_lengths_prefill=[32], + comp_ctx_lengths_decode=[48, 64], + ) + + # Test case 2: Larger context lengths + check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, + n_layer=n_layer, + ctx_len=256, + comp_ctx_lengths_prefill=[128], + comp_ctx_lengths_decode=[192, 256], + ) + + +@pytest.mark.on_qaic +@pytest.mark.ccl +def test_ccl_causal_lm_with_multiple_prefill_decode_lengths(): + """ + Test CCL feature with multiple compute context lengths for both prefill and decode. + """ + model_name = "gpt2" + n_layer = 1 + ctx_len = 256 + + # Multiple CCL values for prefill and decode + comp_ctx_lengths_prefill = [64, 128] + comp_ctx_lengths_decode = [160, 192, 224, 256] + + check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( + model_name, + n_layer=n_layer, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ) diff --git a/tests/transformers/test_comp_ctx_length.py b/tests/transformers/ccl/test_ccl_export_compile.py similarity index 93% rename from tests/transformers/test_comp_ctx_length.py rename to tests/transformers/ccl/test_ccl_export_compile.py index 31b9da07e..4e8271fcb 100644 --- a/tests/transformers/test_comp_ctx_length.py +++ b/tests/transformers/ccl/test_ccl_export_compile.py @@ -20,19 +20,19 @@ configs = [ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params ("gpt2", 256, 2, 4, 128, 512, 127, {}), - ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - ("falcon", 256, 2, 4, 128, 512, 127, {}), - ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mpt", 256, 2, 4, 128, 512, 127, {}), - ("phi", 256, 2, 4, 128, 512, 127, {}), - ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), - ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("starcoder2", 256, 2, 4, 128, 512, 127, {}), - ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + # ("falcon", 256, 2, 4, 128, 512, 127, {}), + # ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), + # ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("mpt", 256, 2, 4, 128, 512, 127, {}), + # ("phi", 256, 2, 4, 128, 512, 127, {}), + # ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), + # ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("starcoder2", 256, 2, 4, 128, 512, 127, {}), + # ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), + # ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] configs = [ From c9aaaecd41c63a58be28e20e7bd4bbda82190011 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Thu, 23 Oct 2025 10:06:39 +0000 Subject: [PATCH 32/41] Lint fix Signed-off-by: Rishin Raj --- .../ccl/test_ccl_causal_lm_models.py | 61 +++++++++---------- 1 file changed, 29 insertions(+), 32 deletions(-) diff --git a/tests/transformers/ccl/test_ccl_causal_lm_models.py b/tests/transformers/ccl/test_ccl_causal_lm_models.py index 21224651e..239378a1c 100644 --- a/tests/transformers/ccl/test_ccl_causal_lm_models.py +++ b/tests/transformers/ccl/test_ccl_causal_lm_models.py @@ -9,7 +9,6 @@ import os from typing import Optional -import numpy as np import pytest import torch from transformers import AutoConfig, AutoModelForCausalLM @@ -34,10 +33,10 @@ def get_custom_n_layers(model_name): """ Function to set number of layers for various types of models. - + Args: model_name: str - Model name - + Returns: n_layer: int or None - Number of layers to use """ @@ -49,12 +48,12 @@ def get_custom_n_layers(model_name): def load_causal_lm_model(model_name, n_layer=1, config=None): """ Function to load model from HuggingFace and transform to KV model. - + Args: model_name: str - HuggingFace model name n_layer: int - Number of layers config: AutoConfig - Custom config (optional) - + Returns: model_hf: Loaded model params: Number of parameters @@ -88,7 +87,7 @@ def load_causal_lm_model(model_name, n_layer=1, config=None): attn_implementation="eager", trust_remote_code=model_name in ModelConfig.EXTERNAL_MODELS, ) - + # Convert to FP32 if model is in BF16 or FP16 torch_dtype = getattr(model_hf.config, "torch_dtype", None) if torch_dtype == torch.bfloat16 or torch_dtype == torch.float16: @@ -110,10 +109,10 @@ def check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens: Optional[list] = None, ): """ - Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, - and the Cloud AI 100 model with CCL (Compute Context Length) feature, both with + Validate the PyTorch model, the PyTorch model after KV changes, the ONNX model, + and the Cloud AI 100 model with CCL (Compute Context Length) feature, both with and without continuous batching. - + Args: model_name (str): Hugging Face Model Card name, Example: ``gpt2`` prompt_len (int): Prompt length for the model to compile. @@ -125,13 +124,13 @@ def check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( pytorch_hf_tokens (list): Pre-computed PyTorch tokens for external models. """ replace_transformers_quantizers() - + # Set default CCL values if not provided if comp_ctx_lengths_prefill is None: comp_ctx_lengths_prefill = [64] if comp_ctx_lengths_decode is None: comp_ctx_lengths_decode = [96, ctx_len] - + if config is None: model_hf, _ = load_causal_lm_model(model_name, n_layer=n_layer) else: @@ -140,7 +139,7 @@ def check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_name) config = model_hf.config batch_size = len(Constants.INPUT_STR) - + api_runner = ApiRunner( batch_size, tokenizer, @@ -149,7 +148,7 @@ def check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( Constants.PROMPT_LEN, Constants.CTX_LEN, ) - + # Run PyTorch HF model if not external model if model_name not in ModelConfig.SWIFTKV_MODELS and model_name not in ModelConfig.EXTERNAL_MODELS: pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) @@ -162,24 +161,24 @@ def check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( comp_ctx_lengths_decode=comp_ctx_lengths_decode, ctx_len=ctx_len, ) - + pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) if model_name not in ModelConfig.SWIFTKV_MODELS: assert (pytorch_hf_tokens == pytorch_kv_tokens).all(), ( "Tokens don't match for HF PyTorch model output and KV PyTorch model output with CCL" ) - + # Export to ONNX - onnx_model_path = qeff_model.export() - + _ = qeff_model.export() + # Note: Skipping ORT validation for CCL models as ApiRunner doesn't support comp_ctx_lengths input # The CCL feature is validated through PyTorch and Cloud AI 100 execution gen_len = pytorch_kv_tokens.shape[-1] if not get_available_device_id(): pytest.skip("No available devices to run model on Cloud AI 100") - + # Compile for Cloud AI 100 with CCL qpc_path = qeff_model.compile( prefill_seq_len=prompt_len, @@ -188,10 +187,10 @@ def check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( mxfp6=False, aic_enable_depth_first=False, ) - + exec_info = qeff_model.generate(tokenizer, prompts=Constants.INPUT_STR) cloud_ai_100_tokens = exec_info.generated_ids[0][:, :gen_len] - + # Validate Cloud AI 100 output matches PyTorch KV output assert (pytorch_kv_tokens == cloud_ai_100_tokens).all(), ( "Tokens don't match for PyTorch KV output and Cloud AI 100 output with CCL." @@ -208,9 +207,9 @@ def check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( @pytest.mark.parametrize("model_name", test_models_ccl) def test_custom_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, custom_causal_model_config_dict): """ - Test function to validate the dummy PyTorch model with CCL, the PyTorch model after KV changes, + Test function to validate the dummy PyTorch model with CCL, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - + Args: model_name (str): Hugging Face Model Card name, Example: ``gpt2`` custom_causal_model_config_dict: Fixture providing custom model configs @@ -223,9 +222,7 @@ def test_custom_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, custom_c pytorch_hf_tokens = ModelConfig.EXTERNAL_MODELS[model_name]["pytorch_hf_tokens_custom_case"] if model_name in ModelConfig.QUANTIZED_MODELS: - check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( - model_name, n_layer=2, pytorch_hf_tokens=pytorch_hf_tokens - ) + check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, n_layer=2, pytorch_hf_tokens=pytorch_hf_tokens) else: check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name, config=config, pytorch_hf_tokens=pytorch_hf_tokens @@ -238,9 +235,9 @@ def test_custom_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name, custom_c @pytest.mark.parametrize("model_name", test_models_ccl) def test_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): """ - Test function to validate the PyTorch model with CCL, the PyTorch model after KV changes, + Test function to validate the PyTorch model with CCL, the PyTorch model after KV changes, the ONNX model, and the Cloud AI 100 model, both with and without continuous batching. - + Args: model_name (str): Hugging Face Model Card name, Example: ``gpt2`` """ @@ -258,8 +255,8 @@ def test_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100(model_name): @pytest.mark.ccl def test_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100_pl1(): """ - Test function to validate the PyTorch model with CCL, the PyTorch model after KV changes, - the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and + Test function to validate the PyTorch model with CCL, the PyTorch model after KV changes, + the ONNX model, and the Cloud AI 100 model for a prompt length of 1, both with and without continuous batching. """ model_name = "gpt2" @@ -285,7 +282,7 @@ def test_ccl_causal_lm_with_different_ctx_lengths(): """ model_name = "gpt2" n_layer = 1 - + # Test case 1: Small context lengths check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name, @@ -294,7 +291,7 @@ def test_ccl_causal_lm_with_different_ctx_lengths(): comp_ctx_lengths_prefill=[32], comp_ctx_lengths_decode=[48, 64], ) - + # Test case 2: Larger context lengths check_ccl_causal_lm_pytorch_vs_kv_vs_ort_vs_ai100( model_name, @@ -314,7 +311,7 @@ def test_ccl_causal_lm_with_multiple_prefill_decode_lengths(): model_name = "gpt2" n_layer = 1 ctx_len = 256 - + # Multiple CCL values for prefill and decode comp_ctx_lengths_prefill = [64, 128] comp_ctx_lengths_decode = [160, 192, 224, 256] From 5765779de854d510406a7297a384d407e3baa40b Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 23 Oct 2025 15:54:44 -0700 Subject: [PATCH 33/41] Adding the support of modeling_gpt_bigcode with CCL Signed-off-by: Vahid Janfaza --- .../models/gpt_bigcode/modeling_gpt_bigcode.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py index af233870b..cb6a0f0d0 100644 --- a/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +++ b/QEfficient/transformers/models/gpt_bigcode/modeling_gpt_bigcode.py @@ -98,6 +98,7 @@ def forward( self, hidden_states: torch.Tensor, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, batch_index: Optional[torch.LongTensor] = None, @@ -151,8 +152,10 @@ def forward( ) if layer_past is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] # save all key/value_states to cache to be re-used for fast auto-regressive generation - cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index} + cache_kwargs = {"position_ids": position_ids, "batch_index": batch_index, "CCL": attention_mask.shape[-1]} key, value = curr_past_key_value.update(key, value, self.layer_idx, cache_kwargs) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if self.is_cross_attention: @@ -180,6 +183,7 @@ def forward( self, hidden_states: Optional[Tuple[torch.Tensor]], layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, @@ -194,6 +198,7 @@ def forward( attn_outputs = self.attn( hidden_states, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, position_ids=position_ids, batch_index=batch_index, @@ -242,6 +247,7 @@ def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[list[torch.Tensor]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -333,6 +339,7 @@ def forward( outputs = block( hidden_states, layer_past=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, position_ids=position_ids, batch_index=batch_index, attention_mask=attention_mask, @@ -374,6 +381,7 @@ def forward( self, input_ids: Optional[torch.Tensor] = None, past_key_values: Optional[tuple[tuple[torch.Tensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.Tensor] = None, token_type_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, @@ -399,6 +407,7 @@ def forward( transformer_outputs = self.transformer( input_ids, past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, attention_mask=attention_mask, token_type_ids=token_type_ids, position_ids=position_ids, From 0468a908b040e3f03090d4d8cca7cb7734639469 Mon Sep 17 00:00:00 2001 From: Rishin Raj Date: Fri, 24 Oct 2025 03:12:08 +0000 Subject: [PATCH 34/41] Removed redendunt test Signed-off-by: Rishin Raj --- tests/transformers/test_comp_ctx_length.py | 320 --------------------- 1 file changed, 320 deletions(-) delete mode 100644 tests/transformers/test_comp_ctx_length.py diff --git a/tests/transformers/test_comp_ctx_length.py b/tests/transformers/test_comp_ctx_length.py deleted file mode 100644 index 31b9da07e..000000000 --- a/tests/transformers/test_comp_ctx_length.py +++ /dev/null @@ -1,320 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ---------------------------------------------------------------------------- - -import copy -import os -from time import perf_counter - -import onnx -import pytest -from transformers import AutoConfig, AutoModel, AutoModelForCausalLM - -from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM -from QEfficient.utils import constants, get_padding_shape_from_config -from QEfficient.utils.hash_utils import hash_dict_params - -configs = [ - # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params - ("gpt2", 256, 2, 4, 128, 512, 127, {}), - ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - ("falcon", 256, 2, 4, 128, 512, 127, {}), - ("gptj", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), - ("llama", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mistral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mixtral", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("mpt", 256, 2, 4, 128, 512, 127, {}), - ("phi", 256, 2, 4, 128, 512, 127, {}), - ("phi3", 256, 2, 4, 128, 512, 127, {"pad_token_id": 0}), - ("qwen2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("starcoder2", 256, 2, 4, 128, 512, 127, {}), - ("granite", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), - ("olmo2", 256, 2, 4, 128, 512, 127, {"num_key_value_heads": 2}), -] - -configs = [ - AutoConfig.for_model( - model_name, - max_position_embeddings=max_position_embeddings, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - vocab_size=vocab_size, - **additional_params, - ) - for ( - model_name, - max_position_embeddings, - num_hidden_layers, - num_attention_heads, - hidden_size, - intermediate_size, - vocab_size, - additional_params, - ) in configs -] -config_ids = [x.model_type for x in configs] - -model_kwargs = {"attn_implementation": "eager"} - - -@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) -def test_causal_lm_unsupported(cb): - model = AutoModelForCausalLM.from_config(AutoConfig.for_model("opt")) - ctx_len = 32 - comp_ctx_lengths_prefill = [16] - comp_ctx_lengths_decode = [24, ctx_len] - with pytest.warns(): - QEFFAutoModelForCausalLM( - model, - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - - -@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) -@pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_init(config, cb): - model = AutoModelForCausalLM.from_config(config, **model_kwargs) - ctx_len = 32 - comp_ctx_lengths_prefill = [16] - comp_ctx_lengths_decode = [24, ctx_len] - qeff_model = QEFFAutoModelForCausalLM( - model, - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - with pytest.raises(TypeError): - QEFFAutoModelForCausalLM( - AutoModel.from_config(config, **model_kwargs), - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - assert qeff_model.model.__class__.__name__.startswith("QEff") - - -@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) -@pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_pretrained(config, cb, tmp_path): - model = AutoModelForCausalLM.from_config(config, **model_kwargs) - model.save_pretrained(tmp_path) - - ctx_len = 32 - comp_ctx_lengths_prefill = [16] - comp_ctx_lengths_decode = [24, ctx_len] - qeff_model = QEFFAutoModelForCausalLM.from_pretrained( - tmp_path, - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - assert qeff_model.model.__class__.__name__.startswith("QEff") - - -@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) -@pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_export_and_hash(config, cb, tmp_path): - ctx_len = 32 - comp_ctx_lengths_prefill = [16] - comp_ctx_lengths_decode = [24, ctx_len] - model_0_0 = QEFFAutoModelForCausalLM( - AutoModelForCausalLM.from_config(config, **model_kwargs), - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - model_0_0.export(tmp_path) - model_path = tmp_path.with_name(tmp_path.name + "-" + model_0_0.export_hash) - assert model_path.is_dir() - assert model_0_0.onnx_path.is_file() - assert model_0_0.onnx_path.relative_to(model_path).parts == (model_0_0.model_name + ".onnx",) - - # Check if the KV-cache inputs and outputs are created - onnx_model = onnx.load(model_0_0.onnx_path, load_external_data=False) - retained_output_names = { - x.name[: -len("_RetainedState")] for x in onnx_model.graph.output if x.name.endswith("_RetainedState") - } - retained_output_names.issubset({x.name for x in onnx_model.graph.input}) - - # Check if there is no re-export - start = perf_counter() - model_0_0.export(tmp_path) - end = perf_counter() - export_time = end - start - assert export_time < 2.0 - - # Check if hashing is happening properly - model_0_1 = QEFFAutoModelForCausalLM( - AutoModelForCausalLM.from_config(config, **model_kwargs), - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - model_0_1.export(tmp_path) - hash_0_0 = model_0_0.export_hash - hash_0_1 = model_0_1.export_hash - - assert hash_0_0 == hash_0_1 - - cfg1 = copy.deepcopy(config) - cfg1.num_hidden_layers -= 1 - model_1_0 = QEFFAutoModelForCausalLM( - AutoModelForCausalLM.from_config(cfg1, **model_kwargs), - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - model_1_0.export(tmp_path) - hash_1_0 = model_1_0.export_hash - cfg2 = copy.deepcopy(config) - cfg2.num_hidden_layers -= 1 - model_1_1 = QEFFAutoModelForCausalLM( - AutoModelForCausalLM.from_config(cfg2, **model_kwargs), - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - model_1_1.export(tmp_path) - hash_1_1 = model_1_1.export_hash - assert hash_1_0 == hash_1_1 - - assert hash_0_0 != hash_1_0 - - if cb: - model_0_no_cb = QEFFAutoModelForCausalLM( - AutoModelForCausalLM.from_config(config, **model_kwargs), - False, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - model_0_no_cb.export(tmp_path) - hash_0_no_cb = model_0_no_cb.export_hash - assert hash_0_0 != hash_0_no_cb - - -@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) -@pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_hash_creation(config, cb, tmp_path): - model = AutoModelForCausalLM.from_config(config, **model_kwargs) - ctx_len = 32 - comp_ctx_lengths_prefill = [16] - comp_ctx_lengths_decode = [24, ctx_len] - qeff_model = QEFFAutoModelForCausalLM( - model, - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - qeff_model.export(tmp_path) - hash_params = {} - hash_params["config"] = qeff_model.model.config.to_diff_dict() - hash_params["peft_config"] = None - hash_params["applied_transform_names"] = qeff_model._transform_names() - hash_params["qeff_auto_class"] = qeff_model.__class__.__name__ - hash_params["qaic_config"] = None - - # Create parameters separately for hash creation - - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - kv_cache_shape = get_padding_shape_from_config( - qeff_model.model.config, fbs if qeff_model.continuous_batching else bs, seq_len - ) - dynamic_axes = { - "input_ids": {0: "batch_size", 1: "seq_len"}, - "position_ids": {0: "batch_size", 1: "seq_len"}, - } - dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} - if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d - pkv_dynamic_axes = { - 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size", - 1: "ctx_len", - } - else: # pkv is 4d - pkv_dynamic_axes = { - 0: "full_batch_size" if qeff_model.continuous_batching else "batch_size", - 2: "ctx_len", - } - output_names = [] - output_names.append("logits") - - for i in range(qeff_model.num_layers): - for kv in ["key", "value"]: - dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes - output_names.append(f"past_{kv}.{i}_RetainedState") - - if qeff_model.continuous_batching: - dynamic_axes["batch_index"] = {0: "batch_size"} - - export_params = {} - export_params["output_names"] = output_names - export_params["dynamic_axes"] = dynamic_axes - hash_params["export_params"] = export_params - manual_hash = hash_dict_params(hash_params) - - assert manual_hash == qeff_model.export_hash - - -@pytest.fixture -def tmp_cache(tmp_path, monkeypatch): - monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) - yield tmp_path - - -@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) -@pytest.mark.parametrize("config", configs, ids=config_ids) -def test_causal_lm_compile(config, cb, tmp_cache): - model = AutoModelForCausalLM.from_config(config, **model_kwargs) - - ctx_len = 32 - comp_ctx_lengths_prefill = [16] - comp_ctx_lengths_decode = [24, ctx_len] - qeff_model = QEFFAutoModelForCausalLM( - model, - cb, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - ) - compile_params = {"prefill_seq_len": 8, "ctx_len": ctx_len} - if cb: - compile_params["full_batch_size"] = 32 - compile_params["batch_size"] = 8 - qeff_model.compile(**compile_params) - model_path = tmp_cache / qeff_model.model_name / (qeff_model.model_name + "-" + qeff_model.export_hash) - - # Check if ONNX is exported properly - assert model_path.is_dir() - assert qeff_model.onnx_path.is_file() - assert qeff_model.onnx_path.relative_to(model_path).parts == (qeff_model.model_name + ".onnx",) - - # Check if QPC is compiled properly - assert qeff_model.qpc_path.is_dir() - assert (qeff_model.qpc_path / "programqpc.bin").is_file() - assert qeff_model.qpc_path.relative_to(tmp_cache).parts[1] == qeff_model.model_name + "-" + qeff_model.export_hash - - # Check if there is no re-compilation - start = perf_counter() - qeff_model.compile(**compile_params) - end = perf_counter() - compile_time = end - start - assert compile_time < 2.0 - assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) From d8f4eab9bad9464d8177aa3ebba6270ca48add18 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 23 Oct 2025 23:25:40 -0700 Subject: [PATCH 35/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9fb9a9c0a..9cf9109dc 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2607,7 +2607,7 @@ def build_decode_specialization( of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ if prefill_seq_len == 1: - if not self.continuous_batching or batch_size == 1: + if not self.continuous_batching:# or batch_size == 1 return None # Avoid duplication with prefill spec = { From 7e952ad9590196d7ef23d76539efb47a8910e4be Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 23 Oct 2025 23:28:40 -0700 Subject: [PATCH 36/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9cf9109dc..12234df04 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2607,7 +2607,7 @@ def build_decode_specialization( of the prefill specialization (e.g., if prefill_seq_len is 1 and not continuous batching). """ if prefill_seq_len == 1: - if not self.continuous_batching:# or batch_size == 1 + if not self.continuous_batching: # or batch_size == 1 return None # Avoid duplication with prefill spec = { From 2d137f989c13f916f8a08fe422d795f9d74849ce Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Sat, 25 Oct 2025 12:11:28 -0700 Subject: [PATCH 37/41] Add CCL support to molmo model Signed-off-by: Vahid Janfaza --- .../models/molmo/modeling_molmo.py | 149 ++++++++++++++---- examples/ccl_molmo_example.py | 98 ++++++++++++ 2 files changed, 220 insertions(+), 27 deletions(-) create mode 100644 examples/ccl_molmo_example.py diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index 4f92316ca..a0b20fe28 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -243,6 +243,7 @@ def attention( attention_bias: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs, @@ -278,8 +279,17 @@ def attention( q, k = qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, self.config) if layer_past is not None: + if comp_ctx_lengths is not None: + attention_bias = attention_bias[:, :, :, : comp_ctx_lengths.shape[-1]] + print(f"attention_bias: {attention_bias.shape}") # sin and cos are specific to RoPE models; cache_position needed for the static cache - cache_kwargs = {"sin": sin, "cos": cos, "batch_index": batch_index, "position_ids": position_ids} + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "CCL": attention_bias.shape[-1], + } k, v = layer_past.update(k, v, self.layer_id, cache_kwargs) attention_interface: Callable = eager_attention_forward @@ -311,6 +321,7 @@ def forward( attention_bias: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, layer_past: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, **kwargs, @@ -334,6 +345,7 @@ def forward( attention_bias, position_ids=position_ids, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, ) @@ -380,6 +392,7 @@ def forward( subsegment_ids: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, batch_index: Optional[torch.LongTensor] = None, use_cache: bool = False, last_logits_only: bool = False, @@ -496,6 +509,7 @@ def forward( attention_bias=causal_mask, position_ids=position_ids, layer_past=layer_past, + comp_ctx_lengths=comp_ctx_lengths, batch_index=batch_index, use_cache=use_cache, ) @@ -518,6 +532,7 @@ def forward( attention_bias=causal_mask, position_ids=position_ids, layers_past=layers_past, + comp_ctx_lengths=comp_ctx_lengths, use_cache=use_cache, ) @@ -574,7 +589,15 @@ def __init__(self, model): # self.language_model = self.model.language_model self.config = self.model.config - def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_values): + def forward( + self, + input_ids, + vision_embeds, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, + ): if input_ids is not None: input_ids = input_ids * (input_ids != -1).to(input_ids.dtype) inputs_embeds = self.model.model.transformer.wte(input_ids) @@ -587,7 +610,11 @@ def forward(self, input_ids, vision_embeds, position_ids, image_idx, past_key_va # inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.model.forward( - input_embeddings=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + input_embeddings=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -608,7 +635,16 @@ def get_qeff_language_decoder(self): """ def forward( - self, pixel_values, image_masks, image_input_idx, valid_idx, input_ids, position_ids, image_idx, past_key_values + self, + pixel_values, + image_masks, + image_input_idx, + valid_idx, + input_ids, + position_ids, + image_idx, + past_key_values, + comp_ctx_lengths: Optional[List[int]] = None, ): image_features, _ = self.model.vision_backbone(pixel_values, image_masks) num_image, num_patch = image_features.shape[1:3] @@ -637,7 +673,11 @@ def forward( inputs_embeds = torch.where(input_ids.shape[1] == torch.tensor(1), inputs_embeds, image_embeds) outputs = self.model.forward( - input_embeddings=inputs_embeds, position_ids=position_ids, past_key_values=past_key_values, use_cache=True + input_embeddings=inputs_embeds, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) image_idx = torch.where(image_idx < next_idx, next_idx, image_idx) @@ -651,6 +691,8 @@ def get_specializations( ctx_len: int, num_images: int = None, img_size: int = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, valid_size: int = None, kv_offload: bool = False, **compiler_options, @@ -679,30 +721,77 @@ def get_specializations( } ] - lang_prefill = { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "valid_size": valid_size, - } - - lang_decode = {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "valid_size": valid_size} + if comp_ctx_lengths_prefill is not None and comp_ctx_lengths_decode is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "valid_size": valid_size, + } + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_prefill[key] = value + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": batch_size, + "seq_len": "1", + "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "valid_size": valid_size, + } + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_decode[key] = value + + lang.append(lang_decode) - if kv_offload: - values = { - "img_size": img_size, - "img_tile": img_tile, - "num_images": num_images, - "num_patch": num_patch, + else: + lang_prefill = { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "valid_size": valid_size, } - for key, value in values.items(): - lang_prefill[key] = value - lang_decode[key] = value + lang_decode = {"batch_size": batch_size, "seq_len": "1", "ctx_len": ctx_len, "valid_size": valid_size} + + if kv_offload: + values = { + "img_size": img_size, + "img_tile": img_tile, + "num_images": num_images, + "num_patch": num_patch, + } + + for key, value in values.items(): + lang_prefill[key] = value + lang_decode[key] = value + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) - lang = [] - lang.append(lang_prefill) - lang.append(lang_decode) specializations = {} if kv_offload: @@ -712,7 +801,7 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, kv_offload: bool = False): + def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -731,6 +820,9 @@ def get_onnx_dynamic_axes(self, kv_offload: bool = False): lang_dynamic_axes[f"past_key.{i}"] = {0: "batch_size", 2: "ctx_len"} lang_dynamic_axes[f"past_value.{i}"] = {0: "batch_size", 2: "ctx_len"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes @@ -760,7 +852,7 @@ def get_output_names(self, kv_offload: bool = False): return lang_output_names return output_names - def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): + def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): inputs_shapes = {} inputs_shapes_lang = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -823,6 +915,9 @@ def get_dummy_inputs(self, kv_offload: bool = False, **kwargs): for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs diff --git a/examples/ccl_molmo_example.py b/examples/ccl_molmo_example.py new file mode 100644 index 000000000..c52d9172b --- /dev/null +++ b/examples/ccl_molmo_example.py @@ -0,0 +1,98 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import requests +import torch +import transformers +from PIL import Image +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "allenai/Molmo-7B-D-0924" +config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + +# config.num_hidden_layers = 2 + +# load the model +ctx_len = 32768 +comp_ctx_lengths_prefill = [3072] +comp_ctx_lengths_decode = [4096, 8192, ctx_len] + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + kv_offload=True, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, + trust_remote_code=True, + config=config, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) +processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) + +### use skip_vision=Ture, if want to run only text, ow false ### +skip_vision = False + +if skip_vision: + ## Only Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + ) + + inputs = processor.process(text="Tell me about yourself") + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["input_ids"] = inputs["input_ids"].to(torch.int64) + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + +else: + ## Vision + Text ## + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + num_cores=16, + num_devices=4, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + + ### IMAGE + TEXT ### + image_url = "https://picsum.photos/id/237/536/354" + + image = Image.open(requests.get(image_url, stream=True).raw) + image = image.resize((536, 354)) + + inputs = processor.process(images=[image], text="Can you describe the image in detail.") + + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + inputs["pixel_values"] = inputs.pop("images") + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + + valid = inputs["image_input_idx"] > 0 + valid = valid.reshape(1, -1) + inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) + + streamer = TextStreamer(tokenizer) + output = qeff_model.generate(inputs=inputs, device_ids=[8, 9, 10, 11], generation_len=100) + print(output.generated_ids) + print(tokenizer.batch_decode(output.generated_ids)) + print(output) + print() From 65a76bd62967062cb7b0695ae87002b41a15b871 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Tue, 28 Oct 2025 17:13:47 -0700 Subject: [PATCH 38/41] Adding support of multimodal models in vllm with CCL Signed-off-by: Vahid Janfaza --- QEfficient/transformers/models/modeling_auto.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 12234df04..d0b420990 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -1094,6 +1094,10 @@ def compile( raise ValueError("Expected at least one of 'skip_lang' or 'skip_vision' to be False") output_names = self.model.get_output_names(kv_offload=True) + # For supporting VLLM and Disaggregated with CCL + if "comp_ctx_lengths_prefill" in compiler_options: + self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") + self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") specializations, compiler_options = self.model.get_specializations( batch_size=batch_size, @@ -1652,6 +1656,10 @@ def compile( ) output_names = self.model.get_output_names() + # For supporting VLLM and Disaggregated with CCL + if "comp_ctx_lengths_prefill" in compiler_options: + self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") + self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") # Get specializations from modelling file # TODO: expose this via the auto class as well From 4fac443a8d1f70a75b17bafbbc64afa6e2caed9e Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Mon, 3 Nov 2025 16:29:21 -0800 Subject: [PATCH 39/41] Update the test script Signed-off-by: Vahid Janfaza --- examples/compute_context_length.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py index dc6991b16..3be1a9eab 100644 --- a/examples/compute_context_length.py +++ b/examples/compute_context_length.py @@ -16,8 +16,8 @@ ## - The second comp_ctx_lengths_decode list will be used for decoding. During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ## ctx_len = 1024 -comp_ctx_lengths_prefill = [256] -comp_ctx_lengths_decode = [512, ctx_len] +comp_ctx_lengths_prefill = [256] # None +comp_ctx_lengths_decode = [ctx_len] # None # model_name = "google/gemma-7b" # model_name = "google/gemma-2-2b" @@ -27,20 +27,23 @@ # model_name = "microsoft/phi-1_5" # model_name = "microsoft/Phi-3-mini-4k-instruct" # model_name = "Qwen/Qwen2.5-7B-Instruct" -model_name = "meta-llama/Llama-3.2-1B" +# model_name = "meta-llama/Llama-3.2-1B" # model_name = "Qwen/Qwen3-1.7B" # model_name = "allenai/OLMo-2-0425-1B" -# model_name = "ibm-granite/granite-3.3-2b-base" +model_name = "ibm-granite/granite-3.3-2b-base" +# model_name = "ibm-granite/granite-3.2-8b-instruct" # model_name = "meta-llama/Llama-3.3-70B-Instruct" # model_name = "Salesforce/codegen-350M-mono" # model_name = "tiiuae/falcon-7b-instruct" # model_name = "openai-community/gpt2" # model_name = "EleutherAI/gpt-j-6b" -# model_name = "EleutherAI/gpt-j-6b" model = QEFFAutoModelForCausalLM.from_pretrained( model_name, continuous_batching=True, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + ctx_len=ctx_len, ) # model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. @@ -48,7 +51,7 @@ prefill_seq_len=128, ctx_len=ctx_len, num_cores=16, - num_devices=1, + num_devices=4, full_batch_size=1, mxint8_kv_cache=True, mxfp6_matmul=True, From 069f86a1a5f6f978fc940ddf36ad96fcc7949171 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 5 Nov 2025 10:03:15 -0800 Subject: [PATCH 40/41] Manual fixes before merge --- QEfficient/generation/text_generation_inference.py | 2 ++ QEfficient/utils/constants.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index cf4b6aa27..560131de2 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -809,12 +809,14 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] prefill_ccl_id = 0 inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + print(f"CCL Prefill: {self.comp_ctx_lengths_prefill[prefill_ccl_id]}") for i in range(num_chunks): if self.comp_ctx_lengths_prefill is not None: if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + print(f"CCL Prefill: {self.comp_ctx_lengths_prefill[prefill_ccl_id]}") chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 57fba282b..9196fa1c0 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -125,6 +125,10 @@ def get_models_dir(): # Wav2Vec2 Constant WAV2VEC2_MAX_SEQ_LEN = 480000 # 30 seconds of audio at 16 kHz sampling rate (16,000 samples/sec × 30 sec) +# Qwen2_5_vl Constants +QWEN2_5_VL_HEIGHT = 354 +QWEN2_5_VL_WIDTH = 536 + class Constants: # Export Constants. @@ -235,3 +239,4 @@ class QnnConstants: }, "SKIP_QNN_CONVERTER_STEP": False, } + From 9b8c8f53ef0f0905fee46b8cd455a3679ed86ba7 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Thu, 6 Nov 2025 23:17:47 -0800 Subject: [PATCH 41/41] Adding Compute-Context-Length(CCL) Signed-off-by: Vahid Janfaza --- .../generation/text_generation_inference.py | 2 - QEfficient/generation/vlm_generation.py | 800 ++++++++++++++++++ QEfficient/transformers/cache_utils.py | 123 +++ .../models/gpt_oss/modeling_gpt_oss.py | 746 ++++++++++++++++ .../models/llama4/modeling_llama4.py | 92 +- .../transformers/models/modeling_auto.py | 192 +++-- .../models/molmo/modeling_molmo.py | 1 - .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 114 ++- QEfficient/utils/check_ccl_specializations.py | 16 +- examples/ccl_gpt_oss.py | 50 ++ examples/ccl_image_text_to_text_inference.py | 8 +- examples/ccl_llama4_CB_example_vision_lang.py | 109 +++ examples/ccl_llama4_example.py | 30 +- examples/ccl_llama4_multi_image_example.py | 89 ++ examples/ccl_mistral3_example.py | 8 +- examples/ccl_molmo_example.py | 14 +- examples/ccl_qwen2_5_vl_CB.py | 81 ++ examples/ccl_qwen2_5_vl_example.py | 64 +- examples/compute_context_length.py | 22 +- examples/gemma3_example/ccl_gemma3_mm.py | 8 +- .../ccl_granite_vision_inference.py | 8 +- .../intern_example/ccl_internvl_inference.py | 8 +- .../ccl_qwen3moe_inference.py | 10 +- 23 files changed, 2382 insertions(+), 213 deletions(-) create mode 100644 QEfficient/generation/vlm_generation.py create mode 100644 QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py create mode 100644 examples/ccl_gpt_oss.py create mode 100644 examples/ccl_llama4_CB_example_vision_lang.py create mode 100644 examples/ccl_llama4_multi_image_example.py create mode 100644 examples/ccl_qwen2_5_vl_CB.py diff --git a/QEfficient/generation/text_generation_inference.py b/QEfficient/generation/text_generation_inference.py index 560131de2..cf4b6aa27 100755 --- a/QEfficient/generation/text_generation_inference.py +++ b/QEfficient/generation/text_generation_inference.py @@ -809,14 +809,12 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] prefill_ccl_id = 0 inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] - print(f"CCL Prefill: {self.comp_ctx_lengths_prefill[prefill_ccl_id]}") for i in range(num_chunks): if self.comp_ctx_lengths_prefill is not None: if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] - print(f"CCL Prefill: {self.comp_ctx_lengths_prefill[prefill_ccl_id]}") chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][ diff --git a/QEfficient/generation/vlm_generation.py b/QEfficient/generation/vlm_generation.py new file mode 100644 index 000000000..5eb91d142 --- /dev/null +++ b/QEfficient/generation/vlm_generation.py @@ -0,0 +1,800 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +This module provides the VisionLanguageGeneration class that inherits from +QEffTextGenerationBase, enabling all advanced text generation features while +maintaining full API compatibility with the original VisionLanguageGeneration. + +Key enhancements: +- Continuous batching support for vision models +- Advanced streaming capabilities +- On-device sampling support +- LoRA adapter support +- Better performance metrics +""" + +from collections import deque +from time import perf_counter +from typing import Any, Dict, List, Optional, Union + +import numpy as np +from transformers import AutoImageProcessor, PreTrainedTokenizer, PreTrainedTokenizerFast + +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.generation.embedding_handler import VisionHandler +from QEfficient.generation.text_generation_inference import ( + CloudAI100ExecInfo, + PerfMetrics, + QEffTextGenerationBase, + TextGeneration, + calculate_latency, + write_io_files, +) +from QEfficient.utils import LRUCache +from QEfficient.utils.logging_utils import logger + + +class VisionLanguageGeneration(QEffTextGenerationBase): + """ + Enhanced vision-language generation class inheriting from QEffTextGenerationBase. + + This class maintains full API compatibility with VisionLanguageGeneration while + adding advanced features like continuous batching, streaming, and sampling. + + Example: + >>> # Drop-in replacement for VisionLanguageGeneration + >>> vlm = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0] + ... ) + >>> result = vlm.generate( + ... images=["image1.jpg"], + ... prompts=["Describe this image"], + ... generation_len=512 + ... ) + + >>> # Enhanced usage with new features + >>> vlm_enhanced = VisionLanguageGeneration( + ... tokenizer=tokenizer, + ... processor=processor, + ... lang_qpc_path="path/to/lang.qpc", + ... vision_qpc_path="path/to/vision.qpc", + ... device_id=[0], + ... full_batch_size=8, # Enable continuous batching + ... include_sampler=True, # Enable on-device sampling + ... sampling_params=sampling_config + ... ) + """ + + def __init__( + self, + qeff_model, + tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast], + processor: AutoImageProcessor, + lang_qpc_path: str, + vision_qpc_path: str, + device_id: Optional[List[int]] = None, + ctx_len: Optional[int] = None, + comp_ctx_lengths_prefill: Optional[List[int]] = None, + comp_ctx_lengths_decode: Optional[List[int]] = None, + enable_debug_logs: bool = False, + write_io_dir: Optional[str] = None, + full_batch_size: Optional[int] = None, + is_tlm: bool = False, + include_sampler: bool = False, + return_pdfs: bool = False, + sampling_params: Optional[Dict[str, Any]] = None, + ): + """ + Initialize vision-language generation with enhanced capabilities + + Args: + qeff_model: QEff model instance + tokenizer: Text tokenizer + processor: Image processor + lang_qpc_path: Path to language model QPC + vision_qpc_path: Path to vision encoder QPC + device_id: Device IDs for execution (default: [0]) + ctx_len: Context length + enable_debug_logs: Enable debug logging + write_io_dir: Directory for I/O file writing + full_batch_size: Enable continuous batching (new feature) + is_tlm: Target language model flag + include_sampler: Enable on-device sampling (new feature) + return_pdfs: Return probability distributions + sampling_params: Sampling parameters for on-device sampling + """ + # Validate required parameters + if not lang_qpc_path: + raise TypeError("lang_qpc_path is required") + if not vision_qpc_path: + raise TypeError("vision_qpc_path is required") + + # Initialize base class with language QPC + # Pass activate=False to prevent premature activation before vision components are ready + super().__init__( + tokenizer=tokenizer, + qpc_path=lang_qpc_path, + full_batch_size=full_batch_size, + ctx_len=ctx_len, + comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + device_id=device_id, + enable_debug_logs=enable_debug_logs, + write_io_dir=write_io_dir, + is_tlm=is_tlm, + include_sampler=include_sampler, + return_pdfs=return_pdfs, + sampling_params=sampling_params, + activate=False, # vision components need to be initialized first + ) + + # Vision-specific initialization + self.is_qwen2_5_vl = ( + hasattr(qeff_model.model.config, "model_type") and qeff_model.model.config.model_type == "qwen2_5_vl" + ) + self.qeff_model = qeff_model + self.processor = processor + self._vision_qpc_path = vision_qpc_path + self.device_id = device_id # Store device_id for vision components + self.enable_debug_logs = enable_debug_logs # Store for vision components + self._vision_outputs_cache = LRUCache(max_size=100) # LRU cache for vision outputs + self._vision_cache = {} # Cache for vision outputs across batches + self._init_vision_components() + + # Now that vision components are initialized, activate the text session + self._session.activate() + + logger.info( + f"VisionLanguageGeneration initialized: batch_size={self.batch_size}, " + f"prefill_seq_len={self._prefill_seq_len}, ctx_len={ctx_len}, " + f"continuous_batching={'enabled' if full_batch_size else 'disabled'}, " + f"sampling={'enabled' if include_sampler else 'disabled'}" + ) + + def _init_vision_components(self): + """Initialize vision-specific components""" + # Vision session (separate from base class language session) + self._vision_session = QAICInferenceSession( + self._vision_qpc_path, self.device_id, activate=False, enable_debug_logs=self.enable_debug_logs + ) + + # Vision handler with language session coordination + vision_config = self._get_vision_config() + self._vision_handler = VisionHandler( + qeff_model=self.qeff_model, + vision_session=self._vision_session, + processor=self.processor, + config=vision_config, + lang_session=self._session, # Pass language session for coordination + ) + + # Setup vision buffer skipping + self._setup_vision_buffer_skipping() + + def _get_vision_config(self) -> Dict[str, Any]: + """ + Derive vision config from session + + Returns: + Dictionary with vision configuration + """ + config = {} + if self._vision_session: + try: + shapes = {} + for output_name in self._vision_session.output_names: + if ( + hasattr(self._vision_session, "bindings") + and output_name in self._vision_session.binding_index_map + ): + binding_idx = self._vision_session.binding_index_map[output_name] + if hasattr(self._vision_session.bindings[binding_idx], "dims"): + shapes[output_name] = tuple(self._vision_session.bindings[binding_idx].dims) + + if shapes: + config["vision_output_shapes"] = shapes + except Exception as e: + logger.warning(f"Could not derive vision config from session: {e}") + + return config + + def _setup_vision_buffer_skipping(self): + """Skip KV cache and retained state buffers for vision session""" + # Pre-compute skip buffers + self._vision_skip_buffers = [ + x + for x in self._vision_session.input_names + self._vision_session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + self._vision_session.skip_buffers(self._vision_skip_buffers) + + # Pre-compute language skip buffers + self._lang_skip_buffers = [ + x + for x in self._session.input_names + self._session.output_names + if x.startswith("past_") or x.endswith("_RetainedState") + ] + + def run_prefill_for_all_inputs(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs in the prompt queue and updates the decode input. + + Method iterates over the full batch size and for each decode batch ID, it pops the next prompt from the queue. It then runs prefill for the next prompt and updates the decode input with the outputs. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + + """ + for decode_batch_id in range(self.full_batch_size): + next_prompt = prompt_queue.popleft() + + # run prefill for num_chunks + outputs, position_ids, generation_len = self.run_prefill( + next_prompt, generation_len, decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1) + ) + + if self.is_qwen2_5_vl: + _ = self.update_decode_inputs_qwen2_5_vl(outputs, position_ids, generation_len, decode_batch_id) + else: + _ = self.update_decode_input(outputs, position_ids, generation_len, decode_batch_id) + + def update_decode_inputs_qwen2_5_vl(self, outputs, position_ids, generation_len, decode_batch_id=None): + """ + Updates the decode input with the generated values. + Args: + outputs (dict): The outputs of the model. + position_ids (array): The position IDs. + generation_len (int): The generation length. + decode_batch_id (int, optional): The decode batch ID. If None, all values are updated. Defaults to None. + + Returns: + next_token_id (array): The next token ID. + """ + next_token_id = self._fetch_next_token_id(outputs) + + # Store the generated values. + self.decode_input_ids[decode_batch_id or slice(None)] = next_token_id + self.decode_pos_ids[:, decode_batch_id] = position_ids.squeeze(1) + self.generated_ids[decode_batch_id or slice(None), 0] = next_token_id.squeeze(1) + self.generation_len[decode_batch_id or slice(None)] = generation_len + return next_token_id + + def _execute_chunked_prefill( + self, + lang_inputs: Dict[str, np.ndarray], + num_chunks: int, + decode_batch_id: Optional[np.ndarray] = None, + prefill_logit_bs: int = 1, + ) -> Dict[str, np.ndarray]: + """ + Execute chunked prefill with language inputs + + Args: + lang_inputs: Pre-processed language inputs with input_ids, position_ids, etc. + num_chunks: Number of chunks to process + decode_batch_id: Batch ID for continuous batching (optional) + prefill_logit_bs: Batch size for prefill logits + + Returns: + Final prefill outputs + """ + # Set output buffers + self._set_output_buffers(batch_size=prefill_logit_bs, sequence_length=1) + + # Skip buffers for dual-QPC coordination + self._session.skip_buffers(self._lang_skip_buffers) + + # Run chunked prefill + outputs = None + chunk_image_idx = None + + if self.comp_ctx_lengths_prefill is not None: + self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill] + prefill_ccl_id = 0 + lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + for i in range(num_chunks): + input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len] + position_ids_slice = lang_inputs["position_ids"][ + ..., i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len + ] + + chunk_inputs = { + "input_ids": input_ids_slice, + "position_ids": position_ids_slice, + "image_idx": chunk_image_idx if chunk_image_idx is not None else np.array([[0]], dtype=np.int64), + } + + if decode_batch_id is not None: + chunk_inputs["batch_index"] = decode_batch_id + + if "cross_attention_mask" in lang_inputs: + chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"] + + if self.comp_ctx_lengths_prefill is not None: + if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]: + prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1) + lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id] + + chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"] + + outputs = self._session.run(chunk_inputs) + + if "image_idx_output" in outputs: + chunk_image_idx = outputs["image_idx_output"] + + if self._write_io_dir is not None: + write_io_files(lang_inputs, outputs, self._write_io_dir, "prefill", "aic_batch_io", True, False) + + # Prepare decode-time cross_attention_mask + if "cross_attention_mask" in lang_inputs: + bs, _, num_images, img_tiles = lang_inputs["cross_attention_mask"].shape + self._decode_cross_attention_mask = np.ones((bs, 1, num_images, img_tiles), dtype=np.int64) + else: + self._decode_cross_attention_mask = None + + return outputs + + def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_id=None): + """ + Override base class prefill to handle vision processing + + Args: + prompt: Can be string or tuple (image_path, text_prompt) + generation_len: Generation length + prefill_logit_bs: Prefill batch size + decode_batch_id: Batch ID for continuous batching + + Returns: + Same as base class: (outputs, position_ids, generation_len) + """ + # Normalize prompt: TextGeneration passes a list even for batch_size=1 + if isinstance(prompt, list) and len(prompt) > 0 and isinstance(prompt[0], tuple) and len(prompt[0]) == 2: + # Unwrap single (image_path, text_prompt) tuple + if len(prompt) == 1: + prompt = prompt[0] + else: + raise NotImplementedError( + "VisionLanguageGeneration.run_prefill currently supports a single (image, text) pair per call." + ) + # Check if this is a vision-language prompt + if isinstance(prompt, tuple) and len(prompt) == 2: + image_path, text_prompt = prompt + + # Check cache for vision outputs + cache_key = image_path if isinstance(image_path, str) else str(image_path) + if cache_key in self._vision_cache: + lang_inputs, vision_outputs, num_chunks = self._vision_cache[cache_key] + logger.debug(f"Using cached vision outputs for {cache_key}") + else: + # Build language inputs with processor-aware vision/text integration + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=image_path, query=text_prompt, prefill_seq_len=self._prefill_seq_len + ) + # Cache for future use + self._vision_cache[cache_key] = (lang_inputs, vision_outputs, num_chunks) + logger.debug(f"Cached vision outputs for {cache_key}") + + # Set vision buffers in language session + self._session.set_buffers(vision_outputs) + logger.debug(f"Vision buffers set: {list(vision_outputs.keys())}") + self._vision_processed = True + self._vision_outputs = vision_outputs + + # Calculate generation_len consistent with ctx_len + max_gen_len = self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + generation_len = self._fetch_generation_len(generation_len, max_gen_len) + + # Execute chunked prefill + outputs = self._execute_chunked_prefill(lang_inputs, num_chunks, decode_batch_id, prefill_logit_bs) + + self._session.skip_buffers(vision_outputs) + + # Prepare position_ids for decode phase (next position after prefill) + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + return outputs, position_ids_decode, generation_len + else: + # Fall back to base class for text-only + return super().run_prefill(prompt, generation_len, prefill_logit_bs, decode_batch_id) + + def _prepare_vision_language_prompt(self, text_prompt, image_path): + """ + Prepare text prompt with vision context + + This method handles the integration of vision and text inputs + according to the specific model's requirements. + """ + # For most vision-language models, we need to apply the chat template + # that includes both image and text components + try: + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": text_prompt}, + {"type": "image"}, + ], + }, + ] + + # Apply chat template + processed_prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True) + + return processed_prompt + + except Exception as e: + logger.warning(f"Failed to apply chat template: {e}. Using original prompt.") + return text_prompt + + def generate( + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, stream: bool = True, **kwargs + ) -> CloudAI100ExecInfo: + """ + Main generation method maintaining API compatibility with VisionLanguageGeneration + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + stream: Enable streaming output + **kwargs: Additional arguments passed to base class + + Returns: + CloudAI100ExecInfo with results and metrics + + Raises: + ValueError: If images and prompts lengths don't match + """ + if len(images) != len(prompts): + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + + # Clear vision cache for fresh generation + self._vision_cache.clear() + + logger.info(f"Generating for {len(images)} image-prompt pairs") + + # Convert to base class format: list of (image, prompt) tuples + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + # Use base class generate method with vision prompts + if self.full_batch_size is not None: + # Continuous batching mode (new capability) + return self._generate_continuous_batching(vision_prompts, generation_len, stream, **kwargs) + else: + # Regular batching mode + return self._generate_regular_batching(vision_prompts, generation_len, stream, **kwargs) + + def _generate_regular_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Handle regular batching for vision-language generation without creating a second language session""" + batch_results = [] + for i in range(0, len(vision_prompts), self.batch_size): + batch = vision_prompts[i : i + self.batch_size] + + if stream: + print( + f"\nProcessing batch {i // self.batch_size + 1}/{(len(vision_prompts) - 1) // self.batch_size + 1}" + ) + for j, (img, prompt) in enumerate(batch): + print(f"Image: {img}") + print(f"Prompt: {prompt}") + print("Completion:", flush=True, end="") + + # Setup decode storage arrays for this batch (use ctx_len or generation_len whichever is larger) + exec_batch_size = self.batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + self.initialize_decode_inputs( + num_prompts=len(batch), execution_batch_size=exec_batch_size, max_gen_length=max_gen_length + ) + + # Prefill using VLM-aware run_prefill (batch is a list of (image, text)) + start = perf_counter() + outputs, position_ids, generation_len_final = self.run_prefill( + batch, generation_len, prefill_logit_bs=self.batch_size + ) + self.update_decode_input(outputs, position_ids, generation_len_final) + + # Prepare decode + decode_inputs = self.prepare_decode_inputs() + + # Decode loop + loop_start = perf_counter() + num_token = self.run_decode(decode_inputs, generation_len_final, automation=False, streamer=None) + end = perf_counter() + + # Decode generated texts + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + # Latency metrics + total_decode_tokens = num_token + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end + ) + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + # Package result for this batch + batch_results.append( + CloudAI100ExecInfo( + batch_size=self.batch_size, + generated_texts=generated_texts, + generated_ids=self.generated_ids, + perf_metrics=perf_metrics, + ) + ) + + # Aggregate results across batches + return self._aggregate_batch_results(batch_results) + + def _generate_continuous_batching(self, vision_prompts, generation_len, stream, **kwargs): + """Enable continuous batching for vision-language models (new capability)""" + logger.info("Using continuous batching for vision-language generation") + + if stream: + logger.warning("Streaming output not fully supported with continuous batching") + + # Reset vision processing state for new generation + self._vision_processed = False + self._vision_outputs = None + self._vision_outputs_cache = {} + + # Initialize decode inputs + num_prompts = len(vision_prompts) + execution_batch_size = self.full_batch_size + max_gen_length = self._ctx_len if not generation_len else max(self._ctx_len, generation_len) + + self.initialize_decode_inputs(num_prompts, execution_batch_size, max_gen_length) + if self.is_qwen2_5_vl: + self.decode_pos_ids = np.zeros((4, execution_batch_size, 1), np.int64) + + # Create prompt queue + prompt_queue = deque(vision_prompts) + + start = perf_counter() + + # Pre-process ALL vision inputs and cache them + logger.info("Pre-processing all vision inputs...") + for batch_id in range(min(self.full_batch_size, len(vision_prompts))): + img, prompt = vision_prompts[batch_id] + + # Process vision for this slot + lang_inputs, vision_outputs, num_chunks = self._vision_handler.get_processed_inputs( + image_url=img, query=prompt, prefill_seq_len=self._prefill_seq_len + ) + + # Cache vision outputs for this batch slot + self._vision_outputs_cache[batch_id] = { + "vision_outputs": vision_outputs, + "lang_inputs": lang_inputs, + "num_chunks": num_chunks, + } + + logger.debug(f"Cached vision outputs for batch_id {batch_id}") + + # Reset prompt queue for prefill + prompt_queue = deque(vision_prompts) + + self.batch_index = None + + # Run prefill for all inputs using cached vision + self.run_prefill_for_all_inputs_with_cached_vision(prompt_queue, generation_len) + + # Set vision buffers for decode (use first slot's vision for now) + # For identical images, any slot's vision works + cached_slot_0 = self._vision_outputs_cache.get(0) + if cached_slot_0: + self._session.set_buffers(cached_slot_0["vision_outputs"]) + logger.debug("Set vision buffers from slot 0 for decode phase") + + # Now set batch_index for decode phase + self.batch_index = np.arange(self.full_batch_size).reshape(-1, 1) + + loop_start = perf_counter() + decode_pause_time = self.run_continuous_batching_decode(prompt_queue, generation_len) + end = perf_counter() + + generated_texts = self.tokenizer.batch_decode(self.generated_ids, skip_special_tokens=True) + + total_decode_tokens = sum( + np.sum(self.generated_ids[i] != self.tokenizer.pad_token_id) - 1 for i in range(len(vision_prompts)) + ) + prefill_time, decode_perf, total_perf, total_time = calculate_latency( + total_decode_tokens, loop_start, start, end, decode_pause_time + ) + prefill_time /= len(vision_prompts) # Average prefill time for continuous batching + + perf_metrics = PerfMetrics(prefill_time, decode_perf, total_perf, total_time) + + return CloudAI100ExecInfo( + batch_size=1, generated_texts=generated_texts, generated_ids=self.generated_ids, perf_metrics=perf_metrics + ) + + def run_prefill_for_all_inputs_with_cached_vision(self, prompt_queue, generation_len): + """ + Runs prefill for all inputs using pre-cached vision outputs. + + This avoids the vision buffer overwriting issue by using cached vision + outputs instead of processing vision during each prefill iteration. + + Args: + prompt_queue (deque): The queue of prompts. + generation_len (int): The generation length. + """ + for decode_batch_id in range(self.full_batch_size): + # Pop the promt as we are processing + _ = prompt_queue.popleft() + + # Get cached vision outputs for this batch slot + cached = self._vision_outputs_cache.get(decode_batch_id) + if cached: + vision_outputs = cached["vision_outputs"] + lang_inputs = cached["lang_inputs"] + num_chunks = cached["num_chunks"] + + # Set vision buffers for THIS prefill + self._session.set_buffers(vision_outputs) + logger.debug(f"Set vision buffers for batch_id {decode_batch_id} prefill") + + # Run prefill with cached inputs + outputs = self._execute_chunked_prefill( + lang_inputs, + num_chunks, + decode_batch_id=np.array(decode_batch_id, dtype=np.int64).reshape(1, 1), + prefill_logit_bs=1, + ) + + self._session.skip_buffers(vision_outputs.keys()) + + # Calculate position_ids for decode + position_ids_decode = np.max(lang_inputs["position_ids"], axis=-1, keepdims=True) + 1 + + # Calculate generation_len + max_gen_len = ( + self._ctx_len - np.where(lang_inputs["position_ids"] != -1, 1, 0).sum(1, keepdims=True).max() + ) + generation_len_final = self._fetch_generation_len(generation_len, max_gen_len) + + # Update decode inputs + if self.is_qwen2_5_vl: + self.update_decode_inputs_qwen2_5_vl( + outputs, position_ids_decode, generation_len_final, decode_batch_id + ) + else: + self.update_decode_input(outputs, position_ids_decode, generation_len_final, decode_batch_id) + else: + logger.error(f"No cached vision outputs for batch_id {decode_batch_id}") + raise RuntimeError(f"Vision outputs not cached for batch_id {decode_batch_id}") + + def prepare_decode_inputs(self): + """ + Override base class to handle vision-specific decode inputs + """ + decode_inputs = super().prepare_decode_inputs() + + # Add image_idx for vision-language models in CB mode during decode only + if self.batch_index is not None and hasattr(self, "_vision_outputs"): + # image_idx should be a single slot selector; decoder expects shape (1,1) + # Query binding dims if available to be robust + try: + if "image_idx" in getattr(self._session, "binding_index_map", {}): + idx = self._session.binding_index_map["image_idx"] + dims = tuple(self._session.bindings[idx].dims) + decode_inputs["image_idx"] = np.zeros(dims, dtype=np.int64) + else: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + except Exception: + decode_inputs["image_idx"] = np.array([[0]], dtype=np.int64) + + # Include cross_attention_mask during decode if present/required + if hasattr(self, "_decode_cross_attention_mask") and self._decode_cross_attention_mask is not None: + # Decoder specialization expects a single mask (batch dim = 1) + decode_inputs["cross_attention_mask"] = self._decode_cross_attention_mask + + return decode_inputs + + def _aggregate_batch_results(self, batch_results): + """Aggregate results from multiple batches""" + if not batch_results: + raise ValueError("No batch results to aggregate") + + if len(batch_results) == 1: + return batch_results[0] + + # Aggregate multiple batch results + all_generated_texts = [] + all_generated_ids = [] + all_metrics = [] + + for result in batch_results: + if isinstance(result.generated_texts[0], list): + # Flatten nested lists + all_generated_texts.extend([text for batch in result.generated_texts for text in batch]) + else: + all_generated_texts.extend(result.generated_texts) + + if isinstance(result.generated_ids, list): + all_generated_ids.extend(result.generated_ids) + else: + all_generated_ids.append(result.generated_ids) + + all_metrics.append(result.perf_metrics) + + # Average metrics + avg_metrics = PerfMetrics( + prefill_time=np.mean([m.prefill_time for m in all_metrics]), + decode_perf=np.mean([m.decode_perf for m in all_metrics]), + total_perf=np.mean([m.total_perf for m in all_metrics]), + total_time=np.mean([m.total_time for m in all_metrics]), + ) + + return CloudAI100ExecInfo( + batch_size=batch_results[0].batch_size, + generated_texts=all_generated_texts, + generated_ids=all_generated_ids, + perf_metrics=avg_metrics, + ) + + def generate_stream_tokens( + self, images: List[str], prompts: List[str], generation_len: Optional[int] = None, **kwargs + ): + """ + Enable token-by-token streaming for vision models (new capability) + + Args: + images: List of image URLs/paths + prompts: List of text prompts + generation_len: Max generation length + **kwargs: Additional arguments + + Yields: + List of decoded tokens for each batch position + + Raises: + NotImplementedError: If continuous batching is enabled + """ + if self.full_batch_size is not None: + raise NotImplementedError("Token streaming not supported with continuous batching for VLM") + + if len(images) != len(prompts): + raise ValueError(f"Number of images ({len(images)}) must match number of prompts ({len(prompts)})") + + logger.info(f"Starting token streaming for {len(images)} image-prompt pairs") + + vision_prompts = [(img, prompt) for img, prompt in zip(images, prompts)] + + text_gen = TextGeneration( + tokenizer=self.tokenizer, + qpc_path=self._qpc_path, + ctx_len=self._ctx_len, + device_id=self.device_id, + enable_debug_logs=self.enable_debug_logs, + is_tlm=self.is_tlm, + include_sampler=self.include_sampler, + return_pdfs=self.return_pdfs, + sampling_params=self.sampling_params, + ) + + text_gen._qaic_model = self + + # Yield tokens as they're generated + for tokens in text_gen.generate_stream_tokens(vision_prompts, generation_len, **kwargs): + yield tokens + + def __repr__(self): + """String representation of the class""" + return ( + f"VisionLanguageGeneration(" + f"batch_size={self.batch_size}, " + f"ctx_len={self._ctx_len}, " + f"continuous_batching={'enabled' if self.full_batch_size else 'disabled'}, " + f"sampling={'enabled' if self.include_sampler else 'disabled'})" + ) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 0d123d25f..7b3c620aa 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -549,3 +549,126 @@ def update( ctx_v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) v_out = torch.where((is_sliding_layer & (position_ids.max() >= (layer_ctx_len - 1))), v_out, ctx_v_out) return k_out, v_out + + +# This is a hack for now, until we get to merging this code with HybridCache class, +# We don't really need to inherit transformers classes as their cache classes are made to work with pytorch and +# ours are made to work with AIC +class QEffHybridCacheForGPTOSS: + def __init__(self, config, batch_size, max_cache_len, sliding_window_len): + self.max_cache_len = max_cache_len + self.batch_size = batch_size + self.sliding_window_len = sliding_window_len + self.key_cache: List[torch.Tensor] = [] + self.value_cache: List[torch.Tensor] = [] + + @classmethod + def from_legacy_cache( + cls, config, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + ) -> "HybridCache": + """Converts a cache in the legacy cache format into an equivalent `DynamicCache`. Used for + backward compatibility.""" + cache = cls( + config, + batch_size=past_key_values[0][0].shape[0], + max_cache_len=past_key_values[1][0].shape[2], + sliding_window_len=past_key_values[0][0].shape[2], + ) + if past_key_values is not None: + for layer_idx in range(len(past_key_values)): + key_states, value_states = past_key_values[layer_idx] + cache.update(key_states, value_states, layer_idx) + return cache + + def __len__(self): + """ + Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds + to the number of layers in the model. + """ + return len(self.key_cache) + + def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: + """Returns the sequence length of the cached states. A layer index can be optionally passed.""" + # TODO: deprecate this function in favor of `cache_position` + is_empty_layer = ( + len(self.key_cache) == 0 # no cache in any layer + or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it + or len(self.key_cache[layer_idx]) == 0 # the layer has no cache + ) + layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0 + return layer_seq_length + + def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: + """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format. Used for + backward compatibility.""" + legacy_cache = () + for layer_idx in range(len(self)): + legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) + return legacy_cache + + def update( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + sliding_window = cache_kwargs.get("sliding_window") + batch_index = cache_kwargs.get("batch_index", None) # Check and fetch batch index value from the kwargs + comp_ctx_len = cache_kwargs.get("CCL") + + if is_sliding_layer: + kv_position_ids = torch.where(position_ids == -1, position_ids, position_ids % sliding_window) + else: + kv_position_ids = position_ids + + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + invalid_scatter_index = torch.iinfo(torch.int32).max + scatter_position_ids = torch.where(kv_position_ids < 0, invalid_scatter_index, kv_position_ids) + else: + scatter_position_ids = kv_position_ids + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Original Gather + if is_sliding_layer: + ctx_len = k_out.shape[2] + else: + ctx_len = comp_ctx_len + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + if torch.onnx.is_in_onnx_export(): + invalid_idx_value = torch.iinfo(torch.int32).max + else: + invalid_idx_value = 0 + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + return k_out, v_out diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py new file mode 100644 index 000000000..576326921 --- /dev/null +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -0,0 +1,746 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +from typing import Callable, Optional, Union + +import torch +from torch import nn +from torch.nn import functional as F +from transformers.cache_utils import Cache +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, +) +from transformers.models.gpt_oss.modeling_gpt_oss import ( + GptOssAttention, + GptOssConfig, + GptOssDecoderLayer, + GptOssExperts, + GptOssForCausalLM, + GptOssMLP, + GptOssModel, + GptOssRotaryEmbedding, + repeat_kv, +) +from transformers.processing_utils import Unpack +from transformers.utils import TransformersKwargs + +from QEfficient.transformers.cache_utils import QEffHybridCacheForGPTOSS +from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask +from QEfficient.utils import constants +from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE + + +class QEffGptOssExperts(GptOssExperts): + def __qeff_init__(self): + self.gate_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.up_proj = nn.Parameter(torch.empty(self.num_experts, self.hidden_size, self.expert_dim)) + self.gate_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) + + +class QEffGptOssMLP(GptOssMLP): + def alt_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + # ------------------- Gather based, weights as activation approach --------------- + def forward_weights_as_activation(self, hidden_states): + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts + gate_up_proj = self.experts.gate_up_proj[router_indices.flatten()] + gate_up_proj_bias = self.experts.gate_up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Apply Chosen Experts (without routing weights first) + # expert_in = hidden_states.repeat_interleave(self.router.top_k, dim=0) + # expert_in = expert_in.view(-1, 1, self.experts.hidden_size) + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + gate_up = torch.bmm(expert_in, gate_up_proj) + gate_up_proj_bias.unsqueeze(1) + gate, up = gate_up[..., ::2], gate_up[..., 1::2] + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation (This is before on Llama4) + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- + def forward(self, hidden_states): + # print("Seperate Split, Up, Gate Projections") + bs, seq_len, _ = hidden_states.shape + hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + router_top_value, router_indices = torch.topk(router_logits, self.router.top_k, dim=-1) + router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) + + # GATHER - collect weights for selected experts (separate gate and up projections) + gate_proj = self.experts.gate_proj[router_indices.flatten()] + gate_proj_bias = self.experts.gate_proj_bias[router_indices.flatten()] + up_proj = self.experts.up_proj[router_indices.flatten()] + up_proj_bias = self.experts.up_proj_bias[router_indices.flatten()] + down_proj = self.experts.down_proj[router_indices.flatten()] + down_proj_bias = self.experts.down_proj_bias[router_indices.flatten()] + + # Reshape for bmm: (bs*seq_len*top_k, 1, hidden_size) + expert_in = ( + hidden_states.unsqueeze(1) + .expand(-1, self.router.top_k, -1) + .contiguous() + .view(-1, 1, self.experts.hidden_size) + ) + + # Apply gate and up projections separately using bmm + gate = torch.bmm(expert_in, gate_proj) + gate_proj_bias.unsqueeze(1) + up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1) + + # Apply activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + gated_output = (up + 1) * glu + + # Down projection + experts_out = torch.bmm(gated_output, down_proj) + down_proj_bias.unsqueeze(1) + experts_out = experts_out.view(bs * seq_len, self.router.top_k, self.experts.hidden_size) + + # Apply routing weights AFTER expert computation + experts_out = experts_out * router_top_value.unsqueeze(-1) + experts_out = experts_out.sum(dim=1) + + return experts_out, router_logits + + def optimized_moe_forward(self, hidden_states: torch.Tensor): + B, S, H = hidden_states.shape + T = B * S + hidden_states = hidden_states.view(T, H) + + # Router computation + router_logits = F.linear(hidden_states, self.router.weight, self.router.bias) + + # Top-k selection + top_w, selected_experts = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + # Creating experts mask and routing weights masked + awesome_experts_mask_1 = ( + torch.nn.functional.one_hot(selected_experts[:, 0], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_2 = ( + torch.nn.functional.one_hot(selected_experts[:, 1], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_3 = ( + torch.nn.functional.one_hot(selected_experts[:, 2], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + awesome_experts_mask_4 = ( + torch.nn.functional.one_hot(selected_experts[:, 3], num_classes=self.experts.num_experts) + .bool() + .T.unsqueeze(-1) + ) + + gateupout1 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout2 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout3 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + gateupout4 = torch.zeros(hidden_states.shape[0], self.experts.intermediate_size) # T, hs + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + + # Gate and Up projections + gate = (hidden_states @ W_g) + b_g # [T, I] + up = (hidden_states @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + gateupout1 += torch.where(awesome_experts_mask_1[e], intermediate, torch.zeros_like(gateupout1)) + gateupout2 += torch.where(awesome_experts_mask_2[e], intermediate, torch.zeros_like(gateupout2)) + gateupout3 += torch.where(awesome_experts_mask_3[e], intermediate, torch.zeros_like(gateupout3)) + gateupout4 += torch.where(awesome_experts_mask_4[e], intermediate, torch.zeros_like(gateupout4)) + + concat_down = torch.zeros((self.router.top_k, T, H)) + concat_mask = torch.cat( + ( + awesome_experts_mask_1.unsqueeze(0), + awesome_experts_mask_2.unsqueeze(0), + awesome_experts_mask_3.unsqueeze(0), + awesome_experts_mask_4.unsqueeze(0), + ), + dim=0, + ) + + concat_gateout = torch.cat( + (gateupout1.unsqueeze(0), gateupout2.unsqueeze(0), gateupout3.unsqueeze(0), gateupout4.unsqueeze(0)), dim=0 + ) + + for e in range(self.experts.num_experts): + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Down projection + down_out = (concat_gateout @ W_d) + b_d # [T, H] + + concat_down += torch.where(concat_mask[:, e, :], down_out, torch.zeros_like(concat_down)) + + downout1, downout2, downout3, downout4 = concat_down[0], concat_down[1], concat_down[2], concat_down[3] + hidden_states = ( + downout1 * top_w[:, 0].unsqueeze(-1) + + downout2 * top_w[:, 1].unsqueeze(-1) + + downout3 * top_w[:, 2].unsqueeze(-1) + + downout4 * top_w[:, 3].unsqueeze(-1) + ).reshape(B, S, H) + + # original shape [B, S, H] + return hidden_states, router_logits + + +# Can be replaced with llama/modeling_llama.py::QEffLlamaRotaryEmbedding but keeping it following transformers ideology +class QEffGptOssRotaryEmbedding(GptOssRotaryEmbedding): + """ + Copied from LlamaForCausalLM: https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py + The only differences are: + - Add static sin/cos computations. + """ + + def __init__(self, config: GptOssConfig, device=None): + super().__init__(config=config) + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=self.original_max_seq_len, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + self.sin_cached[:seq_len].to(dtype=x.dtype) * self.attention_scaling, + ) + + +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def qeff_apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding with Multimodal Sections to the query and key tensors (https://qwenlm.github.io/blog/qwen2-vl/). + + Explanation: + Multimodal 3D rotary position embedding is an extension to 1D rotary position embedding. The input embedding + sequence contains vision (images / videos) embedding and text embedding or just contains text embedding. For + vision embedding part, we apply rotary position embedding on temporal, height and width dimension seperately. + Here we split the channel dimension to 3 chunks for the temporal, height and width rotary position embedding. + For text embedding part, we just apply 1D rotary position embedding. The three rotary position index (temporal, + height and width) of text embedding is always the same, so the text embedding rotary position embedding has no + difference with modern LLMs. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + mrope_section(`List(int)`): + Multimodal rope section is for channel dimension of temporal, height and width in rope calculation. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + + return q_embed.to(q.dtype), k_embed.to(k.dtype) + + +def eager_attention_forward( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling + if attention_mask is not None: + attn_weights = torch.where( + attention_mask, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), attn_weights + ) + + sinks = module.sinks.reshape(1, -1, 1, 1).expand(query.shape[0], -1, query.shape[-2], -1) + combined_logits = torch.cat([attn_weights, sinks], dim=-1) + + # This was not in the original implementation and slightly affect results; it prevents overflow in BF16/FP16 + # when training with bsz>1 we clamp max values. + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + probs = F.softmax(combined_logits, dim=-1, dtype=combined_logits.dtype) + scores = probs[..., :-1] # we drop the sink here + attn_weights = nn.functional.dropout(scores, p=dropout, training=module.training) + attn_output = torch.matmul(attn_weights, value_states) + attn_output = attn_output.transpose(1, 2).contiguous() + return attn_output, attn_weights + + +class QEffGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + if comp_ctx_lengths is not None: + attention_mask = attention_mask[:, :, :, : comp_ctx_lengths.shape[-1]] + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + "CCL": attention_mask.shape[-1], + } + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + +class QEffGptOssDecoderLayer(GptOssDecoderLayer): + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores + # alth, _ = self.mlp.alt_forward(hidden_states) + hidden_states = hidden_states.reshape(residual.shape) + hidden_states = residual + hidden_states + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class QEffGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.sliding_window_len, + sliding_window=past_key_values.sliding_window_len, + ) + + hidden_states = inputs_embeds + # position_embeddings = self.rotary_emb(hidden_states, position_ids) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + +class QEffGptOssForCausalLM(GptOssForCausalLM): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + comp_ctx_lengths: Optional[torch.LongTensor] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeCausalLMOutputWithPast: + r""" + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from transformers import AutoTokenizer, GptOssForCausalLM + + >>> model = GptOssForCausalLM.from_pretrained("mistralai/GptOss-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/GptOss-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + output_router_logits = ( + output_router_logits if output_router_logits is not None else self.config.output_router_logits + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs: MoeModelOutputWithPast = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + output_router_logits=output_router_logits, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][torch.arange(position_ids.shape[0]).view(-1, 1), logit_index] + logits = self.lm_head(hidden_states) + logits = logits.float() + + return MoeCausalLMOutputWithPast( + loss=None, + aux_loss=None, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + router_logits=outputs.router_logits, + ) + + def get_pkv_dynamic_axes( + self, + ): + pkv_dynamic_axes = [] + for layer_type in self.config.layer_types: + if layer_type == "sliding_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) + elif layer_type == "full_attention": + pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) + return pkv_dynamic_axes + + def get_specializations( + self, + batch_size: int, + prefill_seq_len: int, + ctx_len: int, + ): + batch_size = batch_size if batch_size else 1 + prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN + ctx_len = ctx_len if ctx_len else constants.CTX_LEN + + specializations = [ + { + "batch_size": batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "sliding_window": 128, + }, + { + "batch_size": batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "sliding_window": 128, + }, + ] + return specializations diff --git a/QEfficient/transformers/models/llama4/modeling_llama4.py b/QEfficient/transformers/models/llama4/modeling_llama4.py index 0fbdbea5f..78dee9fe5 100644 --- a/QEfficient/transformers/models/llama4/modeling_llama4.py +++ b/QEfficient/transformers/models/llama4/modeling_llama4.py @@ -856,6 +856,7 @@ def forward( position_ids, image_idx, past_key_values, + batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.language_model.get_input_embeddings()(input_ids) @@ -871,6 +872,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, use_cache=True, ) next_idx = (indices1.max() + 1).unsqueeze(0).unsqueeze(0) @@ -935,6 +937,9 @@ def get_specializations( **compiler_options, ): max_num_tiles = compiler_options.pop("max_num_tiles", None) + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + if max_num_tiles is None: logger.warning( "User should pass `max_num_tiles` to compile API to fix the dynamic axes `pixel_values`, you can get more info by calling get_inputs_info function!, Since its not found setting its value to 17" @@ -1013,29 +1018,87 @@ def get_specializations( } ) - else: - lang = [ - { - "batch_size": batch_size, + if comp_ctx_lengths_prefill is not None: + lang = [] + + for i in range(0, len(comp_ctx_lengths_prefill)): + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], "max_num_tiles": max_num_tiles, "img_size": img_size, "vision_size": vision_size, "chunk_length": prefill_seq_len, "chunk_ctx_len": chunk_ctx_len, - }, - { - "batch_size": batch_size, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, + "comp_ctx_lengths": comp_ctx_lengths_decode[i], "max_num_tiles": max_num_tiles, "img_size": img_size, "vision_size": vision_size, "chunk_length": prefill_seq_len, "chunk_ctx_len": chunk_ctx_len, - }, - ] + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) + + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "max_num_tiles": max_num_tiles, + "img_size": img_size, + "vision_size": vision_size, + "chunk_length": prefill_seq_len, + "chunk_ctx_len": chunk_ctx_len, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -1046,7 +1109,9 @@ def get_specializations( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes vision_dynamic_axes = {} lang_dynamic_axes = {} @@ -1121,7 +1186,9 @@ def get_dummy_pkv_cache(self, config, batch_size, seq_len): past_key_values.append(pkv) return past_key_values - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_dummy_inputs( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): if vis_cfg := getattr(self.config, "vision_config", None): img_size = getattr(vis_cfg, "image_size", 336) else: @@ -1178,6 +1245,9 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl for kv in ["key", "value"]: lang_inputs["past_key_values"][i].append(torch.zeros(past_key_values[0][0].shape, dtype=torch.float32)) + if continuous_batching: + lang_inputs["batch_index"] = torch.arange(bs).view(bs, 1) + if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index d0b420990..160d942a7 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -857,6 +857,8 @@ class _QEffAutoModelForImageTextToTextDualQPC: def __init__( self, model: nn.Module, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -879,7 +881,7 @@ def __init__( self.model = model self.config = model.config - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs) self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs) @@ -901,7 +903,7 @@ def model_name(self) -> str: return mname @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs): """ Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path. @@ -935,9 +937,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs): return cls( model, pretrained_model_name_or_path=pretrained_model_name_or_path, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, + qaic_config=qaic_config, **kwargs, ) @@ -994,8 +994,23 @@ def export( List[str] A list containing the paths to the generated ONNX graph files for both components. """ - inputs = self.model.get_dummy_inputs(self.comp_ctx_lengths_decode, kv_offload=True) - dynamic_axes = self.model.get_onnx_dynamic_axes(self.comp_ctx_lengths_decode, kv_offload=True) + # TODO This is a temporary change as continous batching is enabled only for few models. Once support is added for all the models this exception handing can be removed. + try: + inputs = self.model.get_dummy_inputs( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, + continuous_batching=self.continuous_batching, + comp_ctx_lengths=self.comp_ctx_lengths_decode, + ) + except TypeError: + inputs = self.model.get_dummy_inputs(kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes( + kv_offload=True, comp_ctx_lengths=self.comp_ctx_lengths_decode + ) output_names = self.model.get_output_names(kv_offload=True) self.vision_model.export( @@ -1099,6 +1114,11 @@ def compile( self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") + # For supporting VLLM and Disaggregated with CCL + if "comp_ctx_lengths_prefill" in compiler_options: + self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") + self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") + specializations, compiler_options = self.model.get_specializations( batch_size=batch_size, prefill_seq_len=prefill_seq_len, @@ -1218,6 +1238,32 @@ def generate( if not runtime_ai100: raise NotImplementedError("PyTorch execution is not supported yet for this model!") + # Use VisionLanguageGeneration for image-prompt pairs + if (processor and images) or (tokenizer and prompts): + # Create VisionLanguageGeneration instance + batch_size_comp, ctx_len_comp, fbs = get_compilation_dims(self.lang_model.qpc_path) + vlm_gen = VisionLanguageGeneration( + qeff_model=self, + lang_qpc_path=self.lang_model.qpc_path, + vision_qpc_path=self.vision_model.qpc_path, + tokenizer=tokenizer, + processor=processor, + device_id=device_ids, # if device_ids is not None else [0], + ctx_len=ctx_len_comp, + full_batch_size=fbs, + comp_ctx_lengths_prefill=self.comp_ctx_lengths_prefill, + comp_ctx_lengths_decode=self.comp_ctx_lengths_decode, + ) + + # Call generate method + return vlm_gen.generate( + images=images, + prompts=prompts, + generation_len=generation_len, + stream=streamer is not None, + ) + + # Fallback to kv_offload_generate for direct inputs (backward compatibility) return self.kv_offload_generate( inputs=inputs, device_ids=device_ids, streamer=streamer, generation_len=generation_len ) @@ -1359,9 +1405,7 @@ def kv_offload_generate( prefill_ccl_id = 0 lang_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_prefill[prefill_ccl_id] - # Prepare inputs for prefill - chunk_inputs = lang_inputs.copy() - prefill_start = perf_counter() + lang_start = perf_counter() # Run prefill chunk_inputs = lang_inputs.copy() @@ -1470,6 +1514,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal def __init__( self, model: nn.Module, + qaic_config: Optional[dict] = None, **kwargs, ): """ @@ -1491,7 +1536,7 @@ def __init__( raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.") super().__init__(model, **kwargs) - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) # to handle internvl models if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"): @@ -1509,6 +1554,7 @@ def __init__( def from_pretrained( cls, pretrained_model_name_or_path, + qaic_config: Optional[dict] = None, *args, **kwargs, ): @@ -1554,9 +1600,7 @@ def from_pretrained( return cls( model, pretrained_model_name_or_path=pretrained_model_name_or_path, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, + qaic_config=qaic_config, **kwargs, ) @@ -1580,8 +1624,8 @@ def export( str Path to the generated ONNX graph file. """ - inputs = self.model.get_dummy_inputs(self.comp_ctx_lengths_decode) - dynamic_axes = self.model.get_onnx_dynamic_axes(self.comp_ctx_lengths_decode) + inputs = self.model.get_dummy_inputs(comp_ctx_lengths=self.comp_ctx_lengths_decode) + dynamic_axes = self.model.get_onnx_dynamic_axes(comp_ctx_lengths=self.comp_ctx_lengths_decode) output_names = self.model.get_output_names() return self._export(inputs, output_names, dynamic_axes, export_dir=export_dir) @@ -1661,6 +1705,11 @@ def compile( self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") + # For supporting VLLM and Disaggregated with CCL + if "comp_ctx_lengths_prefill" in compiler_options: + self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") + self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") + # Get specializations from modelling file # TODO: expose this via the auto class as well specializations, compiler_options = self.model.get_specializations( @@ -2030,7 +2079,14 @@ class QEFFAutoModelForImageTextToText: _hf_auto_class = AutoModelForImageTextToText - def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs): + def __new__( + self, + model: nn.Module, + kv_offload: Optional[bool] = True, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): """ Instantiate the appropriate internal class for single or dual QPC mode. @@ -2054,13 +2110,22 @@ def __new__(self, model: nn.Module, kv_offload: Optional[bool] = True, **kwargs) self.comp_ctx_lengths_decode = kwargs.get("comp_ctx_lengths_decode", None) if kv_offload: - return _QEffAutoModelForImageTextToTextDualQPC(model, **kwargs) + return _QEffAutoModelForImageTextToTextDualQPC( + model, continuous_batching, qaic_config=qaic_config, **kwargs + ) else: - return _QEFFAutoModelForImageTextToTextSingleQPC(model, **kwargs) + return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs) @classmethod @with_replaced_quantizers - def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optional[bool] = None, **kwargs): + def from_pretrained( + cls, + pretrained_model_name_or_path: str, + kv_offload: Optional[bool] = None, + continuous_batching: bool = False, + qaic_config: Optional[dict] = None, + **kwargs, + ): """ Load a QEfficient image-text-to-text model from a pretrained HuggingFace model or local path. @@ -2100,18 +2165,12 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, kv_offload: Optiona kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False}) - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len = process_ccl_specializations( - kwargs - ) - model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs) return cls( model, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, + qaic_config=qaic_config, **kwargs, ) @@ -2212,7 +2271,7 @@ def __init__( self.model, transformed = SpDTransform.apply(self.model, qaic_config, **kwargs) self.is_tlm = transformed - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, _, _ = process_ccl_specializations(kwargs) + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) self.hash_params["qeff_auto_class"] = self.__class__.__name__ @@ -2322,12 +2381,9 @@ def from_pretrained( if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( model, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, - prefill_seq_len=prefill_seq_len, kv_offload=kv_offload, pretrained_model_name_or_path=pretrained_model_name_or_path, + qaic_config=qaic_config, **kwargs, ) return cls( @@ -2389,7 +2445,7 @@ def export(self, export_dir: Optional[str] = None) -> str: "position_ids": {0: "batch_size", 1: "seq_len"}, } if self.comp_ctx_lengths_prefill is not None: - example_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + example_inputs["comp_ctx_lengths"] = torch.randint(0, 512, (512,), dtype=torch.long) dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} if len(kv_cache_shape) == 3: # For GPTBigCode arch the pkv is 3d @@ -2563,15 +2619,21 @@ def build_prefill_specialization( Dict[str, Union[int, str]] A dictionary defining the prefill specialization. """ - spec = { - "batch_size": 1 if self.continuous_batching else batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "num_logits_to_keep": 1 if self.is_tlm else None, - } + if hasattr(self.model, "get_specializations"): + spec = self.model.get_specializations( + batch_size=1 if self.continuous_batching else batch_size, + prefill_seq_len=prefill_seq_len, + ctx_len=ctx_len, + )[0] + else: + spec = { + "batch_size": 1 if self.continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + } if comp_ctx_lengths is not None: spec["comp_ctx_lengths"] = comp_ctx_lengths - + spec["num_logits_to_keep"] = 1 if self.is_tlm else None if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -2618,15 +2680,23 @@ def build_decode_specialization( if not self.continuous_batching: # or batch_size == 1 return None # Avoid duplication with prefill - spec = { - "batch_size": full_batch_size if self.continuous_batching else batch_size, - "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, - "ctx_len": ctx_len, - "num_logits_to_keep": (num_speculative_tokens + 1) if self.is_tlm else None, - } + if hasattr(self.model, "get_specializations"): + spec = self.model.get_specializations( + batch_size=full_batch_size if self.continuous_batching else batch_size, + prefill_seq_len=(num_speculative_tokens + 1) if self.is_tlm else 1, + ctx_len=ctx_len, + )[1] + else: + spec = { + "batch_size": full_batch_size if self.continuous_batching else batch_size, + "seq_len": (num_speculative_tokens + 1) if self.is_tlm else 1, + "ctx_len": ctx_len, + } if comp_ctx_lengths is not None: spec["comp_ctx_lengths"] = comp_ctx_lengths + spec["num_logits_to_keep"] = (num_speculative_tokens + 1) if self.is_tlm else None + if self.continuous_batching: spec["full_batch_size"] = kv_cache_batch_size else: @@ -2729,22 +2799,24 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ - # For comp_ctx_lengths Disaggregated applications - if self.comp_ctx_lengths_prefill is None: - if comp_ctx_lengths_prefill is not None: + + # For supporting VLLM and Disaggregated with CCL + if "comp_ctx_lengths_prefill" in compiler_options and "comp_ctx_lengths_decode" in compiler_options: + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill") + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode") + if isinstance(comp_ctx_lengths_prefill, str): import ast - if isinstance(comp_ctx_lengths_prefill, str): - try: - # Safely evaluate the string to a Python list for disaggregated input - self.comp_ctx_lengths_prefill = ast.literal_eval(comp_ctx_lengths_prefill) - self.comp_ctx_lengths_decode = ast.literal_eval(comp_ctx_lengths_decode) - - except (ValueError, SyntaxError): - raise ValueError("Invalid format for comp_ctx_lengths. Expected a list-like string.") - else: - self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill - self.comp_ctx_lengths_decode = comp_ctx_lengths_decode + try: + # Safely evaluate the string to a Python list for disaggregated input + self.comp_ctx_lengths_prefill = ast.literal_eval(comp_ctx_lengths_prefill) + self.comp_ctx_lengths_decode = ast.literal_eval(comp_ctx_lengths_decode) + + except (ValueError, SyntaxError): + raise ValueError("Invalid format for comp_ctx_lengths. Expected a list-like string.") + else: + self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill + self.comp_ctx_lengths_decode = comp_ctx_lengths_decode # --- Validation --- if prefill_only is not None and not isinstance(prefill_only, bool): diff --git a/QEfficient/transformers/models/molmo/modeling_molmo.py b/QEfficient/transformers/models/molmo/modeling_molmo.py index a0b20fe28..48a8fbf7b 100644 --- a/QEfficient/transformers/models/molmo/modeling_molmo.py +++ b/QEfficient/transformers/models/molmo/modeling_molmo.py @@ -281,7 +281,6 @@ def attention( if layer_past is not None: if comp_ctx_lengths is not None: attention_bias = attention_bias[:, :, :, : comp_ctx_lengths.shape[-1]] - print(f"attention_bias: {attention_bias.shape}") # sin and cos are specific to RoPE models; cache_position needed for the static cache cache_kwargs = { "sin": sin, diff --git a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index ac91d5477..2628c51a8 100644 --- a/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/QEfficient/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -702,6 +702,7 @@ def forward( position_ids, image_idx, past_key_values, + batch_index: Optional[torch.LongTensor] = None, comp_ctx_lengths: Optional[List[int]] = None, ): inputs_embeds = self.model.get_input_embeddings()(input_ids) @@ -718,6 +719,7 @@ def forward( position_ids=position_ids, past_key_values=past_key_values, comp_ctx_lengths=comp_ctx_lengths, + batch_index=batch_index, use_cache=True, ) @@ -736,7 +738,13 @@ def get_qeff_vision_encoder(self): def get_qeff_language_decoder(self): return QEffQwen_2_5_vl_DecoderWrapper(self) - def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, **kwargs): + def get_dummy_inputs( + self, + comp_ctx_lengths: Optional[List[int]] = None, + kv_offload: bool = False, + continuous_batching: bool = False, + **kwargs, + ): inputs_shapes = {} inputs_shapes["input_ids"] = (constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE, constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN) @@ -787,6 +795,9 @@ def get_dummy_inputs(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offl if comp_ctx_lengths is not None: lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + if comp_ctx_lengths is not None: + lang_inputs["comp_ctx_lengths"] = torch.randint(0, 100, (40,), dtype=torch.long) + inputs = {} if kv_offload: inputs["vision"] = vision_inputs @@ -810,6 +821,9 @@ def get_specializations( kv_offload: bool = False, **compiler_options, ): + comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill", None) + comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode", None) + if height is None or width is None: height = 1365 width = 2048 @@ -888,46 +902,77 @@ def smart_resize( "grid_w": grid_w, } ] + if comp_ctx_lengths_prefill is not None: lang = [] for i in range(0, len(comp_ctx_lengths_prefill)): - lang.append( - { - "batch_size": batch_size, - "seq_len": prefill_seq_len, - "ctx_len": ctx_len, - "vision_size": vision_size, - "comp_ctx_lengths": comp_ctx_lengths_prefill[i], - } - ) - - for i in range(0, len(comp_ctx_lengths_decode)): - lang.append( - { - "batch_size": batch_size, - "seq_len": "1", - "ctx_len": ctx_len, - "vision_size": vision_size, - "comp_ctx_lengths": comp_ctx_lengths_decode[i], - } - ) - - else: - lang = [ - { - "batch_size": batch_size, + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, "seq_len": prefill_seq_len, "ctx_len": ctx_len, "vision_size": vision_size, - }, - { - "batch_size": batch_size, + "comp_ctx_lengths": comp_ctx_lengths_prefill[i], + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang.append(lang_prefill) + + for i in range(0, len(comp_ctx_lengths_decode)): + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, "seq_len": "1", "ctx_len": ctx_len, "vision_size": vision_size, - }, - ] + "comp_ctx_lengths": comp_ctx_lengths_decode[i], + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang.append(lang_decode) + else: + lang_prefill = { + "batch_size": 1 if continuous_batching else batch_size, + "seq_len": prefill_seq_len, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_prefill["full_batch_size"] = kv_cache_batch_size + else: + lang_prefill["batch_size"] = kv_cache_batch_size + if full_batch_size: + lang_prefill["full_batch_exec_size"] = full_batch_size + + lang_decode = { + "batch_size": full_batch_size if continuous_batching else batch_size, + "seq_len": 1, + "ctx_len": ctx_len, + "vision_size": vision_size, + "vision_batch_size": batch_size, + } + + if continuous_batching: + lang_decode["full_batch_size"] = kv_cache_batch_size + else: + lang_decode["batch_size"] = kv_cache_batch_size + + lang = [] + lang.append(lang_prefill) + lang.append(lang_decode) specializations = {} @@ -938,7 +983,9 @@ def smart_resize( else: return lang, compiler_options - def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False): + def get_onnx_dynamic_axes( + self, comp_ctx_lengths: Optional[List[int]] = None, kv_offload: bool = False, continuous_batching: bool = False + ): # Define dynamic axes num_layers = self.config.num_hidden_layers @@ -960,6 +1007,9 @@ def get_onnx_dynamic_axes(self, comp_ctx_lengths: Optional[List[int]] = None, kv if comp_ctx_lengths is not None: lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + if comp_ctx_lengths is not None: + lang_dynamic_axes["comp_ctx_lengths"] = {0: "comp_ctx_lengths"} + dynamic_axes = {} if kv_offload: dynamic_axes["vision"] = vision_dynamic_axes diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 6cb54a6c5..052336fae 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -6,14 +6,14 @@ # ----------------------------------------------------------------------------- -def process_ccl_specializations(kwargs): - ccl_prefill = kwargs.pop("comp_ctx_lengths_prefill", None) - ccl_decode = kwargs.pop("comp_ctx_lengths_decode", None) - ctx_len = kwargs.pop("ctx_len", None) - prefill_seq_len = kwargs.pop("prefill_seq_len", 128) +def process_ccl_specializations(qaic_config): + ccl_prefill = qaic_config.get("comp_ctx_lengths_prefill", None) + ccl_decode = qaic_config.get("comp_ctx_lengths_decode", None) + ctx_len = qaic_config.get("ctx_len", None) + prefill_seq_len = qaic_config.get("prefill_seq_len", 128) if ccl_prefill is None or ccl_decode is None: - return None, None, ctx_len, prefill_seq_len + return None, None if ctx_len is None: raise TypeError("`ctx_len` is required when loading the model with CCL.") @@ -22,7 +22,7 @@ def process_ccl_specializations(kwargs): # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. ccl_union_all = sorted(set(ccl_prefill + ccl_decode)) ccl_union_all = [min(x, ctx_len) for x in ccl_union_all] - return ccl_union_all, ccl_union_all, ctx_len, prefill_seq_len + return ccl_union_all, ccl_union_all # Step 1: Cap values to ctx_len ccl_prefill = [min(x, ctx_len) for x in ccl_prefill] @@ -46,4 +46,4 @@ def process_ccl_specializations(kwargs): updated_prefill.sort() ccl_decode.sort() - return updated_prefill, ccl_decode, ctx_len, prefill_seq_len + return updated_prefill, ccl_decode diff --git a/examples/ccl_gpt_oss.py b/examples/ccl_gpt_oss.py new file mode 100644 index 000000000..0d8583771 --- /dev/null +++ b/examples/ccl_gpt_oss.py @@ -0,0 +1,50 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +ctx_len = 512 +#Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [128,500] +#Set the list of ccl during decoding process +comp_ctx_lengths_decode = [256,ctx_len] + + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + qaic_config={ + "comp_ctx_lengths_prefill":comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode":comp_ctx_lengths_decode, + "ctx_len":ctx_len, + "prefill_seq_len":1, #Passing prefill_seq_len is mandatory for CCL goal in moe models. Currently we can get best perf using PL=1. + }, +) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +onnx_model_path = qeff_model.export() +qpc_path = qeff_model.compile( + prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. + ctx_len=ctx_len, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=4, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + prompts="Who is your creator? and What all you are allowed to do?", + generation_len=256, +) diff --git a/examples/ccl_image_text_to_text_inference.py b/examples/ccl_image_text_to_text_inference.py index 932a407b9..2af386338 100644 --- a/examples/ccl_image_text_to_text_inference.py +++ b/examples/ccl_image_text_to_text_inference.py @@ -43,9 +43,11 @@ def run_model( token=token, attn_implementation="eager", kv_offload=kv_offload, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, ) ## STEP - 2 Export & Compile the Model diff --git a/examples/ccl_llama4_CB_example_vision_lang.py b/examples/ccl_llama4_CB_example_vision_lang.py new file mode 100644 index 000000000..bc4ad3dd5 --- /dev/null +++ b/examples/ccl_llama4_CB_example_vision_lang.py @@ -0,0 +1,109 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import transformers +from transformers import AutoConfig, AutoProcessor + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +ctx_len = 4096 +#Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [3072] +#Set the list of ccl during decoding process +comp_ctx_lengths_decode = [ctx_len] + +continious_batching = True +if continious_batching: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + qaic_config={ + "comp_ctx_lengths_prefill":comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode":comp_ctx_lengths_decode, + "ctx_len":ctx_len, + }, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + full_batch_size=4, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) +else: + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "comp_ctx_lengths_prefill":comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode":comp_ctx_lengths_decode, + "ctx_len":ctx_len, + }, + ) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=17, + batch_size=1, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, + ) + +image_urls = [ + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png", + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +exec_info = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + device_ids=[0,1,2,3], + generation_len=100, +) + +# print("Generated texts:", exec_info.generated_texts) +print("Generated IDs:", exec_info.generated_ids) +print(exec_info) diff --git a/examples/ccl_llama4_example.py b/examples/ccl_llama4_example.py index 5fc715589..5da29960f 100644 --- a/examples/ccl_llama4_example.py +++ b/examples/ccl_llama4_example.py @@ -7,7 +7,7 @@ import torch import transformers -from transformers import AutoConfig, AutoModelForImageTextToText, AutoProcessor, TextStreamer +from transformers import AutoConfig, AutoProcessor, TextStreamer from QEfficient import QEFFAutoModelForImageTextToText @@ -17,23 +17,25 @@ config.text_config.num_hidden_layers = 4 config.vision_config.num_hidden_layers = 2 -model = AutoModelForImageTextToText.from_pretrained(model_id, attn_implementation="eager", config=config) -model.eval() -tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) -processor = AutoProcessor.from_pretrained(model_id) - -### For running the model in single QPC approach use kv_offload=False. For Dual QPC approach use kv_offload=True ### ctx_len = 8192 +# Set the list of ccl during prefilling process comp_ctx_lengths_prefill = [3072] +# Set the list of ccl during decoding process comp_ctx_lengths_decode = [4096, ctx_len] -qeff_model = QEFFAutoModelForImageTextToText( - model, +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", kv_offload=True, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, + config=config, ) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) ### use skip_vision=Ture, if want to run only text, ow false ### skip_vision = False @@ -75,7 +77,7 @@ ) streamer = TextStreamer(tokenizer) - output = qeff_model.generate(inputs=inputs, generation_len=700) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) print(output) @@ -119,7 +121,7 @@ ) inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) streamer = TextStreamer(tokenizer) - output = qeff_model.generate(inputs=inputs, generation_len=1024) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) print(output) diff --git a/examples/ccl_llama4_multi_image_example.py b/examples/ccl_llama4_multi_image_example.py new file mode 100644 index 000000000..88af4d2d1 --- /dev/null +++ b/examples/ccl_llama4_multi_image_example.py @@ -0,0 +1,89 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import torch +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct" +config = AutoConfig.from_pretrained(model_id) +# For Testing Purpose Only +config.text_config.num_hidden_layers = 4 +config.vision_config.num_hidden_layers = 2 + +ctx_len = 8192 +#Set the list of ccl during prefilling process +comp_ctx_lengths_prefill = [5376] +#Set the list of ccl during decoding process +comp_ctx_lengths_decode = [6144, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + qaic_config={ + "comp_ctx_lengths_prefill":comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode":comp_ctx_lengths_decode, + "ctx_len":ctx_len, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +### For multi-image, the value of max_num_tiles should be the sum of the num_tiles values across all the images ### +qeff_model.compile( + prefill_seq_len=128, + ctx_len=ctx_len, + img_size=336, + num_cores=16, + num_devices=4, + max_num_tiles=34, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +### Multi_image Prompt ### +image_url_1 = ( + "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/datasets/cat_style_layout.png" +) + + +image_url_2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" + +messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": image_url_1}, + {"type": "image", "url": image_url_2}, + { + "type": "text", + "text": "Analyze the key elements, colors, and objects in the two images. Discuss their similarities, differences, and how they complement or contrast each other. Reflect on the emotions or ideas they convey, considering the context, light, shadow, and composition.", + }, + ], + }, +] + +inputs = processor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", +) + +inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) +streamer = TextStreamer(tokenizer) +output = qeff_model.generate(inputs=inputs, device_ids=[0,1,2,3], generation_len=100) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/ccl_mistral3_example.py b/examples/ccl_mistral3_example.py index b76227a22..96ed519f5 100644 --- a/examples/ccl_mistral3_example.py +++ b/examples/ccl_mistral3_example.py @@ -42,9 +42,11 @@ def run_model( model_name, kv_offload=kv_offload, config=config, - ctx_len=ctx_len, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, ) ## STEP - 2 Export & Compile the Model diff --git a/examples/ccl_molmo_example.py b/examples/ccl_molmo_example.py index c52d9172b..dd09fa020 100644 --- a/examples/ccl_molmo_example.py +++ b/examples/ccl_molmo_example.py @@ -19,18 +19,20 @@ # config.num_hidden_layers = 2 # load the model -ctx_len = 32768 +ctx_len = 8192 comp_ctx_lengths_prefill = [3072] -comp_ctx_lengths_decode = [4096, 8192, ctx_len] +comp_ctx_lengths_decode = [4096, 8192] qeff_model = QEFFAutoModelForCausalLM.from_pretrained( model_id, kv_offload=True, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, trust_remote_code=True, config=config, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True) @@ -91,7 +93,7 @@ inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) streamer = TextStreamer(tokenizer) - output = qeff_model.generate(inputs=inputs, device_ids=[8, 9, 10, 11], generation_len=100) + output = qeff_model.generate(inputs=inputs, device_ids=[0, 1, 2, 3], generation_len=100) print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) print(output) diff --git a/examples/ccl_qwen2_5_vl_CB.py b/examples/ccl_qwen2_5_vl_CB.py new file mode 100644 index 000000000..b08c72eb0 --- /dev/null +++ b/examples/ccl_qwen2_5_vl_CB.py @@ -0,0 +1,81 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + +import transformers +from transformers import AutoConfig, AutoProcessor, TextStreamer + +from QEfficient import QEFFAutoModelForImageTextToText + +## For AWQ model update pytorch version to 2.8.* +model_id = "Qwen/Qwen2.5-VL-32B-Instruct" +config = AutoConfig.from_pretrained(model_id) +# config.text_config.num_hidden_layers = 2 + +ctx_len = 8192 +comp_ctx_lengths_prefill = [4096] +comp_ctx_lengths_decode = [6144, ctx_len] + +qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( + model_id, + attn_implementation="eager", + kv_offload=True, + config=config, + continuous_batching=True, + qaic_config={ + "comp_ctx_lengths_prefill":comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode":comp_ctx_lengths_decode, + "ctx_len":ctx_len, + }, +) +tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) +processor = AutoProcessor.from_pretrained(model_id) + +batch_size = 1 +## Vision + Text ## +qeff_model.compile( + batch_size=batch_size, + full_batch_size=4, + prefill_seq_len=128, + ctx_len=8192, + num_cores=16, + num_devices=4, + height=354, + width=536, + mxfp6_matmul=True, + mxint8_kv_cache=True, + aic_enable_depth_first=True, + mos=1, +) + +image_urls = [ + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", + "https://picsum.photos/id/237/536/354", +] + +prompts = [ + "Can you describe the image in detail?", + "What are the objects in the image?", + "What is the main subject of the image?", + "What colors are predominant in the image?", +] + +streamer = TextStreamer(tokenizer) +output = qeff_model.generate( + tokenizer=tokenizer, + prompts=prompts, + processor=processor, + images=image_urls, + generation_len=100, +) +print(output.generated_ids) +print(tokenizer.batch_decode(output.generated_ids)) +print(output) diff --git a/examples/ccl_qwen2_5_vl_example.py b/examples/ccl_qwen2_5_vl_example.py index b813462e3..273a18361 100644 --- a/examples/ccl_qwen2_5_vl_example.py +++ b/examples/ccl_qwen2_5_vl_example.py @@ -5,9 +5,10 @@ # # ----------------------------------------------------------------------------- +# If we want to enable QBlocking Run below command:, default is without blocking +# ATTENTION_BLOCKING_MODE=q num_q_blocks=2 python -W ignore qwen2_5_vl_example.py + import requests -import torch -import torch.nn.functional as F import transformers from PIL import Image from qwen_vl_utils import process_vision_info @@ -18,8 +19,7 @@ ## For AWQ model update pytorch version to 2.8.* model_id = "Qwen/Qwen2.5-VL-32B-Instruct" config = AutoConfig.from_pretrained(model_id) - -## Use complete model without changing num_hidden_layers as it will not work for TF version 4.55.0 for Qwen2.5VL model +# config.text_config.num_hidden_layers = 2 ctx_len = 8192 comp_ctx_lengths_prefill = [4096] @@ -27,12 +27,14 @@ qeff_model = QEFFAutoModelForImageTextToText.from_pretrained( model_id, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, attn_implementation="eager", kv_offload=True, config=config, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, ) tokenizer = transformers.AutoTokenizer.from_pretrained(model_id) processor = AutoProcessor.from_pretrained(model_id) @@ -44,7 +46,7 @@ ## Only Text ## ## Set Batch_Size ## - batch_size = 2 + batch_size = 1 qeff_model.compile( batch_size=batch_size, prefill_seq_len=128, @@ -78,28 +80,10 @@ return_tensors="pt", ) - pos_ids, rope_deltas = qeff_model.model.get_rope_index( - inputs["input_ids"], - image_grid_thw=None, - video_grid_thw=None, - second_per_grid_ts=None, - attention_mask=inputs["attention_mask"], - ) - - input_ids_length = inputs["input_ids"].shape[1] - - inputs["position_ids"] = torch.cat([pos_ids, pos_ids[0].unsqueeze(0)], dim=0) - - prefill_seq_len = 128 - num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float - padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len - - inputs["position_ids"] = F.pad( - inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 - ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) streamer = TextStreamer(tokenizer) - output = qeff_model.generate(inputs=inputs, generation_len=100) + output = qeff_model.generate(inputs=inputs, generation_len=100, device_ids=[0, 1, 2, 3]) print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) print(output) @@ -158,31 +142,11 @@ padding=True, return_tensors="pt", ) - input_ids_length = inputs["input_ids"].shape[1] - - inputs["position_ids"] = torch.arange(input_ids_length).view(1, 1, input_ids_length).expand(-1, batch_size, -1) - - pos_ids, rope_deltas = qeff_model.model.model.get_rope_index( - inputs["input_ids"], - inputs["image_grid_thw"], - video_grid_thw=None, - second_per_grid_ts=None, - attention_mask=inputs["attention_mask"], - ) - inputs["position_ids"] = torch.cat((inputs["position_ids"], pos_ids), dim=0) - - prefill_seq_len = 128 - num_chunks = -(input_ids_length // -prefill_seq_len) # ceil divide without float - padded_len = num_chunks * prefill_seq_len # Convert to a multiple of prompt_len - - inputs["position_ids"] = F.pad( - inputs["position_ids"], pad=(0, padded_len - input_ids_length), mode="constant", value=-1 - ) + inputs = qeff_model.model.prepare_inputs_for_generation(inputs=inputs, prefill_seq_len=128, batch_size=batch_size) - inputs.pop("image_grid_thw") streamer = TextStreamer(tokenizer) - output = qeff_model.generate(inputs=inputs, generation_len=100) + output = qeff_model.generate(inputs=inputs, generation_len=100, device_ids=[0, 1, 2, 3]) print(output.generated_ids) print(tokenizer.batch_decode(output.generated_ids)) print(output) diff --git a/examples/compute_context_length.py b/examples/compute_context_length.py index 3be1a9eab..b7cc8d173 100644 --- a/examples/compute_context_length.py +++ b/examples/compute_context_length.py @@ -16,8 +16,8 @@ ## - The second comp_ctx_lengths_decode list will be used for decoding. During the decoding process, based on the position_id or cache index it will work with the specific compute-context-length in the list. It will start from a proper compute-context-length in the list based on input prompt length and will gradually increase the compute-context-length if the cache index passes the current compute-context-length. ## ctx_len = 1024 -comp_ctx_lengths_prefill = [256] # None -comp_ctx_lengths_decode = [ctx_len] # None +comp_ctx_lengths_prefill = [256, 1000] # None +comp_ctx_lengths_decode = [512, ctx_len] # None # model_name = "google/gemma-7b" # model_name = "google/gemma-2-2b" @@ -27,10 +27,10 @@ # model_name = "microsoft/phi-1_5" # model_name = "microsoft/Phi-3-mini-4k-instruct" # model_name = "Qwen/Qwen2.5-7B-Instruct" -# model_name = "meta-llama/Llama-3.2-1B" +model_name = "meta-llama/Llama-3.2-1B" # model_name = "Qwen/Qwen3-1.7B" # model_name = "allenai/OLMo-2-0425-1B" -model_name = "ibm-granite/granite-3.3-2b-base" +# model_name = "ibm-granite/granite-3.3-2b-base" # model_name = "ibm-granite/granite-3.2-8b-instruct" # model_name = "meta-llama/Llama-3.3-70B-Instruct" # model_name = "Salesforce/codegen-350M-mono" @@ -40,10 +40,12 @@ model = QEFFAutoModelForCausalLM.from_pretrained( model_name, - continuous_batching=True, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, + continuous_batching=False, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, ) # model compilation for either continuous or static batching. For continuous batching full_batch_size is needed. @@ -52,10 +54,10 @@ ctx_len=ctx_len, num_cores=16, num_devices=4, - full_batch_size=1, mxint8_kv_cache=True, mxfp6_matmul=True, ) +# full_batch_size=1, # Create tokenizer and run model.generate and passes the input prompts to it. tokenizer = AutoTokenizer.from_pretrained(model_name) @@ -64,5 +66,5 @@ "My name is ", ], tokenizer=tokenizer, - generation_len=128, + generation_len=512, ) diff --git a/examples/gemma3_example/ccl_gemma3_mm.py b/examples/gemma3_example/ccl_gemma3_mm.py index 484c0f8ce..9bf6e9c5a 100644 --- a/examples/gemma3_example/ccl_gemma3_mm.py +++ b/examples/gemma3_example/ccl_gemma3_mm.py @@ -31,9 +31,11 @@ config=config, attn_implementation="eager", kv_offload=True, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, ) ### use skip_vision=Ture, if want to run only text, or false ### diff --git a/examples/granite_example/ccl_granite_vision_inference.py b/examples/granite_example/ccl_granite_vision_inference.py index e03b94a5e..64ecaf948 100644 --- a/examples/granite_example/ccl_granite_vision_inference.py +++ b/examples/granite_example/ccl_granite_vision_inference.py @@ -43,9 +43,11 @@ def run_model( model_name, token=token, kv_offload=kv_offload, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, ) ## STEP - 2 Export & Compile the Model diff --git a/examples/intern_example/ccl_internvl_inference.py b/examples/intern_example/ccl_internvl_inference.py index 0828b1d41..827d50c97 100644 --- a/examples/intern_example/ccl_internvl_inference.py +++ b/examples/intern_example/ccl_internvl_inference.py @@ -189,10 +189,12 @@ def run_intern_on_aic( model = QEFFAutoModelForCausalLM.from_pretrained( model_name, kv_offload=kv_offload, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, trust_remote_code=True, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + }, ) ## STEP 2 -- EXPORT & COMPILE THE MODEL diff --git a/examples/qwen3moe_example/ccl_qwen3moe_inference.py b/examples/qwen3moe_example/ccl_qwen3moe_inference.py index f200c6fa6..d2fa208df 100644 --- a/examples/qwen3moe_example/ccl_qwen3moe_inference.py +++ b/examples/qwen3moe_example/ccl_qwen3moe_inference.py @@ -24,11 +24,13 @@ model = QEFFAutoModelForCausalLM.from_pretrained( model_name, - comp_ctx_lengths_prefill=comp_ctx_lengths_prefill, - comp_ctx_lengths_decode=comp_ctx_lengths_decode, - ctx_len=ctx_len, continuous_batching=False, - prefill_seq_len=prefill_seq_len, + qaic_config={ + "comp_ctx_lengths_prefill": comp_ctx_lengths_prefill, + "comp_ctx_lengths_decode": comp_ctx_lengths_decode, + "ctx_len": ctx_len, + "prefill_seq_len": prefill_seq_len, + }, ) model.compile(