@@ -7332,16 +7332,25 @@ def aten_repeat_interleave_self_int(
73327332 self_rank = len (self .shape )
73337333 pos_dim = (dim + self_rank ) % self_rank
73347334 unsqueezed = op .Unsqueeze (self , [pos_dim + 1 ])
7335- tiles = [1 ] * (self_rank + 1 )
7336- tiles [pos_dim + 1 ] = repeats
7337- tile_repeat = op .Constant (value = ir .tensor (tiles , dtype = INT64 .dtype ))
7338- tiled = op .Tile (unsqueezed , tile_repeat )
7335+ if isinstance (repeats , int ):
7336+ tiles = [1 ] * (self_rank + 1 )
7337+ tiles [pos_dim + 1 ] = repeats
7338+ tile_repeat = op .Constant (value = ir .tensor (tiles , dtype = INT64 .dtype ))
7339+ else :
7340+ # repeats is a symbolic tensor
7341+ tile_repeat = op .Concat (
7342+ op .Constant (value = ir .tensor ([1 ] * pos_dim , dtype = INT64 .dtype )),
7343+ op .Reshape (repeats , op .Constant (value = ir .tensor ([- 1 ], dtype = INT64 .dtype ))),
7344+ op .Constant (value = ir .tensor ([1 ] * (self_rank - pos_dim ), dtype = INT64 .dtype )),
7345+ axis = 0 ,
7346+ )
7347+ tiled = op .Expand (unsqueezed , tile_repeat )
73397348 if self_rank == 1 :
73407349 return op .Identity (tiled )
73417350 final_shape = op .Concat (
73427351 op .Shape (self , start = 0 , end = dim ),
73437352 op .Constant (value_ints = [- 1 ]),
7344- op .Shape (self , start = dim + 1 ),
7353+ op .Shape (self , start = pos_dim + 1 ),
73457354 axis = 0 ,
73467355 )
73477356 return op .Reshape (tiled , final_shape )
@@ -7380,20 +7389,22 @@ def aten_repeat_interleave_Tensor(
73807389 if dim is None :
73817390 # flatten
73827391 self = op .Reshape (self , [- 1 ])
7383- rk = 1
7392+ rank = 1
73847393 else :
7385- rk = len (self .shape )
7394+ rank = len (self .shape )
73867395
7387- if rk > 2 :
7396+ if rank > 2 :
73887397 shape_x0 = op .Shape (self , start = 0 , end = 1 )
73897398 shape_x = op .Shape (self , start = 1 )
73907399 self = op .Reshape (self , op .Concat (shape_x0 , [- 1 ], axis = 0 ))
7391- elif rk == 1 :
7400+ elif rank == 1 :
73927401 shape_x = None
73937402 self = op .Reshape (self , [- 1 , 1 ])
73947403 else :
7395- if rk != 2 :
7396- raise NotImplementedError (f"rank(self)={ rk } not implemented for repeat_interleave" )
7404+ if rank != 2 :
7405+ raise NotImplementedError (
7406+ f"rank(self)={ rank } not implemented for repeat_interleave"
7407+ )
73977408 shape_x = None
73987409
73997410 ci = op .CumSum (repeats , [0 ])
@@ -7406,7 +7417,7 @@ def aten_repeat_interleave_Tensor(
74067417 )
74077418 indices = op .Reshape (srows , [- 1 ])
74087419 values = op .GatherND (self , op .Unsqueeze (indices , [- 1 ]))
7409- if rk == 2 :
7420+ if rank == 2 :
74107421 return values
74117422 # shape_x is None at this stage.
74127423 assert shape_x is None # for mypy
0 commit comments