Skip to content

Commit a925acc

Browse files
justinchubyCopilot
andauthored
[torchlib] Improve pixel_shuffle (#2537)
Simplify the graph when input rank is 4, in which case we don't need to do any shape manipulation. Fix pytorch/pytorch#162061 --------- Signed-off-by: Justin Chu <justinchuby@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 54de741 commit a925acc

File tree

2 files changed

+19
-22
lines changed

2 files changed

+19
-22
lines changed

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6691,34 +6691,41 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType:
66916691
raise NotImplementedError()
66926692

66936693

6694-
@torch_op("aten::pixel_shuffle")
6694+
@torch_op("aten::pixel_shuffle", trace_only=True)
66956695
def aten_pixel_shuffle(self: TReal, upscale_factor: int) -> TReal:
66966696
"""pixel_shuffle(Tensor self, int upscale_factor) -> Tensor"""
6697-
self_shape = op.Shape(self)
6698-
batch_dims = self_shape[:-3]
6699-
chw_in_dims = self_shape[-3:]
6697+
if len(self.shape) == 4:
6698+
return op.DepthToSpace(self, blocksize=upscale_factor, mode="CRD")
6699+
67006700
# Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
6701+
batch_dims = op.Shape(self, end=-3)
6702+
chw_in_dims = op.Shape(self, start=-3)
6703+
67016704
reshaped_self = op.Reshape(
67026705
self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0)
67036706
)
67046707
depth_to_space = op.DepthToSpace(reshaped_self, blocksize=upscale_factor, mode="CRD")
6705-
output_shape = op.Concat(batch_dims, op.Shape(depth_to_space)[1:], axis=0)
6708+
final_dims = op.Shape(depth_to_space, start=1)
6709+
output_shape = op.Concat(batch_dims, final_dims, axis=0)
67066710
return op.Reshape(depth_to_space, output_shape, allowzero=True)
67076711

67086712

6709-
@torch_op("aten::pixel_unshuffle")
6713+
@torch_op("aten::pixel_unshuffle", trace_only=True)
67106714
def aten_pixel_unshuffle(self: TReal, downscale_factor: int) -> TReal:
67116715
"""pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor"""
6716+
if len(self.shape) == 4:
6717+
return op.SpaceToDepth(self, blocksize=downscale_factor)
67126718

6713-
self_shape = op.Shape(self)
6714-
batch_dims = self_shape[:-3]
6715-
chw_in_dims = self_shape[-3:]
67166719
# Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
6720+
batch_dims = op.Shape(self, end=-3)
6721+
chw_in_dims = op.Shape(self, start=-3)
6722+
67176723
reshaped_self = op.Reshape(
67186724
self, op.Concat(op.Constant(value_ints=[-1]), chw_in_dims, axis=0)
67196725
)
67206726
space_to_depth = op.SpaceToDepth(reshaped_self, blocksize=downscale_factor)
6721-
output_shape = op.Concat(batch_dims, op.Shape(space_to_depth)[1:], axis=0)
6727+
final_dims = op.Shape(space_to_depth, start=1)
6728+
output_shape = op.Concat(batch_dims, final_dims, axis=0)
67226729
return op.Reshape(space_to_depth, output_shape, allowzero=True)
67236730

67246731

tests/function_libs/torch_lib/ops_test_data.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1084,26 +1084,16 @@ def _where_input_wrangler(
10841084
TorchLibOpInfo(
10851085
"nn.functional.pixel_shuffle",
10861086
core_ops.aten_pixel_shuffle,
1087-
)
1088-
.xfail(
1087+
).xfail(
10891088
dtypes=(torch.int32, torch.int64),
10901089
reason="fixme: ONNX Runtime does not support int32/64 inputs",
1091-
)
1092-
.xfail(
1093-
matcher=lambda sample: sample.input.numel() == 0,
1094-
reason="fixme: ORT does not support empty tensor as input",
10951090
),
10961091
TorchLibOpInfo(
10971092
"nn.functional.pixel_unshuffle",
10981093
core_ops.aten_pixel_unshuffle,
1099-
)
1100-
.xfail(
1094+
).xfail(
11011095
dtypes=(torch.int32, torch.int64),
11021096
reason="fixme: ONNX Runtime does not support int32/64 inputs",
1103-
)
1104-
.xfail(
1105-
matcher=lambda sample: sample.input.numel() == 0,
1106-
reason="fixme: ORT does not support empty tensor as input",
11071097
),
11081098
TorchLibOpInfo(
11091099
"ops.aten.reflection_pad1d",

0 commit comments

Comments
 (0)