Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from QEfficient.base.onnx_transforms import OnnxTransform
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.customop.ctx_scatter_gather import custom_translation_table
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.utils import (
constants,
Expand All @@ -35,6 +36,12 @@

logger = logging.getLogger(__name__)

import torch._dynamo.config # noqa

# Allow custom ops to be treated as black boxes
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True


class QEFFBaseModel(ABC):
"""
Expand Down Expand Up @@ -179,6 +186,7 @@ def _export(
onnx_transform_kwargs: Optional[Dict[str, any]] = None,
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
dynamic_shapes: Optional[Dict[str, Dict[int, any]]] = None,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
Expand Down Expand Up @@ -244,16 +252,27 @@ def _export(

try:
export_kwargs = {} if export_kwargs is None else export_kwargs
torch.onnx.export(
import time

start = time.perf_counter()
onnx_program = torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
# str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
# dynamic_axes=dynamic_axes,
dynamic_shapes=dynamic_shapes,
custom_translation_table=custom_translation_table,
opset_version=constants.ONNX_EXPORT_OPSET,
dynamo=True,
report=True,
**export_kwargs,
)
end = time.perf_counter()
print("Dynamo enabled onnx export in memory time in sec", round(end - start, 2))

onnx_program.save(str(tmp_onnx_path))
logger.info("PyTorch export successful")

_ = self._offload_model_weights(offload_pt_weights)
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/customop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#
# -----------------------------------------------------------------------------

from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc, CtxScatterFunc3D
from QEfficient.customop.ctx_scatter_gather import CtxGatherFunc, CtxGatherFunc3D, CtxScatterFunc3D
from QEfficient.customop.ctx_scatter_gather_cb import (
CtxGatherFuncCB,
CtxGatherFuncCB3D,
Expand All @@ -16,7 +16,7 @@

__all__ = [
"CtxGatherFunc",
"CtxScatterFunc",
# "CtxScatterFunc",
"CtxGatherFunc3D",
"CtxScatterFunc3D",
"CustomRMSNormAIC",
Expand Down
58 changes: 40 additions & 18 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
ops = getattr(onnxscript, "opset" + str(constants.ONNX_EXPORT_OPSET))


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
# @onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates: onnxscript.FLOAT) -> onnxscript.FLOAT:
# Find dims
batch_size = ops.Gather(ops.Shape(data), [0])
Expand All @@ -34,26 +34,43 @@ def CtxScatter(data: onnxscript.FLOAT, position_ids: onnxscript.INT32, updates:
return ops.ScatterND(data, indices, updates)


class CtxScatterFunc(torch.autograd.Function):
"""
Function to scatter the current key values into KV-cache.
"""
@torch.library.custom_op("qefficient::ctx_scatter", mutates_args=())
def ctx_scatter_op(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor:
"""Custom context scatter operation"""
result = data.clone()
batch_idx = torch.arange(result.shape[0]).view(-1, 1, 1)
head_idx = torch.arange(result.shape[1]).view(1, -1, 1)
ctx_idx = position_ids.unsqueeze(1)
result[batch_idx, head_idx, ctx_idx] = updates
return result

@staticmethod
def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1)
head_idx = torch.arange(data.shape[1]).view(1, -1, 1)
ctx_idx = position_ids.unsqueeze(1)
data[batch_idx, head_idx, ctx_idx] = updates
return data

@staticmethod
def setup_context(ctx, inputs, outputs):
pass
@ctx_scatter_op.register_fake
def _(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor) -> torch.Tensor:
"""Fake implementation for torch.export - just returns data tensor with same shape/dtype"""
return data.clone()

@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data)

# class CtxScatterFunc(torch.autograd.Function):
# """
# Function to scatter the current key values into KV-cache.
# """

# @staticmethod
# def forward(data: torch.Tensor, position_ids: torch.Tensor, updates: torch.Tensor):
# # batch_idx = torch.arange(data.shape[0]).view(-1, 1, 1)
# # head_idx = torch.arange(data.shape[1]).view(1, -1, 1)
# # ctx_idx = position_ids.unsqueeze(1)
# # data[batch_idx, head_idx, ctx_idx] = updates
# # return data

# @staticmethod
# def setup_context(ctx, inputs, outputs):
# pass

# @staticmethod
# def symbolic(g: torch.Graph, data: torch.Value, position_ids: torch.Value, updates: torch.Value) -> torch.Value:
# return g.onnxscript_op(CtxScatter, data, position_ids, updates).setTypeAs(data)


@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
Expand Down Expand Up @@ -139,3 +156,8 @@ def setup_context(ctx, inputs, outputs):
@staticmethod
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)


