11from dataclasses import dataclass
2- from typing import Any , Callable , Dict , Type
2+ from typing import Any , Callable , Dict , Optional , Type , Union
33import torch
44import logging
55
88
99
1010@dataclass (frozen = True )
11- class ModuleReplacement :
11+ class Substitution :
1212 """Class to store key functionality for module replacement"""
1313
1414 # torch.ops.___ name for replacement function for module
1515 new_operator : torch ._ops .OpOverload
1616
17- # Function taking a containing graph, a submodule, and a 'call_module' node and returning
18- # a replacement node, with type 'call_function', or raising an Error if incompatibility is detected
17+ # Function taking a containing graph, a node, and optionally a submodule (if replacing a module)
18+ # and returning a replacement node, with type 'call_function', or raising an Error if
19+ # incompatibility is detected
1920 # Note: subgraph_insertion_fn should NOT delete nodes or recompile the graph
2021 subgraph_insertion_fn : Callable [
21- [torch .fx .GraphModule , torch .nn . Module , torch .fx . Node ], torch .fx .Node
22+ [torch .fx .GraphModule , torch .fx . Node , Optional [ torch .nn . Module ] ], torch .fx .Node
2223 ]
2324
2425
25- # Dictionary mapping module to ModuleReplacement instance
26- MODULE_SUBSTITUTION_REGISTRY : Dict [Type [torch .nn .Module ], ModuleReplacement ] = dict ()
26+ # Dictionary mapping module to Substitution instance
27+ SUBSTITUTION_REGISTRY : Dict [
28+ Union [Type [torch .nn .Module ], Callable ], Substitution
29+ ] = dict ()
2730
2831
29- def module_substitution (
30- module_to_replace : Type [torch .nn .Module ],
32+ def register_substitution (
33+ module_or_function_to_replace : Union [ Type [torch .nn .Module ], Callable ],
3134 new_operator : torch ._ops .OpOverload ,
3235 enabled : bool = True ,
3336) -> Callable [[Any ], Any ]:
3437 """Decorator to register subgraph insertion functions
3538
3639 Args:
37- module_to_replace : nn.Module to replace
40+ module_or_function_to_replace : nn.Module or node target Callable to replace
3841 new_operator: Custom torch operator to replace with
3942 enabled: Whether the substitution is enabled or disabled
4043 Returns:
4144 torch.fx.GraphModule
4245 """
4346
44- def register_substitution (subgraph_insertion_fn ):
47+ def enable_substitution (subgraph_insertion_fn ):
4548 """Function for use if substitution is enabled"""
46- module_replacement = ModuleReplacement (
49+ replacement = Substitution (
4750 new_operator = new_operator , subgraph_insertion_fn = subgraph_insertion_fn
4851 )
49- MODULE_SUBSTITUTION_REGISTRY [ module_to_replace ] = module_replacement
52+ SUBSTITUTION_REGISTRY [ module_or_function_to_replace ] = replacement
5053 return subgraph_insertion_fn
5154
5255 def disable_substitution (subgraph_insertion_fn ):
5356 """Function for use if substitution is disabled"""
5457 return subgraph_insertion_fn
5558
56- return register_substitution if enabled else disable_substitution
59+ return enable_substitution if enabled else disable_substitution
5760
5861
59- def pre_aot_module_replacement (gm : torch .fx .GraphModule ):
60- """Perform module-level graph replacement prior to AOT tracing
62+ def pre_aot_substitutions (gm : torch .fx .GraphModule ):
63+ """Perform graph substitutions prior to AOT tracing
6164
6265 Args:
63- gm: FX GraphModule to perform module replacement on
66+ gm: FX GraphModule to perform substitution on
6467 Returns:
6568 torch.fx.GraphModule
6669
@@ -73,48 +76,58 @@ def pre_aot_module_replacement(gm: torch.fx.GraphModule):
7376
7477 # Iterate over graph nodes, extracting module calls, to check for interceptions
7578 for n in gm .graph .nodes :
79+ exists_in_registry = False
80+ to_replace = None
81+
7682 if n .op == "call_module" :
77- # Extract submodule from graph
83+ # Extract submodule from graph, validate in registry
7884 submodule = gm .get_submodule (n .target )
79-
80- # If submodule is a member of the substitution registry, replace it
81- if type (submodule ) in MODULE_SUBSTITUTION_REGISTRY :
82-
83- try :
84- replacement = MODULE_SUBSTITUTION_REGISTRY [type (submodule )]
85- op , insertion_fn = (
86- replacement .new_operator ,
87- replacement .subgraph_insertion_fn ,
88- )
89- logger .debug (
90- f"Replacing module of type { type (submodule )} with { op } "
85+ to_replace = type (submodule )
86+ exists_in_registry = to_replace in SUBSTITUTION_REGISTRY
87+ elif n .op == "call_function" :
88+ # Extract function from graph, validate in registry
89+ to_replace = n .target
90+ exists_in_registry = n .target in SUBSTITUTION_REGISTRY
91+
92+ # If submodule/function is a member of the substitution registry, replace it
93+ if exists_in_registry :
94+ try :
95+ replacement = SUBSTITUTION_REGISTRY [to_replace ]
96+ op , insertion_fn = (
97+ replacement .new_operator ,
98+ replacement .subgraph_insertion_fn ,
99+ )
100+ logger .debug (f"Replacing node of type { to_replace } with { op } " )
101+
102+ # Insert new node prior to older node
103+ with gm .graph .inserting_before (n ):
104+ new_node = insertion_fn (
105+ gm , n , submodule if n .op == "call_module" else None
91106 )
92107
93- # Insert new node prior to older node
94- with gm .graph .inserting_before (n ):
95- new_node = insertion_fn (gm , submodule , n )
96-
97- # If submodule is not a native torch.nn module, it must be manually excluded
98- # from Dynamo tracing
99- if not type (submodule ).__module__ .startswith ("torch.nn" ):
100- torch ._dynamo .allowed_functions ._allowed_function_ids .add (
101- id (type (submodule ))
102- )
103-
104- # Replace all original node uses and clean up graph
105- n .replace_all_uses_with (new_node )
106- gm .graph .eliminate_dead_code ()
107- gm .graph .lint ()
108- gm .recompile ()
109-
110- # A module replacement can fail in the event that the specific instance of the submodule cannot
111- # be replaced
112- except Exception :
113- logger .debug (
114- f"Encountered error while replacing { type (submodule )} " ,
115- exc_info = True ,
108+ # If submodule is not a native torch.nn module, it must be manually excluded
109+ # from Dynamo tracing
110+ if n .op == "call_module" and not type (submodule ).__module__ .startswith (
111+ "torch.nn"
112+ ):
113+ torch ._dynamo .allowed_functions ._allowed_function_ids .add (
114+ id (to_replace )
116115 )
117- continue
116+
117+ # Replace all original node uses and clean up graph
118+ n .replace_all_uses_with (new_node )
119+ gm .graph .eliminate_dead_code ()
120+ gm .graph .lint ()
121+ gm .recompile ()
122+
123+ # A replacement can fail in the event that the specific instance of the submodule/function
124+ # cannot be replaced
125+ except Exception :
126+ logger .debug (
127+ f"Encountered error while replacing { to_replace } " ,
128+ exc_info = True ,
129+ )
130+ continue
118131
119132 # Perform cleanup and recompilation before returning module
120133 gm .graph .eliminate_dead_code ()
0 commit comments