@@ -7292,12 +7292,114 @@ def aten_repeat(self: TTensor, repeats: Sequence[TInt]) -> TTensor:
72927292 return op .Tile (self_expanded , repeats )
72937293
72947294
7295- def aten_repeat_interleave (
7296- repeats : TensorType , output_size : Optional [int ] = None
7295+ @torch_op ("aten::repeat_interleave.self_int" , trace_only = True )
7296+ def aten_repeat_interleave_self_int (
7297+ self : TensorType , repeats : int , dim : Optional [int ] = None
72977298) -> TensorType :
7298- """repeat_interleave.Tensor (Tensor repeats, *, int ? output_size=None) -> Tensor"""
7299+ """repeat_interleave.self_int (Tensor self, SymInt repeats, int? dim=None, *, SymInt ? output_size=None) -> Tensor
72997300
7300- raise NotImplementedError ()
7301+ The trick is to repeat in one direction orthogonal to reshape.
7302+
7303+ .. code-block:: python
7304+
7305+ x = torch.tensor([[0, 1, 2], [3, 4, 5]])
7306+ x.repeat_interleave(2, dim=0)
7307+
7308+ is equivalent to:
7309+
7310+ .. code-block:: python
7311+
7312+ x = torch.tensor([[0, 1, 2], [3, 4, 5]])
7313+ x.repeat((1, 2)).reshape((-1, t.shape[1]))
7314+ """
7315+ if dim is None :
7316+ raise NotImplementedError ("No conversion available yet when dim is None." )
7317+
7318+ self_rank = len (self .shape )
7319+ pos_dim = (dim + self_rank ) % self_rank
7320+ unsqueezed = op .Unsqueeze (self , [pos_dim + 1 ])
7321+ tiles = [1 ] * (self_rank + 1 )
7322+ tiles [pos_dim + 1 ] = repeats
7323+ tile_repeat = op .Constant (value = ir .tensor (tiles , dtype = INT64 .dtype ))
7324+ tiled = op .Tile (unsqueezed , tile_repeat )
7325+ if self_rank == 1 :
7326+ return op .Identity (tiled )
7327+ final_shape = op .Concat (
7328+ op .Shape (self , start = 0 , end = dim ),
7329+ op .Constant (value_ints = [- 1 ]),
7330+ op .Shape (self , start = dim + 1 ),
7331+ axis = 0 ,
7332+ )
7333+ return op .Reshape (tiled , final_shape )
7334+
7335+
7336+ @torch_op ("aten::repeat_interleave.Tensor" , trace_only = True )
7337+ def aten_repeat_interleave_Tensor (
7338+ self : TensorType , repeats : Optional [TensorType ] = None , dim : Optional [int ] = None
7339+ ) -> TensorType :
7340+ """repeat_interleave.Tensor(Tensor repeats, *, int? output_size=None) -> Tensor
7341+
7342+ When `repeats` is a tensor, each line is multiplied
7343+ by a different number.
7344+ There are multiple strategies. Here is one.
7345+
7346+ .. code-block:: python
7347+
7348+ import torch
7349+
7350+ x = torch.tensor([[0, 1, 2], [3, 4, 5]])
7351+ times = torch.tensor([2, 3], dtype=torch.int64)
7352+ y = x.repeat_interleave(times, dim=0)
7353+ print("repeat_interleave")
7354+ print(y)
7355+
7356+ ci = times.cumsum(dim=0)
7357+ rows = torch.arange(ci[-1], dtype=torch.int64) < ci.reshape((-1, 1))
7358+ srows = times.shape[0] - rows.to(torch.int64).sum(axis=0)
7359+ indices = srows.reshape((-1, ))
7360+ print("decomposed")
7361+ print(x[indices, :])
7362+ """
7363+ if repeats is None :
7364+ repeats = self
7365+ self = op .Range (0 , op .Squeeze (op .Shape (repeats , start = - 1 ), [0 ]), 1 )
7366+ if dim is None :
7367+ # flatten
7368+ self = op .Reshape (self , [- 1 ])
7369+ rk = 1
7370+ else :
7371+ rk = len (self .shape )
7372+
7373+ if rk > 2 :
7374+ shape_x0 = op .Shape (self , start = 0 , end = 1 )
7375+ shape_x = op .Shape (self , start = 1 )
7376+ self = op .Reshape (self , op .Concat (shape_x0 , [- 1 ], axis = 0 ))
7377+ elif rk == 1 :
7378+ shape_x = None
7379+ self = op .Reshape (self , [- 1 , 1 ])
7380+ else :
7381+ if rk != 2 :
7382+ raise NotImplementedError (f"rank(self)={ rk } not implemented for repeat_interleave" )
7383+ shape_x = None
7384+
7385+ ci = op .CumSum (repeats , [0 ])
7386+ last_ci = op .Gather (ci , [- 1 ])
7387+ trange = op .Range (0 , op .Squeeze (last_ci , [0 ]), 1 )
7388+ rows = op .Less (trange , op .Unsqueeze (ci , [- 1 ]))
7389+ srows = op .Sub (
7390+ op .Shape (self , start = 0 , end = 1 ),
7391+ op .ReduceSum (op .Cast (rows , to = INT64 .dtype ), [0 ]),
7392+ )
7393+ indices = op .Reshape (srows , [- 1 ])
7394+ values = op .GatherND (self , op .Unsqueeze (indices , [- 1 ]))
7395+ if rk == 2 :
7396+ return values
7397+ # shape_x is None at this stage.
7398+ assert shape_x is None # for mypy
7399+ return op .Reshape (
7400+ values ,
7401+ op .Concat ([- 1 ], shape_x , axis = 0 ) if shape_x else [- 1 ],
7402+ )
73017403
73027404
73037405@torch_op ("aten::reshape" )
0 commit comments