custom_translation_table = {
torch.ops.qefficient.ctx_scatter.default: CtxScatter,
}
32 changes: 23 additions & 9 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
CtxGatherFunc3D,
CtxGatherFuncCB,
CtxGatherFuncCB3D,
CtxScatterFunc,
# CtxScatterFunc,
CtxScatterFunc3D,
CtxScatterFuncCB,
CtxScatterFuncCB3D,
Expand Down Expand Up @@ -90,8 +90,10 @@ def write_only(self, key_states, value_states, cache_kwargs):
self.keys = CtxScatterFuncCB.apply(self.keys, batch_index, scatter_position_ids, key_states)
self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states)
else:
self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states)
self.values = CtxScatterFunc.apply(self.values, position_ids, value_states)
# self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states)
# self.values = CtxScatterFunc.apply(self.values, position_ids, value_states)
self.keys = torch.ops.qefficient.ctx_scatter(self.keys, position_ids, key_states)
self.values = torch.ops.qefficient.ctx_scatter(self.values, position_ids, value_states)

def update(
self,
Expand Down Expand Up @@ -131,8 +133,10 @@ def update(

self.values = CtxScatterFuncCB.apply(self.values, batch_index, scatter_position_ids, value_states)
else:
self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states)
self.values = CtxScatterFunc.apply(self.values, position_ids, value_states)
# self.keys = CtxScatterFunc.apply(self.keys, position_ids, key_states)
# self.values = CtxScatterFunc.apply(self.values, position_ids, value_states)
self.keys = torch.ops.qefficient.ctx_scatter(self.keys, position_ids, key_states)
self.values = torch.ops.qefficient.ctx_scatter(self.values, position_ids, value_states)

k_out, v_out = self.keys, self.values

Expand Down Expand Up @@ -407,10 +411,15 @@ def update(
valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
# self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
# self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states)
self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(
self.key_cache[layer_idx], kv_position_ids, key_states
)
self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(
self.value_cache[layer_idx], kv_position_ids, value_states
)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
Expand Down Expand Up @@ -509,10 +518,15 @@ def update(
valid_mask = (kv_position_ids != -1).unsqueeze(1).unsqueeze(-1)
key_states = torch.where(valid_mask == 1, key_states, torch.zeros_like(key_states))
value_states = torch.where(valid_mask == 1, value_states, torch.zeros_like(value_states))
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
# self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
# self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], kv_position_ids, value_states)
self.key_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(
self.key_cache[layer_idx], kv_position_ids, key_states
)
self.value_cache[layer_idx] = torch.ops.qefficient.ctx_scatter(
self.value_cache[layer_idx], kv_position_ids, value_states
)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Original Gather
Expand Down
96 changes: 96 additions & 0 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,99 @@
from QEfficient.utils.logging_utils import logger


def convert_dynamic_axes_to_dynamic_shapes(dynamic_axes: Dict[str, Dict[int, str]]) -> Dict[str, any]:
"""
Convert ONNX dynamic_axes format to torch.export dynamic_shapes format

Args:
dynamic_axes: ONNX format like {"input_ids": {0: "batch_size", 1: "seq_len"}}

Returns:
dynamic_shapes: torch.export format with Dim objects matching model forward args
"""
from torch.export import Dim

# Create dimension registry to reuse Dim objects with same names
dim_registry = {}
dynamic_shapes = {}

# Handle regular model inputs (not past_key_values)
# These match the QEffLlamaForCausalLM forward signature:
# input_ids, attention_mask, position_ids, past_key_values, batch_index, etc.
for input_name, axes_map in dynamic_axes.items():
if not input_name.startswith("past_"):
input_dynamic_shapes = {}
for axis_idx, dim_name in axes_map.items():
# Create or reuse Dim object for this dimension name
if dim_name not in dim_registry:
if dim_name == "batch_size":
dim_registry[dim_name] = Dim.DYNAMIC
elif "seq_len" in dim_name:
dim_registry[dim_name] = Dim.DYNAMIC
else:
dim_registry[dim_name] = Dim.DYNAMIC

input_dynamic_shapes[axis_idx] = dim_registry[dim_name]

dynamic_shapes[input_name] = input_dynamic_shapes

# Handle past_key_values specially - collect all past_key.X and past_value.X
past_keys = {}
past_values = {}

for input_name, axes_map in dynamic_axes.items():
if input_name.startswith("past_key."):
layer_idx = int(input_name.split(".")[1])
layer_dynamic_shapes = {}
for axis_idx, dim_name in axes_map.items():
if dim_name not in dim_registry:
if dim_name == "batch_size":
dim_registry[dim_name] = Dim.DYNAMIC
elif "seq_len" in dim_name:
dim_registry[dim_name] = Dim.DYNAMIC
else:
dim_registry[dim_name] = Dim.DYNAMIC
layer_dynamic_shapes[axis_idx] = dim_registry[dim_name]
past_keys[layer_idx] = layer_dynamic_shapes

elif input_name.startswith("past_value."):
layer_idx = int(input_name.split(".")[1])
layer_dynamic_shapes = {}
for axis_idx, dim_name in axes_map.items():
if dim_name not in dim_registry:
if dim_name == "batch_size":
dim_registry[dim_name] = Dim.DYNAMIC
elif "seq_len" in dim_name:
dim_registry[dim_name] = Dim.DYNAMIC
else:
dim_registry[dim_name] = Dim.DYNAMIC
layer_dynamic_shapes[axis_idx] = dim_registry[dim_name]
past_values[layer_idx] = layer_dynamic_shapes

# Reconstruct past_key_values as nested structure if we have past keys/values
if past_keys or past_values:
max_layer = max(list(past_keys.keys()) + list(past_values.keys()))
past_kv_shapes = []

for layer_idx in range(max_layer + 1):
layer_shapes = []
if layer_idx in past_keys:
layer_shapes.append(past_keys[layer_idx])
else:
layer_shapes.append({})

if layer_idx in past_values:
layer_shapes.append(past_values[layer_idx])
else:
layer_shapes.append({})

past_kv_shapes.append(layer_shapes)

dynamic_shapes["past_key_values"] = past_kv_shapes

return dynamic_shapes


class QEFFTransformersBase(QEFFBaseModel):
"""
Base class for QEfficient wrappers around HuggingFace transformer models.
Expand Down Expand Up @@ -2310,11 +2403,14 @@ def export(self, export_dir: Optional[str] = None) -> str:
dynamic_axes=dynamic_axes,
)

dynamic_shapes = convert_dynamic_axes_to_dynamic_shapes(dynamic_axes)

return self._export(
example_inputs,
output_names,
dynamic_axes,
export_dir=export_dir,
dynamic_shapes=dynamic_shapes,
)

def get_sampling_inputs_and_outputs(
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ONNX_EXPORT_EXAMPLE_SEQ_LEN = 32
ONNX_EXPORT_EXAMPLE_FBS = 4
ONNX_EXPORT_EXAMPLE_NLK = 2 # Number of Logits to Keep
ONNX_EXPORT_OPSET = 13
ONNX_EXPORT_OPSET = 17
ONNX_EXPORT_MAX_NUM_IMAGES = 1
ONNX_EXPORT_MAX_IMAGE_TILES = 4
ONNX_EXPORT_IMAGE_WIDTH = 560
Expand Down Expand Up @@ -84,7 +84,7 @@ def get_models_dir():
ONNX_EXPORT_EXAMPLE_MAX_TOP_K_IDS = 512
ONNX_EXPORT_EXAMPLE_TOP_PS = 0.80
ONNX_EXPORT_EXAMPLE_MIN_PS = 0.99
ONNX_EXPORT_OPSET = 13
ONNX_EXPORT_OPSET = 17

COMPILER = ["/opt/qti-aic/exec/qaic-exec", "-aic-hw"]
DEFAULT_AIC_HW_VERSION = "ai100"
Expand Down
Loading