Skip to content

Commit f83c05b

Browse files
authored
Add __rtruediv__ and __rfloordiv__ to Scalar variables (#1701)
Also remove legacy `__div__`
1 parent 112b325 commit f83c05b

File tree

4 files changed

+24
-22
lines changed

4 files changed

+24
-22
lines changed

pytensor/scalar/basic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -943,6 +943,12 @@ def __rsub__(self, other):
943943
def __rmul__(self, other):
944944
return mul(other, self)
945945

946+
def __rtruediv__(self, other):
947+
return true_div(other, self)
948+
949+
def __rfloordiv__(self, other):
950+
return int_div(other, self)
951+
946952
def __rmod__(self, other):
947953
return mod(other, self)
948954

pytensor/tensor/variable.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from pytensor.graph.utils import MetaType
1414
from pytensor.scalar import (
1515
ComplexError,
16-
IntegerDivisionError,
1716
)
1817
from pytensor.tensor import _get_vector_length
1918
from pytensor.tensor.exceptions import AdvancedIndexingError
@@ -138,18 +137,6 @@ def __mul__(self, other):
138137
except (NotImplementedError, TypeError):
139138
return NotImplemented
140139

141-
def __div__(self, other):
142-
# See explanation in __add__ for the error caught
143-
# and the return value in that case
144-
try:
145-
return pt.math.div_proxy(self, other)
146-
except IntegerDivisionError:
147-
# This is to raise the exception that occurs when trying to divide
148-
# two integer arrays (currently forbidden).
149-
raise
150-
except (NotImplementedError, TypeError):
151-
return NotImplemented
152-
153140
def __pow__(self, other):
154141
# See explanation in __add__ for the error caught
155142
# and the return value in that case
@@ -210,9 +197,6 @@ def __rsub__(self, other):
210197
def __rmul__(self, other):
211198
return pt.math.mul(other, self)
212199

213-
def __rdiv__(self, other):
214-
return pt.math.div_proxy(other, self)
215-
216200
def __rmod__(self, other):
217201
return pt.math.mod(other, self)
218202

pytensor/xtensor/type.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -308,9 +308,6 @@ def __sub__(self, other):
308308
def __mul__(self, other):
309309
return px.math.mul(self, other)
310310

311-
def __div__(self, other):
312-
return px.math.div(self, other)
313-
314311
def __pow__(self, other):
315312
return px.math.pow(self, other)
316313

@@ -341,9 +338,6 @@ def __rsub__(self, other):
341338
def __rmul__(self, other):
342339
return px.math.mul(other, self)
343340

344-
def __rdiv__(self, other):
345-
return px.math.div_proxy(other, self)
346-
347341
def __rmod__(self, other):
348342
return px.math.mod(other, self)
349343

tests/scalar/test_basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
EQ,
1111
ComplexError,
1212
Composite,
13+
IntDiv,
1314
ScalarType,
15+
TrueDiv,
1416
add,
1517
and_,
1618
arccos,
@@ -531,3 +533,19 @@ def test_scalar_hash_default_output_type_preference():
531533
del old_eq.output_types_preference # mimic old Op
532534
assert new_eq == old_eq
533535
assert hash(new_eq) == hash(old_eq)
536+
537+
538+
def test_rtruediv():
539+
x = ScalarType(dtype="float64")()
540+
y = 1.0 / x
541+
assert isinstance(y.owner.op, TrueDiv)
542+
assert isinstance(y.type, ScalarType)
543+
assert y.eval({x: 2.0}) == 0.5
544+
545+
546+
def test_rfloordiv():
547+
x = ScalarType(dtype="float64")()
548+
y = 5.0 // x
549+
assert isinstance(y.owner.op, IntDiv)
550+
assert isinstance(y.type, ScalarType)
551+
assert y.eval({x: 2.0}) == 2.0

0 commit comments

Comments
 (0)