5454_INT64_MAX = 9223372036854775807
5555_INT64_MIN = - 9223372036854775808
5656_MATH_PI = math .pi
57- Rank = common_ops .Rank
5857
5958
6059@torch_op ("aten::_local_scalar_dense" , trace_only = True )
@@ -947,11 +946,11 @@ def reshape_to_1d(tensor):
947946 return op .SequenceMap (self , body = reshape_to_1d )
948947
949948
950- @torch_op ("aten::atleast_2d" )
949+ @torch_op ("aten::atleast_2d" , trace_only = True )
951950def aten_atleast_2d (self : TTensor ) -> TTensor :
952951 """atleast_2d(Tensor self) -> Tensor"""
953952
954- if Rank (self ) <= 1 :
953+ if len (self . shape ) <= 1 :
955954 self = op .Reshape (self , op .Constant (value_ints = [1 , - 1 ]))
956955 return op .Identity (self )
957956
@@ -975,7 +974,7 @@ def reshape_to_2d(tensor):
975974def aten_atleast_3d (self : TTensor ) -> TTensor :
976975 """atleast_3d(Tensor self) -> Tensor"""
977976
978- rank = Rank (self )
977+ rank = len (self . shape )
979978 if rank <= 1 :
980979 self = op .Reshape (self , op .Constant (value_ints = [1 , - 1 , 1 ]))
981980 elif rank == 2 :
@@ -1820,39 +1819,21 @@ def aten_conj_physical(self: TensorType) -> TensorType:
18201819 raise NotImplementedError ()
18211820
18221821
1823- @torch_op ("aten::constant_pad_nd" )
1824- def aten_constant_pad_nd (self : TTensor , pad : INT64 , value : float = 0.0 ) -> TTensor :
1822+ @torch_op ("aten::constant_pad_nd" , trace_only = True )
1823+ def aten_constant_pad_nd (self : TTensor , pad : Sequence [ INT64 ] , value : float = 0.0 ) -> TTensor :
18251824 """constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor"""
18261825
18271826 # The desired order of paddings is
18281827 # dim_0_begin, dim_1_begin, ... , dim_0_end, ..., dim_n_end.
18291828 # n is the dimension of input.
18301829 # assume zero-dimensions in the beginning
1831- # rank = len(self.shape) # rank must be scalar
1832- # paddings = list(pad[:] ) + [0] * (rank * 2 - len(pad))
1830+ rank = len (self .shape )
1831+ paddings = list (pad ) + [0 ] * (rank * 2 - len (pad ))
18331832 # reverse order and collate first beginnings and then ends
1834- # paddings = paddings[-2::-2] + paddings[-1::-2]
1835-
1836- neg_1 = op .Constant (value_ints = [- 1 ])
1837-
1838- zero_count = op .Sub (op .Mul (Rank (self ), 2 ), op .Size (pad ))
1839- zero_count = op .Reshape (zero_count , neg_1 )
1840- zero = op .Constant (value_ints = [0 ])
1841- zeros = op .Expand (zero , zero_count )
1842- torch_paddings = op .Concat (pad , zeros , axis = 0 )
1843- size_d = op .Size (torch_paddings )
1844- steps = op .Constant (value_ints = [- 2 ])
1845-
1846- starts = steps
1847- ends = op .Sub (starts , size_d )
1848- odd_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1849-
1850- starts = neg_1
1851- ends = op .Sub (starts , size_d )
1852- even_elements = op .Slice (torch_paddings , starts , ends , zero , steps )
1833+ paddings = paddings [- 2 ::- 2 ] + paddings [- 1 ::- 2 ]
1834+ constant_value = op .Constant (value = ir .tensor (value , dtype = self .dtype ))
18531835
1854- onnx_padding = op .Concat (odd_elements , even_elements , axis = 0 )
1855- return op .Pad (self , onnx_padding , value )
1836+ return op .Pad (self , paddings , constant_value )
18561837
18571838
18581839@torch_op ("aten::contiguous" , trace_only = True )
@@ -3996,7 +3977,7 @@ def reshape_to_atleast_2d(tensor):
39963977 result = op .ConcatFromSequence (tensors_atleast_2d , axis = 1 , new_axis = 0 )
39973978
39983979 # hstack expects a non-empty sequence of tensors. So we don't need to check for length
3999- rank_1d_or_less = op .Less (Rank (op .SequenceAt (tensors , 0 )), 2 )
3980+ rank_1d_or_less = op .Less (op . Size (op .Shape ( op . SequenceAt (tensors , 0 ) )), 2 )
40003981 if rank_1d_or_less :
40013982 result = op .Reshape (result , op .Constant (value_ints = [- 1 ]))
40023983 return result
@@ -6076,7 +6057,7 @@ def aten_native_group_norm(
60766057 norm = op .Reshape (norm , op .Shape (input ), allowzero = True )
60776058 # Using the input weight and bias to do affine
60786059 # But need to unsqueeze to the target shape for broading cast easy
6079- input_rank = Rank (input )
6060+ input_rank = len (input . shape )
60806061 axes_unsqueeze = op .Range (1 , input_rank - 1 , 1 )
60816062 weight_full_shape = op .Unsqueeze (weight , axes_unsqueeze )
60826063 bias_full_shape = op .Unsqueeze (bias , axes_unsqueeze )
@@ -8229,7 +8210,7 @@ def aten_symeig(
82298210def aten_t (self : TTensor ) -> TTensor :
82308211 """t(Tensor(a) self) -> Tensor(a)"""
82318212
8232- rank = Rank (self )
8213+ rank = len (self . shape )
82338214 if rank == 2 :
82348215 result = op .Transpose (self , perm = [1 , 0 ])
82358216 else :
@@ -8312,26 +8293,24 @@ def aten_threshold_backward(
83128293 raise NotImplementedError ()
83138294
83148295
8315- @torch_op ("aten::tile" )
8316- def aten_tile (self : TTensor , dims : INT64 ) -> TTensor :
8296+ @torch_op ("aten::tile" , trace_only = True )
8297+ def aten_tile (self : TTensor , dims : Sequence [ int ] ) -> TTensor :
83178298 """tile(Tensor self, int[] dims) -> Tensor"""
83188299
8319- self_rank = Rank (self )
8320- dims_rank = op . Size (dims )
8321- diff = op . Sub ( self_rank , dims_rank )
8300+ self_rank = len (self . shape )
8301+ dims_rank = len (dims )
8302+ diff = self_rank - dims_rank
83228303
83238304 if diff > 0 :
83248305 # dims is shorter than self.shape
83258306 # pad dims with 1
8326- diff_1d = op .Reshape (diff , op .Constant (value_ints = [1 ]))
8327- exapnd_ones = op .Expand (op .Constant (value_ints = [1 ]), diff_1d )
8328- dims = op .Concat (exapnd_ones , dims , axis = 0 )
8307+ exapnd_ones = [1 ] * diff
8308+ dims = [* exapnd_ones , * dims ]
83298309
8330- if diff < 0 :
8310+ elif diff < 0 :
83318311 # dims is longer than self.shape
83328312 # pad self.shape with 1
8333- diff_1d = op .Reshape (op .Abs (diff ), op .Constant (value_ints = [1 ]))
8334- exapnd_ones = op .Expand (op .Constant (value_ints = [1 ]), diff_1d )
8313+ exapnd_ones = op .Constant (value_ints = [1 ] * (- diff ))
83358314 self_shape = op .Shape (self )
83368315 self_final_shape = op .Concat (exapnd_ones , self_shape , axis = 0 )
83378316 self = op .Reshape (self , self_final_shape , allowzero = True )
0 commit comments