11import unittest
22
33import torch
4+ import torchdynamo
5+ import torchvision
6+
7+ from functorch import make_fx as make_fx_pk
8+ from functorch .experimental import functionalize
49from fx2trt_oss .tracer .dispatch_tracer .tracer import make_fx
10+ from torch .library import Library
11+ from torchdynamo .optimizations .normalize import normalize_ir
12+ from torchdynamo .optimizations .python_key import fake_signature
513
614torch .manual_seed (0 )
715
16+ wrap_lib = Library ("wrap" , "DEF" )
17+ """
18+ There are two methods for setting leaf_module. leaf(op registeration) and leaf(override call_module)
19+ Only leaf(op registeration) can work together with functionalize.
20+ If you do not need funcitonalize, you can choose any of the leaf module methods.
21+
22+ Test coverage:
23+ PythonkeyTracerTest.test_leaf_operator_reg: python_key tracer + functionalize + leaf(op registeration)
24+
25+ DispatchTracerTest.test_leaf_operator_reg: dispatch tracer + functionalize + leaf(op registeration)
26+ DispatchTracerTest.test_leaf: dispatch tracer + leaf(override call_module)
27+ DispatchTracerTest.test_non_tensor_input: dispatch tracer
28+ DispatchTracerTest.test_resnet18: dispatch tracer
29+ DispatchTracerTest.test_reference_copy: dispatch tracer + functionalize
30+ DispatchTracerTest.test_reference_copy_torchdynamo: dispatcher tracer + torchdynamo + functionalize
31+ """
32+
33+
34+ class PythonkeyTracerTest (unittest .TestCase ):
35+ def test_leaf_operator_reg (self ):
36+ class Leaf (torch .nn .Module ):
37+ def forward (self , x , y ):
38+ return x + y + torch .nn .Parameter (torch .ones (5 ))
39+
40+ leaf = Leaf ()
41+ wrap_lib .define ("wrapped_foo(Tensor x, Tensor y) -> Tensor" )
42+ wrap_lib .impl ("wrapped_foo" , leaf , "CPU" )
43+
44+ class Bar (torch .nn .Module ):
45+ def __init__ (self ):
46+ super (Bar , self ).__init__ ()
47+ self .foo = torch .ops .wrap .wrapped_foo
48+ self .other = torch .nn .Parameter (torch .ones (5 ))
49+
50+ def forward (self , x , y ):
51+ x = self .foo (x , y )
52+ x = x + self .other
53+ return x
54+
55+ mod = Bar ()
56+
57+ def f (x , y ):
58+ return mod (x , y )
59+
60+ gm = make_fx_pk (functionalize (f ))(torch .ones (5 ), torch .ones (5 ))
61+ inputs = [torch .ones (5 ) + 5 , torch .ones (5 ) + 8 ]
62+ output = gm (* inputs )
63+ ref_output = f (* inputs )
64+ torch .testing .assert_close (output , ref_output )
65+
866
967class DispatchTracerTest (unittest .TestCase ):
10- def test_leaf_module_list (self ):
11- class TestModule (torch .nn .Module ):
68+ def test_leaf_operator_reg (self ):
69+ class Leaf (torch .nn .Module ):
70+ def forward (self , x , y ):
71+ return x + y + torch .nn .Parameter (torch .ones (5 ))
72+
73+ leaf = Leaf ()
74+ wrap_lib .define ("wrapped_leaf(Tensor x, Tensor y) -> Tensor" )
75+ wrap_lib .impl ("wrapped_leaf" , leaf , "CPU" )
76+
77+ class Bar (torch .nn .Module ):
78+ def __init__ (self ):
79+ super (Bar , self ).__init__ ()
80+ self .leaf = torch .ops .wrap .wrapped_leaf
81+ self .other = torch .nn .Parameter (torch .ones (5 ))
82+
83+ def forward (self , x , y ):
84+ x = self .leaf (x , y )
85+ x = x + self .other
86+ return x
87+
88+ mod = Bar ()
89+
90+ def f (x , y ):
91+ return mod (x , y )
92+
93+ gm = make_fx (functionalize (f ))(torch .ones (5 ), torch .ones (5 ))
94+ inputs = [torch .ones (5 ) + 5 , torch .ones (5 ) + 8 ]
95+ output = gm (* inputs )
96+ ref_output = f (* inputs )
97+ torch .testing .assert_close (output , ref_output )
98+ # through the op registration method, the module is defined in a call_function
99+ call_function_node = None
100+ for node in gm .graph .nodes :
101+ if (
102+ node .op == "call_function"
103+ and node .target == torch .ops .wrap .wrapped_leaf
104+ ):
105+ call_function_node = node
106+ self .assertIsNotNone (call_function_node )
107+
108+ def test_leaf (self ):
109+ class TestModuleLeaf (torch .nn .Module ):
12110 def __init__ (self ):
13111 super ().__init__ ()
14112 self .conv = torch .nn .Conv2d (3 , 10 , 1 )
15- self .relu = torch .nn .ReLU ()
113+ self .relu = torch .nn .ReLU (inplace = True )
16114
17115 def forward (self , x ):
18116 x = self .conv (x )
19117 return self .relu (x )
20118
119+ class TestModule (torch .nn .Module ):
120+ def __init__ (self ):
121+ super ().__init__ ()
122+
123+ self .relu = torch .nn .ReLU (inplace = True )
124+ self .leaf = TestModuleLeaf ()
125+
126+ def forward (self , x ):
127+ x = self .leaf (x )
128+ return self .relu (x )
129+
21130 mod = TestModule ()
22131
23132 def f (x ):
24133 return mod (x )
25134
26135 a = torch .randn (1 , 3 , 1 , 1 )
27136 ref_output = f (a )
28- func = make_fx (f , leaf_module_list = {"torch.nn.modules.activation.ReLU " })
137+ func = make_fx (f , leaf_module_list = {"test_dispatch_tracer.TestModuleLeaf " })
29138 gm = func (a )
30139 output = gm (a )
31140 torch .testing .assert_close (output , ref_output )
@@ -36,17 +145,90 @@ def f(x):
36145 if node .op == "call_module" :
37146 call_module_node = node
38147 self .assertIsNotNone (call_module_node )
39- self .assertEqual (call_module_node .target , "ReLU_0 " )
148+ self .assertEqual (call_module_node .target , "TestModuleLeaf_0 " )
40149
41150 def test_non_tensor_input (self ):
42151 def foo (x ):
43152 a = x ["a" ]
44153 b = x ["b" ]
45154 return a + b
46155
47- x = {"a" : torch .randn (1 ), "b" : torch .randn (1 )}
156+ x = {"a" : torch .randn (2 , 2 ), "b" : torch .randn (2 , 2 )}
48157 ref_output = foo (x )
49158 func = make_fx (foo )
50159 gm = func (x )
51160 output = gm (x )
52161 torch .testing .assert_close (output , ref_output )
162+
163+ def test_resnet18 (self ):
164+ mod = torchvision .models .resnet18 (pretrained = False )
165+
166+ def f (x ):
167+ return mod (x )
168+
169+ a = torch .randn (1 , 3 , 224 , 224 )
170+ ref_output = f (a )
171+ gm = make_fx (f )(a )
172+ output = gm (a )
173+ torch .testing .assert_close (output , ref_output )
174+
175+ def test_reference_copy (self ):
176+ class TestModule (torch .nn .Module ):
177+ def __init__ (self ):
178+ super ().__init__ ()
179+
180+ def forward (self , x , y ):
181+ y [:, 0 ] = x [:, 0 ]
182+ return y
183+
184+ mod = TestModule ()
185+
186+ def f (x , y ):
187+ return mod (x , y )
188+
189+ a = torch .ones (2 , 2 ) + 2
190+ b = torch .ones (2 , 2 )
191+ b_copy = torch .ones (2 , 2 )
192+ ref_output = f (a , b )
193+ gm = make_fx (functionalize (f ))(a , b )
194+ output = gm (a , b_copy )
195+ torch .testing .assert_close (output , ref_output )
196+
197+ def test_reference_copy_torchdynamo (self ):
198+ class TestModule (torch .nn .Module ):
199+ def __init__ (self ):
200+ super ().__init__ ()
201+ self .relu = torch .nn .ReLU (inplace = True )
202+
203+ def forward (self , x , y ):
204+ y = y + 3
205+ y = self .relu (y )
206+ y [:, 0 ] = x [:, 0 ]
207+ return y
208+
209+ mod = TestModule ()
210+
211+ def f (x , y ):
212+ return mod (x , y )
213+
214+ a = torch .ones (2 , 2 ) + 2
215+ b = torch .ones (2 , 2 )
216+ inputs = [a , b ]
217+ ref_output = f (* inputs )
218+
219+ def compile_dispatch (gm , example_inputs ):
220+ # after normalization, relu in-place is removed
221+ gm = normalize_ir (gm , example_inputs )
222+ # dispatch tracer
223+ nargs = len (example_inputs )
224+ gm = make_fx (functionalize (fake_signature (gm , nargs )))(* example_inputs )
225+ return gm
226+
227+ optimize_ctx = torchdynamo .optimize (
228+ compile_dispatch ,
229+ nopython = True ,
230+ )
231+
232+ with optimize_ctx :
233+ output = mod (* inputs )
234+ torch .testing .assert_close (output , ref_output )
0 commit comments