Skip to content

Commit 79414da

Browse files
Incoperated all subfunction changes
Signed-off-by: Abhishek Kumar Singh <sabhis@qti.qualcomm.com>
1 parent cb7da87 commit 79414da

File tree

6 files changed

+195
-6
lines changed

6 files changed

+195
-6
lines changed

QEfficient/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import QEfficient.utils.model_registery # noqa: F401
1212
from QEfficient.utils import custom_format_warning
1313
from QEfficient.utils.logging_utils import logger
14+
from QEfficient.utils.patches import apply_torch_patches, is_patched
1415

1516
# For faster downloads via hf_transfer
1617
# This code is put above import statements as this needs to be executed before
@@ -22,6 +23,9 @@
2223
# custom warning for the better logging experience
2324
warnings.formatwarning = custom_format_warning
2425

26+
# Apply patches
27+
# TODO: Find a better way to do this, this is temp. fix.
28+
apply_torch_patches()
2529

2630
def check_qaic_sdk():
2731
"""Check if QAIC SDK is installed"""
@@ -70,6 +74,8 @@ def check_qaic_sdk():
7074
"QEFFAutoModelForImageTextToText",
7175
"QEFFAutoModelForSpeechSeq2Seq",
7276
"QEFFCommonLoader",
77+
"apply_torch_patches",
78+
"is_patched",
7379
]
7480

7581
else:

QEfficient/base/modeling_qeff.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,13 @@
1818
import onnx
1919
import torch
2020

