@@ -3688,23 +3688,27 @@ def python_math_floor(self: TFloat) -> TInt:
36883688
36893689
36903690@torch_op ("aten::floor_divide" , trace_only = True )
3691- def aten_floor_divide (self : TFloat , other : TFloat ) -> TFloat :
3691+ def aten_floor_divide (self : TTensor , other : TTensor ) -> TTensor :
36923692 """floor_divide(Tensor self, Tensor other) -> Tensor"""
36933693
3694- return op .Floor (op .Div (self , other ))
3694+ if self .dtype .is_floating_point ():
3695+ return op .Floor (op .Div (self , other ))
36953696
3697+ assert self .dtype .is_integer ()
36963698
3697- @torch_op ("aten::floor_divide" , trace_only = True )
3698- def aten_floor_divide_int (self : TInt , other : TInt ) -> TInt :
3699- """floor_divide(Tensor self, Tensor other) -> Tensor"""
3699+ if not self .dtype .is_signed ():
3700+ return op .Div (self , other )
37003701
3701- # TODO(justinchuby): This can be simplified if we can constrain the
3702- # inputs to be positive integers. Consider how we can embed constraints in the model.
3703- dtype = self .dtype
3704- self = op .Cast (self , to = FLOAT .dtype )
3705- other = op .Cast (other , to = FLOAT .dtype )
3706- result = op .Floor (op .Div (self , other ))
3707- return op .Cast (result , to = dtype )
3702+ # Convert truncation to flooring
3703+ # Reference: https://github.com/pytorch/pytorch/blob/ffc645c870f0abd368606ba1e2b3b58cacb03046/torch/_refs/__init__.py#L1401C1-L1409C70
3704+ # offset = (torch.signbit(a) != torch.signbit(b)).logical_and(torch.fmod(a, b) != 0)
3705+ # return prims.div(a, b) - _maybe_convert_to_dtype(offset, a.dtype)
3706+ offset = op .And (
3707+ op .Not (op .Equal (op .Sign (self ), op .Sign (other ))),
3708+ op .Cast (op .Mod (self , other ), to = BOOL .dtype ),
3709+ )
3710+ offset = op .Cast (offset , to = self .dtype )
3711+ return op .Sub (op .Div (self , other ), offset )
37083712
37093713
37103714@torch_op ("_operator::floordiv" , trace_only = True )
0 commit comments