@@ -6691,34 +6691,41 @@ def aten_pinverse(self: TensorType, rcond: float = 1e-15) -> TensorType:
66916691 raise NotImplementedError ()
66926692
66936693
6694- @torch_op ("aten::pixel_shuffle" )
6694+ @torch_op ("aten::pixel_shuffle" , trace_only = True )
66956695def aten_pixel_shuffle (self : TReal , upscale_factor : int ) -> TReal :
66966696 """pixel_shuffle(Tensor self, int upscale_factor) -> Tensor"""
6697- self_shape = op . Shape (self )
6698- batch_dims = self_shape [: - 3 ]
6699- chw_in_dims = self_shape [ - 3 :]
6697+ if len (self . shape ) == 4 :
6698+ return op . DepthToSpace ( self , blocksize = upscale_factor , mode = "CRD" )
6699+
67006700 # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
6701+ batch_dims = op .Shape (self , end = - 3 )
6702+ chw_in_dims = op .Shape (self , start = - 3 )
6703+
67016704 reshaped_self = op .Reshape (
67026705 self , op .Concat (op .Constant (value_ints = [- 1 ]), chw_in_dims , axis = 0 )
67036706 )
67046707 depth_to_space = op .DepthToSpace (reshaped_self , blocksize = upscale_factor , mode = "CRD" )
6705- output_shape = op .Concat (batch_dims , op .Shape (depth_to_space )[1 :], axis = 0 )
6708+ final_dims = op .Shape (depth_to_space , start = 1 )
6709+ output_shape = op .Concat (batch_dims , final_dims , axis = 0 )
67066710 return op .Reshape (depth_to_space , output_shape , allowzero = True )
67076711
67086712
6709- @torch_op ("aten::pixel_unshuffle" )
6713+ @torch_op ("aten::pixel_unshuffle" , trace_only = True )
67106714def aten_pixel_unshuffle (self : TReal , downscale_factor : int ) -> TReal :
67116715 """pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor"""
6716+ if len (self .shape ) == 4 :
6717+ return op .SpaceToDepth (self , blocksize = downscale_factor )
67126718
6713- self_shape = op .Shape (self )
6714- batch_dims = self_shape [:- 3 ]
6715- chw_in_dims = self_shape [- 3 :]
67166719 # Reshaping input by collapsing all leading dimensions to match ONNX op requirement (4D)
6720+ batch_dims = op .Shape (self , end = - 3 )
6721+ chw_in_dims = op .Shape (self , start = - 3 )
6722+
67176723 reshaped_self = op .Reshape (
67186724 self , op .Concat (op .Constant (value_ints = [- 1 ]), chw_in_dims , axis = 0 )
67196725 )
67206726 space_to_depth = op .SpaceToDepth (reshaped_self , blocksize = downscale_factor )
6721- output_shape = op .Concat (batch_dims , op .Shape (space_to_depth )[1 :], axis = 0 )
6727+ final_dims = op .Shape (space_to_depth , start = 1 )
6728+ output_shape = op .Concat (batch_dims , final_dims , axis = 0 )
67226729 return op .Reshape (space_to_depth , output_shape , allowzero = True )
67236730
67246731
0 commit comments