From 0efb3129b0145a4b6b215c5c7df6b1ba94f0a5ac Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Mon, 3 Nov 2025 07:34:02 +0000 Subject: [PATCH 1/3] modified code for onnx export with dynamo enabled Signed-off-by: Sharvari Medhe --- QEfficient/base/modeling_qeff.py | 5 +- QEfficient/customop/ctx_scatter_gather.py | 14 ++- .../transformers/models/modeling_auto.py | 98 +++++++++++++++++++ 3 files changed, 111 insertions(+), 6 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..9bbfec9f4 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -179,6 +179,7 @@ def _export( onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, + dynamic_shapes: Optional[Dict[str, Dict[int, any]]] = None, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -250,8 +251,10 @@ def _export( str(tmp_onnx_path), input_names=input_names, output_names=output_names, - dynamic_axes=dynamic_axes, + dynamic_shapes=dynamic_shapes, opset_version=constants.ONNX_EXPORT_OPSET, + dynamo=True, + report=True, **export_kwargs, ) logger.info("PyTorch export successful") diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c4f5a7bbd..c9f700575 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -41,11 +41,15 @@ class CtxScatterFunc(torch.autograd.Function): @staticmethod def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): - batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) - head_idx = torch.arange(data.shape[1]).view(1, -1, 1) - ctx_idx = position_ids.unsqueeze(1) - data[batch_idx, head_idx, ctx_idx] = updates - return data + # batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) + # head_idx = torch.arange(data.shape[1]).view(1, -1, 1) + # ctx_idx = position_ids.unsqueeze(1) + # data[batch_idx, head_idx, ctx_idx] = updates + # return data + B, H, T, D = data.shape + idx = position_ids.unsqueeze(1).unsqueeze(-1).expand(B, H, T, D).to(dtype=torch.long) + out = data.scatter_reduce(dim=2, index=idx, src=updates, reduce="sum", include_self=False) + return out @staticmethod def setup_context(ctx, inputs, outputs): diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..dc467643c 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -59,6 +59,101 @@ from QEfficient.utils.logging_utils import logger +def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str]]) -> Dict[str, any]: + """ + Convert ONNX dynamic_axes format to torch.export dynamic_shapes format + + Args: + dynamic_axes: ONNX format like {"input_ids": {0: "batch_size", 1: "seq_len"}} + + Returns: + dynamic_shapes: torch.export format with Dim objects matching model forward args + """ + from torch.export import Dim + + # Create dimension registry to reuse Dim objects with same names + dim_registry = {} + dynamic_shapes = {} + + # Handle regular model inputs (not past_key_values) + # These match the QEffLlamaForCausalLM forward signature: + # input_ids, attention_mask, position_ids, past_key_values, batch_index, etc. + for input_name, axes_map in dynamic_axes.items(): + if not input_name.startswith("past_"): + input_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + # Create or reuse Dim object for this dimension name + if dim_name not in dim_registry: + if dim_name == "batch_size": + dim_registry[dim_name] = Dim(dim_name, min=1, max=64) # Support realistic batch sizes + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Conservative seq range + else: + dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Generic conservative range + + input_dynamic_shapes[axis_idx] = dim_registry[dim_name] + + dynamic_shapes[input_name] = input_dynamic_shapes + + # Handle past_key_values specially - collect all past_key.X and past_value.X + past_keys = {} + past_values = {} + + for input_name, axes_map in dynamic_axes.items(): + if input_name.startswith("past_key."): + layer_idx = int(input_name.split(".")[1]) + layer_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + if dim_name not in dim_registry: + # Create Dim with conservative constraints to avoid conflicts + if dim_name == "batch_size": + dim_registry[dim_name] = Dim(dim_name, min=1, max=64) # Support realistic batch sizes + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Conservative seq range + else: + dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Generic conservative range + layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] + past_keys[layer_idx] = layer_dynamic_shapes + + elif input_name.startswith("past_value."): + layer_idx = int(input_name.split(".")[1]) + layer_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + if dim_name not in dim_registry: + # Create Dim with conservative constraints to avoid conflicts + if dim_name == "batch_size": + dim_registry[dim_name] = Dim(dim_name, min=1, max=64) # Support realistic batch sizes + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Conservative seq range + else: + dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Generic conservative range + layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] + past_values[layer_idx] = layer_dynamic_shapes + + # Reconstruct past_key_values as nested structure if we have past keys/values + if past_keys or past_values: + max_layer = max(list(past_keys.keys()) + list(past_values.keys())) + past_kv_shapes = [] + + for layer_idx in range(max_layer + 1): + layer_shapes = [] + if layer_idx in past_keys: + layer_shapes.append(past_keys[layer_idx]) + else: + layer_shapes.append({}) + + if layer_idx in past_values: + layer_shapes.append(past_values[layer_idx]) + else: + layer_shapes.append({}) + + past_kv_shapes.append(layer_shapes) + + dynamic_shapes["past_key_values"] = past_kv_shapes + + return dynamic_shapes + + class QEFFTransformersBase(QEFFBaseModel): """ Base class for QEfficient wrappers around HuggingFace transformer models. @@ -2310,11 +2405,14 @@ def export(self, export_dir: Optional[str] = None) -> str: dynamic_axes=dynamic_axes, ) + dynamic_shapes = convert_dynamic_axes_to_dynamic_shapes(dynamic_axes) + return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, + dynamic_shapes=dynamic_shapes, ) def get_sampling_inputs_and_outputs( From cfa4578b74e251ee6ea870d1c0d51c0c67e6ff4c Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Fri, 7 Nov 2025 08:13:03 +0000 Subject: [PATCH 2/3] adding custom_translation_table Signed-off-by: Sharvari Medhe --- QEfficient/base/modeling_qeff.py | 17 ++++- QEfficient/customop/__init__.py | 4 +- QEfficient/customop/ctx_scatter_gather.py | 68 +++++++++++-------- QEfficient/transformers/cache_utils.py | 40 +++++++---- .../transformers/models/modeling_auto.py | 20 +++--- QEfficient/utils/constants.py | 4 +- 6 files changed, 95 insertions(+), 58 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 9bbfec9f4..3deb463c3 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -22,6 +22,7 @@ from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.customop.ctx_scatter_gather import custom_translation_table from QEfficient.utils import ( constants, create_json, @@ -35,7 +36,11 @@ logger = logging.getLogger(__name__) +import torch._dynamo.config +# Allow custom ops to be treated as black boxes +torch._dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.capture_dynamic_output_shape_ops = True class QEFFBaseModel(ABC): """ Base class for all the model classes (i.e. LLMs, SD, quantized etc.). @@ -245,18 +250,26 @@ def _export( try: export_kwargs = {} if export_kwargs is None else export_kwargs - torch.onnx.export( + import time + start = time.perf_counter() + onnx_program = torch.onnx.export( self.model, (example_inputs,), - str(tmp_onnx_path), + # str(tmp_onnx_path), input_names=input_names, output_names=output_names, + # dynamic_axes=dynamic_axes, dynamic_shapes=dynamic_shapes, + custom_translation_table=custom_translation_table, opset_version=constants.ONNX_EXPORT_OPSET, dynamo=True, report=True, **export_kwargs, ) + end = time.perf_counter() + print("Dynamo enabled onnx export in memory time in sec", round(end - start,2)) + + onnx_program.save(str(tmp_onnx_path)) logger.info("PyTorch export successful") _ = self._offload_model_weights(offload_pt_weights) diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index ff0709f82..015001855 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc3D from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherFuncCB, CtxGatherFuncCB3D, @@ -16,7 +16,7 @@ __all__ = [ "CtxGatherFunc", - "CtxScatterFunc", + # "CtxScatterFunc", "CtxGatherFunc3D", "CtxScatterFunc3D", "CustomRMSNormAIC", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c9f700575..99a6533ca 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -7,13 +7,11 @@ import onnxscript import torch - from QEfficient.utils import constants ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET)) - -@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +# @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: # Find dims batch_size = ops.Gather(ops.Shape(data), [0]) @@ -34,31 +32,42 @@ def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: return ops.ScatterND(data, indices, updates) -class CtxScatterFunc(torch.autograd.Function): - """ - Function to scatter the current key values into KV-cache. - """ - - @staticmethod - def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): - # batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) - # head_idx = torch.arange(data.shape[1]).view(1, -1, 1) - # ctx_idx = position_ids.unsqueeze(1) - # data[batch_idx, head_idx, ctx_idx] = updates - # return data - B, H, T, D = data.shape - idx = position_ids.unsqueeze(1).unsqueeze(-1).expand(B, H, T, D).to(dtype=torch.long) - out = data.scatter_reduce(dim=2, index=idx, src=updates, reduce="sum", include_self=False) - return out - - @staticmethod - def setup_context(ctx, inputs, outputs): - pass - - @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) - +@torch.library.custom_op("qefficient::ctx_scatter", mutates_args=()) +def ctx_scatter_op(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + """Custom context scatter operation""" + result = data.clone() + batch_idx = torch.arange(result.shape[0]).view(-1, 1, 1) + head_idx = torch.arange(result.shape[1]).view(1, -1, 1) + ctx_idx = position_ids.unsqueeze(1) + result[batch_idx, head_idx, ctx_idx] = updates + return result + +@ctx_scatter_op.register_fake +def _(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + """Fake implementation for torch.export - just returns data tensor with same shape/dtype""" + return data.clone() + +# class CtxScatterFunc(torch.autograd.Function): +# """ +# Function to scatter the current key values into KV-cache. +# """ + +# @staticmethod +# def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): +# # batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) +# # head_idx = torch.arange(data.shape[1]).view(1, -1, 1) +# # ctx_idx = position_ids.unsqueeze(1) +# # data[batch_idx, head_idx, ctx_idx] = updates +# # return data + +# @staticmethod +# def setup_context(ctx, inputs, outputs): +# pass + +# @staticmethod +# def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: +# return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: @@ -143,3 +152,6 @@ def setup_context(ctx, inputs, outputs): @staticmethod def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) + + +custom_translation_table = {torch.ops.qefficient.ctx_scatter.default : CtxScatter, } diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index bbd937d52..2e1f1abdb 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -17,7 +17,7 @@ CtxGatherFunc3D, CtxGatherFuncCB, CtxGatherFuncCB3D, - CtxScatterFunc, + # CtxScatterFunc, CtxScatterFunc3D, CtxScatterFuncCB, CtxScatterFuncCB3D, @@ -90,8 +90,14 @@ def write_only(self, key_states, value_states, cache_kwargs): self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) else: - self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + # self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + # self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + self.keys = torch.ops.qefficient.ctx_scatter(self.keys, position_ids, key_states) + self.values = torch.ops.qefficient.ctx_scatter(self.values, position_ids, value_states) + + + + def update( self, @@ -131,8 +137,10 @@ def update( self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) else: - self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + # self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + # self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + self.keys = torch.ops.qefficient.ctx_scatter(self.keys, position_ids, key_states) + self.values = torch.ops.qefficient.ctx_scatter(self.values, position_ids, value_states) k_out, v_out = self.keys, self.values @@ -152,6 +160,7 @@ def update( k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices) v_out = CtxGatherFunc.apply(v_out, ctx_indices) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) @@ -334,6 +343,7 @@ def from_legacy_cache( # TODO:This function will be depercated in future. class QEffHybridCache(HybridCache): + def __init__(self, config, batch_size, max_cache_len): super().__init__(config, batch_size, max_cache_len=max_cache_len) self.key_cache: List[torch.Tensor] = [] @@ -407,10 +417,12 @@ def update( valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) + # self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + # self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) + self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(self.value_cache[layer_idx], kv_position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather @@ -509,10 +521,12 @@ def update( valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( - self.value_cache[layer_idx], kv_position_ids, value_states - ) + # self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + # self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) + self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(self.value_cache[layer_idx], kv_position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index dc467643c..45d8272a3 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -85,11 +85,11 @@ def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str # Create or reuse Dim object for this dimension name if dim_name not in dim_registry: if dim_name == "batch_size": - dim_registry[dim_name] = Dim(dim_name, min=1, max=64) # Support realistic batch sizes + dim_registry[dim_name] = Dim.DYNAMIC elif "seq_len" in dim_name: - dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Conservative seq range + dim_registry[dim_name] = Dim.DYNAMIC else: - dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Generic conservative range + dim_registry[dim_name] = Dim.DYNAMIC input_dynamic_shapes[axis_idx] = dim_registry[dim_name] @@ -105,13 +105,12 @@ def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str layer_dynamic_shapes = {} for axis_idx, dim_name in axes_map.items(): if dim_name not in dim_registry: - # Create Dim with conservative constraints to avoid conflicts if dim_name == "batch_size": - dim_registry[dim_name] = Dim(dim_name, min=1, max=64) # Support realistic batch sizes + dim_registry[dim_name] = Dim.DYNAMIC elif "seq_len" in dim_name: - dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Conservative seq range + dim_registry[dim_name] = Dim.DYNAMIC else: - dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Generic conservative range + dim_registry[dim_name] = Dim.DYNAMIC layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] past_keys[layer_idx] = layer_dynamic_shapes @@ -120,13 +119,12 @@ def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str layer_dynamic_shapes = {} for axis_idx, dim_name in axes_map.items(): if dim_name not in dim_registry: - # Create Dim with conservative constraints to avoid conflicts if dim_name == "batch_size": - dim_registry[dim_name] = Dim(dim_name, min=1, max=64) # Support realistic batch sizes + dim_registry[dim_name] = Dim.DYNAMIC elif "seq_len" in dim_name: - dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Conservative seq range + dim_registry[dim_name] = Dim.DYNAMIC else: - dim_registry[dim_name] = Dim(dim_name, min=1, max=4096) # Generic conservative range + dim_registry[dim_name] = Dim.DYNAMIC layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] past_values[layer_idx] = layer_dynamic_shapes diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 5f7a4db7b..5aab14520 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -17,7 +17,7 @@ ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 @@ -84,7 +84,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"] DEFAULT_AIC_HW_VERSION = "ai100" From bb88fb87f28fcfb1e3f1371a456ae2740577382f Mon Sep 17 00:00:00 2001 From: Sharvari Medhe Date: Fri, 7 Nov 2025 08:15:21 +0000 Subject: [PATCH 3/3] quickfix: formatting changes Signed-off-by: Sharvari Medhe --- QEfficient/base/modeling_qeff.py | 11 +++++---- QEfficient/customop/ctx_scatter_gather.py | 12 +++++++--- QEfficient/transformers/cache_utils.py | 24 +++++++++---------- .../transformers/models/modeling_auto.py | 18 +++++++------- 4 files changed, 37 insertions(+), 28 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 3deb463c3..02394f934 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -21,8 +21,8 @@ from QEfficient.base.onnx_transforms import OnnxTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile -from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.customop.ctx_scatter_gather import custom_translation_table +from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import ( constants, create_json, @@ -36,11 +36,13 @@ logger = logging.getLogger(__name__) -import torch._dynamo.config +import torch._dynamo.config # noqa # Allow custom ops to be treated as black boxes torch._dynamo.config.capture_scalar_outputs = True torch._dynamo.config.capture_dynamic_output_shape_ops = True + + class QEFFBaseModel(ABC): """ Base class for all the model classes (i.e. LLMs, SD, quantized etc.). @@ -251,8 +253,9 @@ def _export( try: export_kwargs = {} if export_kwargs is None else export_kwargs import time + start = time.perf_counter() - onnx_program = torch.onnx.export( + onnx_program = torch.onnx.export( self.model, (example_inputs,), # str(tmp_onnx_path), @@ -267,7 +270,7 @@ def _export( **export_kwargs, ) end = time.perf_counter() - print("Dynamo enabled onnx export in memory time in sec", round(end - start,2)) + print("Dynamo enabled onnx export in memory time in sec", round(end - start, 2)) onnx_program.save(str(tmp_onnx_path)) logger.info("PyTorch export successful") diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index 99a6533ca..8d68ae3a1 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -7,10 +7,12 @@ import onnxscript import torch + from QEfficient.utils import constants ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET)) + # @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: # Find dims @@ -41,12 +43,14 @@ def ctx_scatter_op(data: torch.Tensor, position_ids: torch.Tensor, updates: torc ctx_idx = position_ids.unsqueeze(1) result[batch_idx, head_idx, ctx_idx] = updates return result - + + @ctx_scatter_op.register_fake def _(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: """Fake implementation for torch.export - just returns data tensor with same shape/dtype""" return data.clone() + # class CtxScatterFunc(torch.autograd.Function): # """ # Function to scatter the current key values into KV-cache. @@ -67,7 +71,7 @@ def _(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> # @staticmethod # def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: # return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) - + @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxScatter3D(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: @@ -154,4 +158,6 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) -custom_translation_table = {torch.ops.qefficient.ctx_scatter.default : CtxScatter, } +custom_translation_table = { + torch.ops.qefficient.ctx_scatter.default: CtxScatter, +} diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 2e1f1abdb..92f73dbd6 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -95,10 +95,6 @@ def write_only(self, key_states, value_states, cache_kwargs): self.keys = torch.ops.qefficient.ctx_scatter(self.keys, position_ids, key_states) self.values = torch.ops.qefficient.ctx_scatter(self.values, position_ids, value_states) - - - - def update( self, key_states: torch.Tensor, @@ -160,7 +156,6 @@ def update( k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices) v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices) else: - k_out = CtxGatherFunc.apply(k_out, ctx_indices) v_out = CtxGatherFunc.apply(v_out, ctx_indices) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) @@ -343,7 +338,6 @@ def from_legacy_cache( # TODO:This function will be depercated in future. class QEffHybridCache(HybridCache): - def __init__(self, config, batch_size, max_cache_len): super().__init__(config, batch_size, max_cache_len=max_cache_len) self.key_cache: List[torch.Tensor] = [] @@ -419,10 +413,13 @@ def update( value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) # self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) # self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) - self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(self.value_cache[layer_idx], kv_position_ids, value_states) + self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter( + self.key_cache[layer_idx], kv_position_ids, key_states + ) + self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter( + self.value_cache[layer_idx], kv_position_ids, value_states + ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather @@ -523,10 +520,13 @@ def update( value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) # self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) # self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) - self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(self.value_cache[layer_idx], kv_position_ids, value_states) + self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter( + self.key_cache[layer_idx], kv_position_ids, key_states + ) + self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter( + self.value_cache[layer_idx], kv_position_ids, value_states + ) - k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 45d8272a3..777feda0a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -85,11 +85,11 @@ def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str # Create or reuse Dim object for this dimension name if dim_name not in dim_registry: if dim_name == "batch_size": - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC elif "seq_len" in dim_name: - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC else: - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC input_dynamic_shapes[axis_idx] = dim_registry[dim_name] @@ -106,11 +106,11 @@ def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str for axis_idx, dim_name in axes_map.items(): if dim_name not in dim_registry: if dim_name == "batch_size": - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC elif "seq_len" in dim_name: - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC else: - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] past_keys[layer_idx] = layer_dynamic_shapes @@ -120,11 +120,11 @@ def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str for axis_idx, dim_name in axes_map.items(): if dim_name not in dim_registry: if dim_name == "batch_size": - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC elif "seq_len" in dim_name: - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC else: - dim_registry[dim_name] = Dim.DYNAMIC + dim_registry[dim_name] = Dim.DYNAMIC layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] past_values[layer_idx] = layer_dynamic_shapes