@@ -1523,10 +1523,10 @@ def aten_broadcast_tensors(tensors: Sequence[TensorType]) -> TensorType:
15231523 raise NotImplementedError ()
15241524
15251525
1526- @torch_op ("aten::broadcast_to" )
1527- def aten_broadcast_to (self : TTensor , size : INT64 ) -> TTensor :
1526+ @torch_op ("aten::broadcast_to" , trace_only = True )
1527+ def aten_broadcast_to (self : TTensor , size : Sequence [ INT64 ] ) -> TTensor :
15281528 """broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
1529-
1529+ size = common_ops . merge_dims ( size )
15301530 return op .Expand (self , size )
15311531
15321532
@@ -3286,20 +3286,20 @@ def aten_embedding_sparse_backward(
32863286
32873287@torch_op ("aten::empty.memory_format" , trace_only = True )
32883288def aten_empty (
3289- size : IntType ,
3289+ size : Sequence [ INT64 ] ,
32903290 dtype : int = FLOAT .dtype ,
32913291 layout : str = "" ,
32923292 device : str = "" ,
32933293 pin_memory : bool = False ,
32943294 memory_format : str = "" ,
32953295) -> TensorType : # type: ignore[type-var]
3296- # empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
3296+ """ empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor"""
32973297 if dtype == - 1 :
32983298 dtype = FLOAT .dtype
3299- # using Zeros to simulate np.empty()
3300- size = op . Cast ( size , to = INT64 . dtype )
3301- zero = op .Constant (value_float = 0.0 )
3302- zero = op . Cast ( zero , to = dtype )
3299+
3300+ # using Zeros to simulate empty( )
3301+ zero = op .Constant (value = ir . tensor ( 0 , dtype = ir . DataType ( dtype )) )
3302+ size = common_ops . merge_dims ( size )
33033303
33043304 return op .Expand (zero , size )
33053305
@@ -3334,17 +3334,18 @@ def aten_empty_quantized(
33343334
33353335@torch_op ("aten::empty_strided" , trace_only = True )
33363336def aten_empty_strided (
3337- size : INT64 ,
3337+ size : Sequence [ INT64 ] ,
33383338 stride : INT64 ,
33393339 layout : str = "" ,
3340+ dtype : int = FLOAT .dtype ,
33403341 device : str = "" ,
33413342 pin_memory : bool = False ,
33423343) -> TTensor : # type: ignore[type-var]
33433344 # empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
33443345
33453346 # using Zeros to simulate empty()
3346- size = op .Cast ( size , to = INT64 . dtype )
3347- zero = op . Constant ( value_float = 0.0 )
3347+ zero = op .Constant ( value = ir . tensor ( 0 , dtype = ir . DataType ( dtype )) )
3348+ size = common_ops . merge_dims ( size )
33483349
33493350 return op .Expand (zero , size )
33503351
@@ -3392,13 +3393,14 @@ def aten_exp2(self: TFloat) -> TFloat:
33923393
33933394
33943395@torch_op ("aten::expand" , trace_only = True )
3395- def aten_expand (self : TTensor , size : TInt , implicit : bool = False ) -> TTensor :
3396+ def aten_expand (self : TTensor , size : Sequence [ INT64 ] , implicit : bool = False ) -> TTensor :
33963397 """expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)"""
3397- size = op .Cast (size , to = INT64 .dtype )
33983398 # NOTE: PyTorch supports `not changing dim` by -1, but ONNX supports `not changing dim` by 1.
33993399 # To support -1 dim, we need to convert -1 to 1.
3400- size = op .Abs (size )
3401- return op .Expand (self , size )
3400+ # Even though in theory a dynamic dim can still be -1, in practice it is very unlikely
3401+ # and isn't expected to appear from correct usages of SymInt.
3402+ size = [1 if isinstance (s , int ) and s == - 1 else s for s in size ]
3403+ return op .Expand (self , common_ops .merge_dims (size ))
34023404
34033405
34043406@torch_op ("aten::expand_as" , trace_only = True )
@@ -7409,12 +7411,10 @@ def aten_repeat_interleave_Tensor(
74097411 )
74107412
74117413
7412- @torch_op ("aten::reshape" )
7413- def aten_reshape (self : TTensor , shape : IntType ) -> TTensor :
7414+ @torch_op ("aten::reshape" , trace_only = True )
7415+ def aten_reshape (self : TTensor , shape : Sequence [ INT64 ] ) -> TTensor :
74147416 """reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)"""
7415-
7416- # Reshape only support INT64 as 'shape'
7417- shape = op .Cast (shape , to = INT64 .dtype )
7417+ shape = common_ops .merge_dims (shape )
74187418 return op .Reshape (self , shape )
74197419
74207420
@@ -9153,23 +9153,22 @@ def aten_vdot(self: TensorType, other: TensorType) -> TensorType:
91539153
91549154
91559155@torch_op (("aten::view" , "aten::_unsafe_view" ), trace_only = True )
9156- def aten_view (self : TTensor , size : IntType ) -> TTensor :
9156+ def aten_view (self : TTensor , size : Sequence [ INT64 ] ) -> TTensor :
91579157 """view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
91589158
9159- size = op . Cast (size , to = INT64 . dtype ) # Reshape only support INT64 as second input
9159+ size = common_ops . merge_dims (size )
91609160 return op .Reshape (self , size , allowzero = True )
91619161
91629162
9163- @torch_op (("aten::view" , "aten::_unsafe_view" ), complex = True )
9164- def aten_view_complex (self : TTensor , size : IntType ) -> TTensor :
9163+ @torch_op (("aten::view" , "aten::_unsafe_view" ), complex = True , trace_only = True )
9164+ def aten_view_complex (self : TTensor , size : Sequence [ INT64 ] ) -> TTensor :
91659165 """view(Tensor(a) self, SymInt[] size) -> Tensor(a)"""
91669166
9167- size = op .Cast (size , to = INT64 .dtype ) # Reshape only support INT64 as second input
9168- complex_size = op .Concat (size , op .Constant (value_ints = [2 ]), axis = 0 )
9167+ complex_size = common_ops .merge_dims ([* size , 2 ])
91699168 return op .Reshape (self , complex_size , allowzero = True )
91709169
91719170
9172- @torch_op ("aten::view_as" )
9171+ @torch_op ("aten::view_as" , trace_only = True )
91739172def aten_view_as (self : TTensor , other : TTensor2 ) -> TTensor :
91749173 """view_as(Tensor(a) self, Tensor other) -> Tensor(a)"""
91759174
@@ -9213,11 +9212,11 @@ def aten_view_as_real_copy(self: TTensor) -> TTensor:
92139212 return op .Identity (self )
92149213
92159214
9216- @torch_op ("aten::view_copy" )
9217- def aten_view_copy (self : TTensor , size : IntType ) -> TTensor :
9215+ @torch_op ("aten::view_copy" , trace_only = True )
9216+ def aten_view_copy (self : TTensor , size : Sequence [ INT64 ] ) -> TTensor :
92189217 """view_copy(Tensor self, SymInt[] size) -> Tensor"""
92199218
9220- size = op . Cast (size , to = INT64 . dtype ) # Reshape only support INT64 as second input
9219+ size = common_ops . merge_dims (size )
92219220 return op .Reshape (self , size )
92229221
92239222
@@ -9245,7 +9244,8 @@ def reshape_to_2d(tensor):
92459244 "aten::where.ScalarSelf" ,
92469245 "aten::where.ScalarOther" ,
92479246 "aten::where.self" ,
9248- )
9247+ ),
9248+ trace_only = True ,
92499249)
92509250def aten_where (condition : BOOL , self : TTensor , other : TTensor ) -> TTensor :
92519251 """where.self(Tensor condition, Tensor self, Tensor other) -> Tensor"""
@@ -9261,7 +9261,7 @@ def aten_xor(self: TensorType, other: TensorType) -> TensorType:
92619261
92629262@torch_op ("aten::zeros" , trace_only = True )
92639263def aten_zeros (
9264- size : IntType ,
9264+ size : Sequence [ INT64 ] ,
92659265 dtype : int = FLOAT .dtype ,
92669266 layout : str = "" ,
92679267 device : str = "" ,
@@ -9270,9 +9270,9 @@ def aten_zeros(
92709270 """zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor"""
92719271 if dtype == - 1 :
92729272 dtype = FLOAT .dtype
9273- size = op . Cast ( size , to = INT64 . dtype )
9274- zero = op .Constant (value_float = 0.0 )
9275- zero = op . Cast ( zero , to = dtype )
9273+
9274+ zero = op .Constant (value = ir . tensor ( 0 , dtype = ir . DataType ( dtype )) )
9275+ size = common_ops . merge_dims ( size )
92769276
92779277 return op .Expand (zero , size )
92789278
0 commit comments