Skip to content

Commit 04a9da4

Browse files
authored
Unsqueeze unbatched input of avg_pool (#2646)
Onnx's `AveragePool` require input shape as `N,C,H,W`, but torch accept both `N,C,H,W` and `C,H,W`. Unsqueeze if input is unbatched, just like what `max_pool` does.
1 parent 8a94ad6 commit 04a9da4

File tree

2 files changed

+54
-30
lines changed

2 files changed

+54
-30
lines changed

onnxscript/function_libs/torch_lib/ops/nn.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,33 @@ def _adjust_attributes_of_avg_pool(
114114
return (kernel_shape, strides, pads)
115115

116116

117+
def _aten_avg_pool_onnx(
118+
self: TFloat,
119+
kernel_shape: Sequence[int],
120+
strides: Sequence[int],
121+
pads: Sequence[int],
122+
ceil_mode: bool,
123+
count_include_pad: bool,
124+
) -> TFloat:
125+
self_rank_is_unbatched_rank = len(self.shape) == len(kernel_shape) + 1
126+
if self_rank_is_unbatched_rank: # C,H,W -> N,C,H,W and N=1
127+
self = op.Unsqueeze(self, [0])
128+
129+
result = op.AveragePool(
130+
self,
131+
ceil_mode=ceil_mode,
132+
count_include_pad=count_include_pad,
133+
kernel_shape=kernel_shape,
134+
pads=pads,
135+
strides=strides,
136+
)
137+
138+
if self_rank_is_unbatched_rank:
139+
result = op.Squeeze(result, [0])
140+
141+
return result
142+
143+
117144
@torch_op("aten::avg_pool1d", trace_only=True)
118145
def aten_avg_pool1d(
119146
self: TFloat,
@@ -134,16 +161,7 @@ def aten_avg_pool1d(
134161
expand_size, kernel_size, stride, padding
135162
)
136163

137-
result = op.AveragePool(
138-
self,
139-
ceil_mode=ceil_mode,
140-
count_include_pad=count_include_pad,
141-
kernel_shape=kernel_shape,
142-
pads=pads,
143-
strides=strides,
144-
)
145-
146-
return result
164+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
147165

148166

149167
@torch_op("aten::avg_pool2d", trace_only=True)
@@ -167,15 +185,6 @@ def aten_avg_pool2d(
167185
expand_size, kernel_size, stride, padding
168186
)
169187

170-
result = op.AveragePool(
171-
self,
172-
ceil_mode=ceil_mode,
173-
count_include_pad=count_include_pad,
174-
kernel_shape=kernel_shape,
175-
pads=pads,
176-
strides=strides,
177-
)
178-
179188
# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
180189
# mask = [
181190
# 1, 2, 3, S,..3, 2, 1
@@ -189,7 +198,7 @@ def aten_avg_pool2d(
189198
# S is stride size, in this case S=4,
190199
# S may dup lot of times according to the image size
191200

192-
return result
201+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
193202

194203

195204
def aten_avg_pool2d_backward(
@@ -228,15 +237,6 @@ def aten_avg_pool3d(
228237
expand_size, kernel_size, stride, padding
229238
)
230239

231-
result = op.AveragePool(
232-
self,
233-
kernel_shape=kernel_shape,
234-
strides=strides,
235-
pads=pads,
236-
count_include_pad=count_include_pad,
237-
ceil_mode=ceil_mode,
238-
)
239-
240240
# TODO: if want to support divisor_override argument, need to op.Mul(result, mask)
241241
# mask = [
242242
# 1, 2, 3, S,..3, 2, 1
@@ -250,7 +250,7 @@ def aten_avg_pool3d(
250250
# S is stride size, in this case S=4,
251251
# S may dup lot of times according to the image size
252252

253-
return result
253+
return _aten_avg_pool_onnx(self, kernel_shape, strides, pads, ceil_mode, count_include_pad)
254254

255255

256256
def aten_avg_pool3d_backward(

tests/function_libs/torch_lib/e2e_ops_tests.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,30 @@ def forward(self, x):
238238
)
239239
_testing.assert_onnx_program(onnx_program)
240240

241+
def test_avg_pool(self):
242+
class Model(torch.nn.Module):
243+
def forward(self, x2d, x3d, x4d, x5d):
244+
return (
245+
torch.nn.functional.avg_pool1d(x2d, 2), # pylint: disable=not-callable
246+
torch.nn.functional.avg_pool1d(x3d, 2), # pylint: disable=not-callable
247+
torch.nn.functional.avg_pool2d(x3d, 2), # pylint: disable=not-callable
248+
torch.nn.functional.avg_pool2d(x4d, 2), # pylint: disable=not-callable
249+
torch.nn.functional.avg_pool3d(x4d, 2), # pylint: disable=not-callable
250+
torch.nn.functional.avg_pool3d(x5d, 2), # pylint: disable=not-callable
251+
)
252+
253+
x2d = torch.randn(10, 10)
254+
x3d = torch.randn(10, 10, 10)
255+
x4d = torch.randn(10, 10, 10, 10)
256+
x5d = torch.randn(10, 10, 10, 10, 10)
257+
onnx_program = torch.onnx.export(
258+
Model(),
259+
(x2d, x3d, x4d, x5d),
260+
dynamo=True,
261+
verbose=False,
262+
)
263+
_testing.assert_onnx_program(onnx_program)
264+
241265

242266
if __name__ == "__main__":
243267
unittest.main()

0 commit comments

Comments
 (0)