1+ # -----------------------------------------------------------------------------
2+ #
3+ # Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+ # SPDX-License-Identifier: BSD-3-Clause
5+ #
6+ # -----------------------------------------------------------------------------
7+
8+ """Monkey patches for torch.onnx.utils to fix ONNX export issues."""
9+
10+ from typing import Collection , Set , Type , Union
11+
12+ import torch
13+ import torch .onnx .utils as onnx_utils
14+ from torch import _C
15+
16+
17+ def _setup_trace_module_map_patched (
18+ model : Union [torch .nn .Module , torch .jit .ScriptModule ],
19+ export_modules_as_functions : Union [bool , Collection [Type [torch .nn .Module ]]],
20+ ) -> Set [str ]:
21+ """Patched version of _setup_trace_module_map that fixes onnx_attrs type mismatch."""
22+
23+ def __register_attribute_hook ():
24+ attr_name = "_onnx_attrs"
25+
26+ def _track_module_attributes_forward_pre_hook (module , input ):
27+ setattr (module , attr_name , _get_module_attributes (module ))
28+
29+ def _track_module_attributes_forward_hook (module , input , output ):
30+ tracing_state = _C ._get_tracing_state ()
31+ if not tracing_state :
32+ return
33+ graph = tracing_state .graph ()
34+ onnx_attrs = {}
35+ if hasattr (module , attr_name ):
36+ onnx_attrs = getattr (module , attr_name )
37+ delattr (module , attr_name )
38+ # FIX: use empty dict to avoid type mismatch with _jit_pass_onnx_track_scope_attributes
39+ # Observed in transformers v4.55 and above
40+ onnx_attrs = {}
41+ _C ._jit_pass_onnx_track_scope_attributes (graph , onnx_attrs )
42+
43+ for m in model .modules ():
44+ m .register_forward_hook (_track_module_attributes_forward_hook )
45+ m .register_forward_pre_hook (_track_module_attributes_forward_pre_hook )
46+
47+ def _unqualified_variable_name (qualified_name : str ) -> str :
48+ """
49+ Parse qualified variable name and return the unqualified version.
50+ Pure numeric atoms are considered inadequate, so this function will look past them,
51+ and start from the first non-numeric atom.
52+ """
53+ name_atoms = qualified_name .split ("." )
54+ for i , atom in reversed (list (enumerate (name_atoms ))):
55+ if not atom .isnumeric ():
56+ return "." .join (name_atoms [i :])
57+ return qualified_name
58+
59+ trace_module_map = {
60+ _m : torch ._C ._jit_onnx_create_full_scope_name (torch .typename (type (_m )), _unqualified_variable_name (_n ))
61+ for _n , _m in model .named_modules ()
62+ }
63+ torch .jit ._trace ._trace_module_map = trace_module_map
64+
65+ if isinstance (export_modules_as_functions , bool ) and export_modules_as_functions :
66+ module_typenames = {torch .typename (type (module )) for module in trace_module_map }
67+ elif isinstance (export_modules_as_functions , set ) and export_modules_as_functions :
68+
69+ def _find_typename (v ):
70+ if isinstance (v , type ):
71+ return torch .typename (v )
72+ else :
73+ raise RuntimeError (
74+ "Only type of the `nn.Module` should be "
75+ "passed in the set for argument `export_modules_as_functions`. "
76+ f"Got `{ type (v ).__name__ } `."
77+ )
78+
79+ module_typenames = {_find_typename (v ) for v in export_modules_as_functions }
80+ else :
81+ module_typenames = set ()
82+
83+ if module_typenames :
84+ __register_attribute_hook ()
85+
86+ return module_typenames
87+
88+
89+ def _get_module_attributes (module ):
90+ """Helper function to get module attributes safely."""
91+ import typing
92+
93+ annotations = typing .get_type_hints (type (module ))
94+ base_m_annotations = typing .get_type_hints (torch .nn .Module )
95+ [annotations .pop (k , None ) for k in base_m_annotations ]
96+
97+ attrs = {}
98+ for k in annotations :
99+ try :
100+ attrs [k ] = getattr (module , k )
101+ except AttributeError :
102+ _C ._jit_onnx_log (f"Skipping module attribute '{ k } '" )
103+ continue
104+ return attrs
105+
106+
107+ def apply_torch_patches ():
108+ """Apply all necessary torch patches for ONNX export."""
109+ # Monkey patch the function
110+ onnx_utils ._setup_trace_module_map = _setup_trace_module_map_patched
111+
112+ if hasattr (onnx_utils , "_get_module_attributes" ):
113+ onnx_utils ._get_module_attributes = _get_module_attributes
114+
115+ print ("Applied torch ONNX export patches for export_modules_as_functions compatibility" )
116+
117+
118+ def is_patched ():
119+ """Check if patches have been applied."""
120+ return onnx_utils ._setup_trace_module_map == _setup_trace_module_map_patched
0 commit comments