21-
from QEfficient.base.onnx_transforms import OnnxTransform
21+
from QEfficient.base.onnx_transforms import CustomOpTransform, OnnxTransform, rename_function_outputs
2222
from QEfficient.base.pytorch_transforms import PytorchTransform
2323
from QEfficient.compile.qnn_compiler import compile as qnn_compile
24+
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
25+
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
2426
from QEfficient.generation.cloud_infer import QAICInferenceSession
27+
from QEfficient.transformers.models.pytorch_transforms import get_decoder_layer_classes_for_export
2528
from QEfficient.utils import (
2629
constants,
2730
create_json,
@@ -243,22 +246,31 @@ def _export(
243246
input_names.append(param)
244247

245248
try:
249+
CustomOpTransform.register_custom_op("CustomRMSNormFunc", CustomRMSNormFunc, CustomRMSNorm)
250+
CustomOpTransform.register_custom_op("CtxScatterFunc", CtxScatterFunc, CtxScatter)
251+
CustomOpTransform.register_custom_op("CtxGatherFunc", CtxGatherFunc, CtxGather)
252+
decoder_layer_classes = get_decoder_layer_classes_for_export(self.model)
246253
export_kwargs = {} if export_kwargs is None else export_kwargs
254+
247255
torch.onnx.export(
248256
self.model,
249257
(example_inputs,),
250258
str(tmp_onnx_path),
251259
input_names=input_names,
252260
output_names=output_names,
253261
dynamic_axes=dynamic_axes,
254-
opset_version=constants.ONNX_EXPORT_OPSET,
262+
opset_version=17,
263+
export_modules_as_functions=decoder_layer_classes,
264+
do_constant_folding=True,
265+
verbose=True,
255266
**export_kwargs,
256267
)
257268
logger.info("PyTorch export successful")
258269

259270
_ = self._offload_model_weights(offload_pt_weights)
260-
261271
model = onnx.load(tmp_onnx_path, load_external_data=False)
272+
model,transformed = rename_function_outputs(model)
273+
262274
transform_kwargs = {
263275
"onnx_base_dir": str(tmp_onnx_dir),
264276
"model_name": self.model_name,

QEfficient/base/onnx_transforms.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,29 @@ def apply(
9999
current_file_size = tsize
100100
external_data_helper.set_external_data(tensor, f"{model_name}_{file_num}.onnx.data")
101101
return model, transformed
102+
103+
def rename_function_outputs(model):
104+
graph = model.graph
105+
op_type_to_func_map = {func.name:func for func in model.functions}
106+
decoder_layer_patterns = ["DecoderLayer", "Block", "Layer"]
107+
transformed = False
108+
model_graph_outputs = [val.name for val in model.graph.output]
109+
node_count = 0
110+
for node in graph.node:
111+
if any(pattern in node.name or pattern in node.op_type for pattern in decoder_layer_patterns):
112+
func = op_type_to_func_map[node.op_type]
113+
for i, out_name in enumerate(func.output):
114+
if "_InternalRetainedState" in out_name:
115+
transformed = True
116+
tmp = node.output[i]
117+
if "key" in func.output[i]:
118+
new_name = f"past_key.{node_count}_RetainedState"
119+
elif "value" in func.output[i]:
120+
new_name= f"past_value.{node_count}_RetainedState"
121+
else:
122+
raise NotImplementedError()
123+
print(f"renaming {node.output[i]} to {new_name}")
124+
node.output[i] = new_name
125+
model.graph.output[model_graph_outputs.index(tmp)].name = new_name
126+
node_count+=1
127+
return model, transformed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,7 @@ def export(self, export_dir: Optional[str] = None) -> str:
347347
dynamic_axes,
348348
export_dir=export_dir,
349349
)
350-
350+
351351
def compile(
352352
self,
353353
onnx_path: Optional[str] = None,
@@ -2285,14 +2285,14 @@ def export(self, export_dir: Optional[str] = None) -> str:
22852285
for kv in ["key", "value"]:
22862286
example_inputs["past_key_values"][i].append(torch.zeros(pkv_cache[0][0].shape, dtype=torch.float32))
22872287
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
2288-
output_names.append(f"past_{kv}.{i}_RetainedState")
2288+
output_names.append(f"past_{kv}.{i}_InternalRetainedState")
22892289

22902290
else:
22912291
for i in range(self.num_layers):
22922292
for kv in ["key", "value"]:
22932293
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
22942294
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes
2295-
output_names.append(f"past_{kv}.{i}_RetainedState")
2295+
output_names.append(f"past_{kv}.{i}_InternalRetainedState")
22962296

22972297
if self.continuous_batching:
22982298
example_inputs["batch_index"] = torch.arange(bs).view(bs, 1)

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,3 +788,28 @@ def apply(cls, model: nn.Module, pooling: Union[str, Callable]) -> Tuple[nn.Modu
788788
model = PooledModel(model, pooling_method)
789789
warnings.warn("Pooling is applied to the model.")
790790
return model, transformed
791+
792+
def get_decoder_layer_classes_for_export(model: nn.Module) -> set:
793+
"""
794+
Dynamically determine which DecoderLayer classes should be exported as functions
795+
based on the model's architecture using the existing KVCacheTransform mapping.
796+
"""
797+
# Define patterns that identify decoder layer classes
798+
DECODER_LAYER_PATTERNS = ["DecoderLayer", "Block", "Layer"]
799+
800+
# Get all QEff classes that are decoder layers from the existing mapping
801+
decoder_layer_classes = set()
802+
803+
for original_class, qeff_class in KVCacheTransform._module_mapping.items():
804+
# Check if the QEff class name contains decoder layer patterns
805+
qeff_class_name = qeff_class.__name__
806+
if any(pattern in qeff_class_name for pattern in DECODER_LAYER_PATTERNS):
807+
decoder_layer_classes.add(qeff_class)
808+
809+
# Filter to only include classes that are actually used in the current model
810+
model_decoder_classes = set()
811+
for module in model.modules():
812+
if module.__class__ in decoder_layer_classes:
813+
model_decoder_classes.add(module.__class__)
814+
815+
return model_decoder_classes

QEfficient/utils/patches.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
"""Monkey patches for torch.onnx.utils to fix ONNX export issues."""
9+
10+
from typing import Collection, Set, Type, Union
11+
12+
import torch
13+
import torch.onnx.utils as onnx_utils
14+
from torch import _C
15+
16+
17+
def _setup_trace_module_map_patched(
18+
model: Union[torch.nn.Module, torch.jit.ScriptModule],
19+
export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]],
20+
) -> Set[str]:
21+
"""Patched version of _setup_trace_module_map that fixes onnx_attrs type mismatch."""
22+
23+
def __register_attribute_hook():
24+
attr_name = "_onnx_attrs"
25+
26+
def _track_module_attributes_forward_pre_hook(module, input):
27+
setattr(module, attr_name, _get_module_attributes(module))
28+
29+
def _track_module_attributes_forward_hook(module, input, output):
30+
tracing_state = _C._get_tracing_state()
31+
if not tracing_state:
32+
return
33+
graph = tracing_state.graph()
34+
onnx_attrs = {}
35+
if hasattr(module, attr_name):
36+
onnx_attrs = getattr(module, attr_name)
37+
delattr(module, attr_name)
38+
# FIX: use empty dict to avoid type mismatch with _jit_pass_onnx_track_scope_attributes
39+
# Observed in transformers v4.55 and above
40+
onnx_attrs = {}
41+
_C._jit_pass_onnx_track_scope_attributes(graph, onnx_attrs)
42+
43+
for m in model.modules():
44+
m.register_forward_hook(_track_module_attributes_forward_hook)
45+
m.register_forward_pre_hook(_track_module_attributes_forward_pre_hook)
46+
47+
def _unqualified_variable_name(qualified_name: str) -> str:
48+
"""
49+
Parse qualified variable name and return the unqualified version.
50+
Pure numeric atoms are considered inadequate, so this function will look past them,
51+
and start from the first non-numeric atom.
52+
"""
53+
name_atoms = qualified_name.split(".")
54+
for i, atom in reversed(list(enumerate(name_atoms))):
55+
if not atom.isnumeric():
56+
return ".".join(name_atoms[i:])
57+
return qualified_name
58+
59+
trace_module_map = {
60+
_m: torch._C._jit_onnx_create_full_scope_name(torch.typename(type(_m)), _unqualified_variable_name(_n))
61+
for _n, _m in model.named_modules()
62+
}
63+
torch.jit._trace._trace_module_map = trace_module_map
64+
65+
if isinstance(export_modules_as_functions, bool) and export_modules_as_functions:
66+
module_typenames = {torch.typename(type(module)) for module in trace_module_map}
67+
elif isinstance(export_modules_as_functions, set) and export_modules_as_functions:
68+
69+
def _find_typename(v):
70+
if isinstance(v, type):
71+
return torch.typename(v)
72+
else:
73+
raise RuntimeError(
74+
"Only type of the `nn.Module` should be "
75+
"passed in the set for argument `export_modules_as_functions`. "
76+
f"Got `{type(v).__name__}`."
77+
)
78+
79+
module_typenames = {_find_typename(v) for v in export_modules_as_functions}
80+
else:
81+
module_typenames = set()
82+
83+
if module_typenames:
84+
__register_attribute_hook()
85+
86+
return module_typenames
87+
88+
89+
def _get_module_attributes(module):
90+
"""Helper function to get module attributes safely."""
91+
import typing
92+
93+
annotations = typing.get_type_hints(type(module))
94+
base_m_annotations = typing.get_type_hints(torch.nn.Module)
95+
[annotations.pop(k, None) for k in base_m_annotations]
96+
97+
attrs = {}
98+
for k in annotations:
99+
try:
100+
attrs[k] = getattr(module, k)
101+
except AttributeError:
102+
_C._jit_onnx_log(f"Skipping module attribute '{k}'")
103+
continue
104+
return attrs
105+
106+
107+
def apply_torch_patches():
108+
"""Apply all necessary torch patches for ONNX export."""
109+
# Monkey patch the function
110+
onnx_utils._setup_trace_module_map = _setup_trace_module_map_patched
111+
112+
if hasattr(onnx_utils, "_get_module_attributes"):
113+
onnx_utils._get_module_attributes = _get_module_attributes
114+
115+
print("Applied torch ONNX export patches for export_modules_as_functions compatibility")
116+
117+
118+
def is_patched():
119+
"""Check if patches have been applied."""
120+
return onnx_utils._setup_trace_module_map == _setup_trace_module_map_patched

0 commit comments

Comments
 (0)