1010from executorch .backends .arm ._passes .arm_pass_utils import (
1111 create_node ,
1212 get_first_fake_tensor ,
13+ is_param_node ,
1314)
1415from executorch .backends .arm .tosa_utils import is_consumer_node_depthwise_conv2d
16+ from executorch .exir import ExportedProgram
1517from executorch .exir .dialects ._ops import ops as exir_ops
1618from executorch .exir .pass_base import ExportPass , PassResult
1719
1820
19- class AnnotateChannelsLastDimOrder (ExportPass ):
21+ def _is_input (node : torch .fx .Node , exported_program : ExportedProgram ) -> bool :
22+ """
23+ Returns True if the node is an input node, i.e. a placeholder or a parameter.
24+ """
25+ return node .op == "placeholder" and not is_param_node (exported_program , node )
26+
27+
28+ class ToTosaMemoryFormatPass (ExportPass ):
2029 """
2130 Annotates each node with a tosa_dim_order. tosa_dim_order can be seen as a channels-last dim-order
2231 that in most cases will be (0, 2, 3, 1) for nodes with 4D-shapes. The pass also inserts backend.tosa.TRANSPOSE
@@ -30,6 +39,10 @@ class AnnotateChannelsLastDimOrder(ExportPass):
3039 NNHWC_order = (0 , 1 , 3 , 4 , 2 )
3140 NNHWC_inverse_order = (0 , 1 , 4 , 2 , 3 )
3241
42+ def __init__ (self , exported_program : ExportedProgram ) -> None :
43+ self .exported_program = exported_program
44+ super ().__init__ ()
45+
3346 def is_weight_node_for_depthwise_conv2d (self , node : torch .fx .Node ):
3447 """
3548 returns True for w in the following sequence;
@@ -92,25 +105,30 @@ def is_channel_reshape(input_shape, output_shape):
92105
93106 @staticmethod
94107 def insert_input_transpose (node , input_node , graph_module ):
108+ if input_node .target == exir_ops .backend .tosa .TRANSPOSE .default :
109+ pre_permute_node = input_node .all_input_nodes [0 ]
110+ node .replace_input_with (input_node , pre_permute_node )
111+ return
112+
95113 with graph_module .graph .inserting_before (node ):
96114 permute_node = create_node (
97115 graph_module .graph ,
98116 exir_ops .backend .tosa .TRANSPOSE .default ,
99117 args = (
100118 input_node ,
101119 list (
102- AnnotateChannelsLastDimOrder .NNHWC_inverse_order
120+ ToTosaMemoryFormatPass .NNHWC_inverse_order
103121 if len (get_first_fake_tensor (input_node ).size ()) == 5
104- else AnnotateChannelsLastDimOrder .NHWC_inverse_order
122+ else ToTosaMemoryFormatPass .NHWC_inverse_order
105123 ),
106124 ),
125+ from_node = node ,
107126 )
108127 node .replace_input_with (input_node , permute_node )
109128
110129 permute_node .meta ["tosa_dim_order" ] = tuple (
111130 range (len (input_node .meta ["val" ].size ()))
112131 )
113- permute_node .meta ["val" ] = input_node .meta ["val" ]
114132
115133 @staticmethod
116134 def insert_output_transpose (node , graph_module ):
@@ -121,25 +139,23 @@ def insert_output_transpose(node, graph_module):
121139 args = (
122140 node ,
123141 list (
124- AnnotateChannelsLastDimOrder .NNHWC_order
142+ ToTosaMemoryFormatPass .NNHWC_order
125143 if len (get_first_fake_tensor (node ).size ()) == 5
126- else AnnotateChannelsLastDimOrder .NHWC_order
144+ else ToTosaMemoryFormatPass .NHWC_order
127145 ),
128146 ),
147+ from_node = node ,
129148 )
149+
130150 permute_node .meta ["tosa_dim_order" ] = (
131- AnnotateChannelsLastDimOrder .NNHWC_order
151+ ToTosaMemoryFormatPass .NNHWC_order
132152 if len (get_first_fake_tensor (node ).size ()) == 5
133- else AnnotateChannelsLastDimOrder .NHWC_order
134- )
135- permute_node .meta ["val" ] = get_first_fake_tensor (node ).permute (
136- AnnotateChannelsLastDimOrder .NNHWC_order
137- if len (get_first_fake_tensor (node ).size ()) == 5
138- else AnnotateChannelsLastDimOrder .NHWC_order
153+ else ToTosaMemoryFormatPass .NHWC_order
139154 )
140155 node .meta ["tosa_dim_order" ] = tuple (
141156 range (len (get_first_fake_tensor (node ).size ()))
142157 )
158+
143159 users = [user for user in node .users if user != permute_node ]
144160 for user in users :
145161 user .replace_input_with (node , permute_node )
@@ -150,20 +166,23 @@ def _insert_view_transpose(
150166 ):
151167 nchw_to_nhwc = len (input_shape ) < 4 and len (output_shape ) >= 4
152168 nhwc_to_nchw = len (input_shape ) >= 4 and len (output_shape ) < 4
153- channel_reshape = AnnotateChannelsLastDimOrder .is_channel_reshape (
169+ channel_reshape = ToTosaMemoryFormatPass .is_channel_reshape (
154170 output_shape , input_shape
155171 )
156172
157173 if (
158174 channel_reshape or nhwc_to_nchw
159- ) and AnnotateChannelsLastDimOrder .memory_format_differs (input_shape ):
160- AnnotateChannelsLastDimOrder .insert_input_transpose (
175+ ) and ToTosaMemoryFormatPass .memory_format_differs (input_shape ):
176+
177+ ToTosaMemoryFormatPass .insert_input_transpose (
161178 node , input_node , graph_module
162179 )
180+
163181 if (
164182 channel_reshape or nchw_to_nhwc
165- ) and AnnotateChannelsLastDimOrder .memory_format_differs (output_shape ):
166- AnnotateChannelsLastDimOrder .insert_output_transpose (node , graph_module )
183+ ) and ToTosaMemoryFormatPass .memory_format_differs (output_shape ):
184+
185+ ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
167186
168187 def insert_tosa_transposes (self , graph_module : torch .fx .GraphModule ):
169188 """
@@ -181,9 +200,10 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
181200 for node in graph_module .graph .nodes :
182201 # call_function and placeholder allowed due to
183202 # index.Tensor being able to come in as both
184- if node .op not in ["call_function" , "placeholder" ]:
203+ if node .op not in ["call_function" , "placeholder" , "output" ]:
185204 continue
186205
206+ # Transpose views
187207 elif node .target in (
188208 exir_ops .edge .aten .view_copy .default ,
189209 exir_ops .edge .aten .index .Tensor ,
@@ -194,25 +214,48 @@ def insert_tosa_transposes(self, graph_module: torch.fx.GraphModule):
194214 input_node = node .args [0 ]
195215 input_shape = input_node .meta ["val" ].shape
196216 output_shape = node .meta ["val" ].shape
197-
198217 self ._insert_view_transpose (
199- input_shape , output_shape , node , input_node , graph_module
218+ input_shape ,
219+ output_shape ,
220+ node ,
221+ input_node ,
222+ graph_module ,
200223 )
201224
225+ # Transpose inputs
226+ elif _is_input (node , self .exported_program ):
227+ input_shape = get_first_fake_tensor (node ).size ()
228+ if len (input_shape ) in (4 , 5 ):
229+ ToTosaMemoryFormatPass .insert_output_transpose (node , graph_module )
230+
231+ # Transpose outputs
232+ elif node .op == "output" :
233+ output_shape = get_first_fake_tensor (node ).size ()
234+
235+ if len (output_shape ) in (4 , 5 ):
236+ for input_node in node .all_input_nodes :
237+ ToTosaMemoryFormatPass .insert_input_transpose (
238+ node , input_node , graph_module
239+ )
240+
202241 def call (self , graph_module : torch .fx .GraphModule ):
203242 for node in graph_module .graph .nodes :
204243 node_data = get_first_fake_tensor (node ).data
205244
206- if node_data .dim () == 4 :
245+ # Inputs and outputs are always in (N)NCHW format
246+ if _is_input (node , self .exported_program ) or node .op == "output" :
247+ dim_order = tuple (range (node_data .dim ()))
248+ elif node_data .dim () == 4 :
207249 dim_order = self .NHWC_order
208250 if self .is_weight_node_for_depthwise_conv2d (node ):
209251 # The weights of TOSA DEPTHWISE_CONV2D have shape (H, W, C, M) which corresponds to
210252 # dim_order = (2, 3, 0, 1) (https://www.mlplatform.org/tosa/tosa_spec.html#_depthwise_conv2d).
211253 dim_order = self .HWCM_order
212254 elif node_data .dim () == 5 :
213- dim_order = self .NNHWC_order # type: ignore[assignment]
255+ dim_order = self .NNHWC_order
214256 else :
215257 dim_order = tuple (range (node_data .dim ())) # type: ignore[assignment]
258+
216259 node .meta ["tosa_dim_order" ] = dim_order
217260 # Insert TOSA transposes to convert between (N)NCHW and (N)NHWC format.
218261 # See insert_tosa_transposes for insertion conditions.
0 commit comments