@@ -114,6 +114,33 @@ def _adjust_attributes_of_avg_pool(
114114 return (kernel_shape , strides , pads )
115115
116116
117+ def _aten_avg_pool_onnx (
118+ self : TFloat ,
119+ kernel_shape : Sequence [int ],
120+ strides : Sequence [int ],
121+ pads : Sequence [int ],
122+ ceil_mode : bool ,
123+ count_include_pad : bool ,
124+ ) -> TFloat :
125+ self_rank_is_unbatched_rank = len (self .shape ) == len (kernel_shape ) + 1
126+ if self_rank_is_unbatched_rank : # C,H,W -> N,C,H,W and N=1
127+ self = op .Unsqueeze (self , [0 ])
128+
129+ result = op .AveragePool (
130+ self ,
131+ ceil_mode = ceil_mode ,
132+ count_include_pad = count_include_pad ,
133+ kernel_shape = kernel_shape ,
134+ pads = pads ,
135+ strides = strides ,
136+ )
137+
138+ if self_rank_is_unbatched_rank :
139+ result = op .Squeeze (result , [0 ])
140+
141+ return result
142+
143+
117144@torch_op ("aten::avg_pool1d" , trace_only = True )
118145def aten_avg_pool1d (
119146 self : TFloat ,
@@ -134,16 +161,7 @@ def aten_avg_pool1d(
134161 expand_size , kernel_size , stride , padding
135162 )
136163
137- result = op .AveragePool (
138- self ,
139- ceil_mode = ceil_mode ,
140- count_include_pad = count_include_pad ,
141- kernel_shape = kernel_shape ,
142- pads = pads ,
143- strides = strides ,
144- )
145-
146- return result
164+ return _aten_avg_pool_onnx (self , kernel_shape , strides , pads , ceil_mode , count_include_pad )
147165
148166
149167@torch_op ("aten::avg_pool2d" , trace_only = True )
@@ -167,15 +185,6 @@ def aten_avg_pool2d(
167185 expand_size , kernel_size , stride , padding
168186 )
169187
170- result = op .AveragePool (
171- self ,
172- ceil_mode = ceil_mode ,
173- count_include_pad = count_include_pad ,
174- kernel_shape = kernel_shape ,
175- pads = pads ,
176- strides = strides ,
177- )
178-
179188 # TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
180189 # mask = [
181190 # 1, 2, 3, S,..3, 2, 1
@@ -189,7 +198,7 @@ def aten_avg_pool2d(
189198 # S is stride size, in this case S=4,
190199 # S may dup lot of times according to the image size
191200
192- return result
201+ return _aten_avg_pool_onnx ( self , kernel_shape , strides , pads , ceil_mode , count_include_pad )
193202
194203
195204def aten_avg_pool2d_backward (
@@ -228,15 +237,6 @@ def aten_avg_pool3d(
228237 expand_size , kernel_size , stride , padding
229238 )
230239
231- result = op .AveragePool (
232- self ,
233- kernel_shape = kernel_shape ,
234- strides = strides ,
235- pads = pads ,
236- count_include_pad = count_include_pad ,
237- ceil_mode = ceil_mode ,
238- )
239-
240240 # TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
241241 # mask = [
242242 # 1, 2, 3, S,..3, 2, 1
@@ -250,7 +250,7 @@ def aten_avg_pool3d(
250250 # S is stride size, in this case S=4,
251251 # S may dup lot of times according to the image size
252252
253- return result
253+ return _aten_avg_pool_onnx ( self , kernel_shape , strides , pads , ceil_mode , count_include_pad )
254254
255255
256256def aten_avg_pool3d_backward (
0 commit comments