diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 6ecbf0fc0..02394f934 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -21,6 +21,7 @@ from QEfficient.base.onnx_transforms import OnnxTransform from QEfficient.base.pytorch_transforms import PytorchTransform from QEfficient.compile.qnn_compiler import compile as qnn_compile +from QEfficient.customop.ctx_scatter_gather import custom_translation_table from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import ( constants, @@ -35,6 +36,12 @@ logger = logging.getLogger(__name__) +import torch._dynamo.config # noqa + +# Allow custom ops to be treated as black boxes +torch._dynamo.config.capture_scalar_outputs = True +torch._dynamo.config.capture_dynamic_output_shape_ops = True + class QEFFBaseModel(ABC): """ @@ -179,6 +186,7 @@ def _export( onnx_transform_kwargs: Optional[Dict[str, any]] = None, export_dir: Optional[str] = None, offload_pt_weights: bool = True, + dynamic_shapes: Optional[Dict[str, Dict[int, any]]] = None, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -244,16 +252,27 @@ def _export( try: export_kwargs = {} if export_kwargs is None else export_kwargs - torch.onnx.export( + import time + + start = time.perf_counter() + onnx_program = torch.onnx.export( self.model, (example_inputs,), - str(tmp_onnx_path), + # str(tmp_onnx_path), input_names=input_names, output_names=output_names, - dynamic_axes=dynamic_axes, + # dynamic_axes=dynamic_axes, + dynamic_shapes=dynamic_shapes, + custom_translation_table=custom_translation_table, opset_version=constants.ONNX_EXPORT_OPSET, + dynamo=True, + report=True, **export_kwargs, ) + end = time.perf_counter() + print("Dynamo enabled onnx export in memory time in sec", round(end - start, 2)) + + onnx_program.save(str(tmp_onnx_path)) logger.info("PyTorch export successful") _ = self._offload_model_weights(offload_pt_weights) diff --git a/QEfficient/customop/__init__.py b/QEfficient/customop/__init__.py index ff0709f82..015001855 100644 --- a/QEfficient/customop/__init__.py +++ b/QEfficient/customop/__init__.py @@ -5,7 +5,7 @@ # # ----------------------------------------------------------------------------- -from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D +from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc3D from QEfficient.customop.ctx_scatter_gather_cb import ( CtxGatherFuncCB, CtxGatherFuncCB3D, @@ -16,7 +16,7 @@ __all__ = [ "CtxGatherFunc", - "CtxScatterFunc", + # "CtxScatterFunc", "CtxGatherFunc3D", "CtxScatterFunc3D", "CustomRMSNormAIC", diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c4f5a7bbd..8d68ae3a1 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -13,7 +13,7 @@ ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET)) -@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) +# @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT: # Find dims batch_size = ops.Gather(ops.Shape(data), [0]) @@ -34,26 +34,43 @@ def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: return ops.ScatterND(data, indices, updates) -class CtxScatterFunc(torch.autograd.Function): - """ - Function to scatter the current key values into KV-cache. - """ +@torch.library.custom_op("qefficient::ctx_scatter", mutates_args=()) +def ctx_scatter_op(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + """Custom context scatter operation""" + result = data.clone() + batch_idx = torch.arange(result.shape[0]).view(-1, 1, 1) + head_idx = torch.arange(result.shape[1]).view(1, -1, 1) + ctx_idx = position_ids.unsqueeze(1) + result[batch_idx, head_idx, ctx_idx] = updates + return result - @staticmethod - def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): - batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) - head_idx = torch.arange(data.shape[1]).view(1, -1, 1) - ctx_idx = position_ids.unsqueeze(1) - data[batch_idx, head_idx, ctx_idx] = updates - return data - @staticmethod - def setup_context(ctx, inputs, outputs): - pass +@ctx_scatter_op.register_fake +def _(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor: + """Fake implementation for torch.export - just returns data tensor with same shape/dtype""" + return data.clone() - @staticmethod - def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: - return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) + +# class CtxScatterFunc(torch.autograd.Function): +# """ +# Function to scatter the current key values into KV-cache. +# """ + +# @staticmethod +# def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor): +# # batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1) +# # head_idx = torch.arange(data.shape[1]).view(1, -1, 1) +# # ctx_idx = position_ids.unsqueeze(1) +# # data[batch_idx, head_idx, ctx_idx] = updates +# # return data + +# @staticmethod +# def setup_context(ctx, inputs, outputs): +# pass + +# @staticmethod +# def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value: +# return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data) @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1)) @@ -139,3 +156,8 @@ def setup_context(ctx, inputs, outputs): @staticmethod def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value: return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data) + + +custom_translation_table = { + torch.ops.qefficient.ctx_scatter.default: CtxScatter, +} diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index bbd937d52..92f73dbd6 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -17,7 +17,7 @@ CtxGatherFunc3D, CtxGatherFuncCB, CtxGatherFuncCB3D, - CtxScatterFunc, + # CtxScatterFunc, CtxScatterFunc3D, CtxScatterFuncCB, CtxScatterFuncCB3D, @@ -90,8 +90,10 @@ def write_only(self, key_states, value_states, cache_kwargs): self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states) self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) else: - self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + # self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + # self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + self.keys = torch.ops.qefficient.ctx_scatter(self.keys, position_ids, key_states) + self.values = torch.ops.qefficient.ctx_scatter(self.values, position_ids, value_states) def update( self, @@ -131,8 +133,10 @@ def update( self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states) else: - self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) - self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + # self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states) + # self.values = CtxScatterFunc.apply(self.values, position_ids, value_states) + self.keys = torch.ops.qefficient.ctx_scatter(self.keys, position_ids, key_states) + self.values = torch.ops.qefficient.ctx_scatter(self.values, position_ids, value_states) k_out, v_out = self.keys, self.values @@ -407,10 +411,15 @@ def update( valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( + # self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + # self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) + self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter( + self.key_cache[layer_idx], kv_position_ids, key_states + ) + self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter( self.value_cache[layer_idx], kv_position_ids, value_states ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather @@ -509,10 +518,15 @@ def update( valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1) key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states)) value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states)) - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply( + # self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + # self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states) + self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter( + self.key_cache[layer_idx], kv_position_ids, key_states + ) + self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter( self.value_cache[layer_idx], kv_position_ids, value_states ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] # Original Gather diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 633a0b29d..777feda0a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -59,6 +59,99 @@ from QEfficient.utils.logging_utils import logger +def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str]]) -> Dict[str, any]: + """ + Convert ONNX dynamic_axes format to torch.export dynamic_shapes format + + Args: + dynamic_axes: ONNX format like {"input_ids": {0: "batch_size", 1: "seq_len"}} + + Returns: + dynamic_shapes: torch.export format with Dim objects matching model forward args + """ + from torch.export import Dim + + # Create dimension registry to reuse Dim objects with same names + dim_registry = {} + dynamic_shapes = {} + + # Handle regular model inputs (not past_key_values) + # These match the QEffLlamaForCausalLM forward signature: + # input_ids, attention_mask, position_ids, past_key_values, batch_index, etc. + for input_name, axes_map in dynamic_axes.items(): + if not input_name.startswith("past_"): + input_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + # Create or reuse Dim object for this dimension name + if dim_name not in dim_registry: + if dim_name == "batch_size": + dim_registry[dim_name] = Dim.DYNAMIC + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim.DYNAMIC + else: + dim_registry[dim_name] = Dim.DYNAMIC + + input_dynamic_shapes[axis_idx] = dim_registry[dim_name] + + dynamic_shapes[input_name] = input_dynamic_shapes + + # Handle past_key_values specially - collect all past_key.X and past_value.X + past_keys = {} + past_values = {} + + for input_name, axes_map in dynamic_axes.items(): + if input_name.startswith("past_key."): + layer_idx = int(input_name.split(".")[1]) + layer_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + if dim_name not in dim_registry: + if dim_name == "batch_size": + dim_registry[dim_name] = Dim.DYNAMIC + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim.DYNAMIC + else: + dim_registry[dim_name] = Dim.DYNAMIC + layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] + past_keys[layer_idx] = layer_dynamic_shapes + + elif input_name.startswith("past_value."): + layer_idx = int(input_name.split(".")[1]) + layer_dynamic_shapes = {} + for axis_idx, dim_name in axes_map.items(): + if dim_name not in dim_registry: + if dim_name == "batch_size": + dim_registry[dim_name] = Dim.DYNAMIC + elif "seq_len" in dim_name: + dim_registry[dim_name] = Dim.DYNAMIC + else: + dim_registry[dim_name] = Dim.DYNAMIC + layer_dynamic_shapes[axis_idx] = dim_registry[dim_name] + past_values[layer_idx] = layer_dynamic_shapes + + # Reconstruct past_key_values as nested structure if we have past keys/values + if past_keys or past_values: + max_layer = max(list(past_keys.keys()) + list(past_values.keys())) + past_kv_shapes = [] + + for layer_idx in range(max_layer + 1): + layer_shapes = [] + if layer_idx in past_keys: + layer_shapes.append(past_keys[layer_idx]) + else: + layer_shapes.append({}) + + if layer_idx in past_values: + layer_shapes.append(past_values[layer_idx]) + else: + layer_shapes.append({}) + + past_kv_shapes.append(layer_shapes) + + dynamic_shapes["past_key_values"] = past_kv_shapes + + return dynamic_shapes + + class QEFFTransformersBase(QEFFBaseModel): """ Base class for QEfficient wrappers around HuggingFace transformer models. @@ -2310,11 +2403,14 @@ def export(self, export_dir: Optional[str] = None) -> str: dynamic_axes=dynamic_axes, ) + dynamic_shapes = convert_dynamic_axes_to_dynamic_shapes(dynamic_axes) + return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, + dynamic_shapes=dynamic_shapes, ) def get_sampling_inputs_and_outputs( diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 5f7a4db7b..5aab14520 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -17,7 +17,7 @@ ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32 ONNX_EXPORT_EXAMPLE_FBS = 4 ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 ONNX_EXPORT_MAX_NUM_IMAGES = 1 ONNX_EXPORT_MAX_IMAGE_TILES = 4 ONNX_EXPORT_IMAGE_WIDTH = 560 @@ -84,7 +84,7 @@ def get_models_dir(): ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512 ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80 ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99 -ONNX_EXPORT_OPSET = 13 +ONNX_EXPORT_OPSET = 17 COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"] DEFAULT_AIC_HW_VERSION = "ai100"