Skip to content

Commit 0bbda09

Browse files
Fix binary operations between GenExpr and MatrixExpr (#1071)
* Add tests for all binary operations Currently failing: FAILED tests/test_matrix_variable.py::test_binop[add-genexpr-matvar] - AttributeError: 'MatrixExpr' object has no attribute 'getOp' FAILED tests/test_matrix_variable.py::test_binop[sub-genexpr-matvar] - AttributeError: 'MatrixExpr' object has no attribute 'getOp' FAILED tests/test_matrix_variable.py::test_binop[mul-var-matvar] - NotImplementedError FAILED tests/test_matrix_variable.py::test_binop[mul-genexpr-matvar] - AttributeError: 'MatrixExpr' object has no attribute 'getOp' FAILED tests/test_matrix_variable.py::test_binop[truediv-var-matvar] - AttributeError: 'MatrixExpr' object has no attribute 'getOp' FAILED tests/test_matrix_variable.py::test_binop[truediv-genexpr-matvar] - AttributeError: 'MatrixExpr' object has no attribute 'getOp' * Fix binops in all cases * Add changelog entry * Fix model for tests must be at module level --------- Co-authored-by: João Dionísio <57299939+Joao-Dionisio@users.noreply.github.com>
1 parent a6dd8ab commit 0bbda09

File tree

3 files changed

+39
-3
lines changed

3 files changed

+39
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
### Added
55
- Added basic type stubs to help with IDE autocompletion and type checking.
66
### Fixed
7+
- Implemented all binary operations between MatrixExpr and GenExpr
78
- Fixed the type of @ matrix operation result from MatrixVariable to MatrixExpr.
89
### Changed
910
- Speed up MatrixVariable.sum(axis=None) via quicksum

src/pyscipopt/expr.pxi

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ def buildGenExprObj(expr):
146146
GenExprs = np.empty(expr.shape, dtype=object)
147147
for idx in np.ndindex(expr.shape):
148148
GenExprs[idx] = buildGenExprObj(expr[idx])
149-
return GenExprs
149+
return GenExprs.view(MatrixExpr)
150150

151151
else:
152152
assert isinstance(expr, GenExpr)
@@ -223,6 +223,9 @@ cdef class Expr:
223223
return self
224224

225225
def __mul__(self, other):
226+
if isinstance(other, MatrixExpr):
227+
return other * self
228+
226229
if _is_number(other):
227230
f = float(other)
228231
return Expr({v:f*c for v,c in self.terms.items()})
@@ -420,6 +423,9 @@ cdef class GenExpr:
420423
return UnaryExpr(Operator.fabs, self)
421424

422425
def __add__(self, other):
426+
if isinstance(other, MatrixExpr):
427+
return other + self
428+
423429
left = buildGenExprObj(self)
424430
right = buildGenExprObj(other)
425431
ans = SumExpr()
@@ -475,6 +481,9 @@ cdef class GenExpr:
475481
# return self
476482

477483
def __mul__(self, other):
484+
if isinstance(other, MatrixExpr):
485+
return other * self
486+
478487
left = buildGenExprObj(self)
479488
right = buildGenExprObj(other)
480489
ans = ProdExpr()
@@ -537,7 +546,7 @@ cdef class GenExpr:
537546
def __truediv__(self,other):
538547
divisor = buildGenExprObj(other)
539548
# we can't divide by 0
540-
if divisor.getOp() == Operator.const and divisor.number == 0.0:
549+
if isinstance(divisor, GenExpr) and divisor.getOp() == Operator.const and divisor.number == 0.0:
541550
raise ZeroDivisionError("cannot divide by 0")
542551
return self * divisor**(-1)
543552

tests/test_matrix_variable.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,10 @@
1+
import operator
2+
import pdb
3+
import pprint
4+
import pytest
5+
from pyscipopt import Model, Variable, log, exp, cos, sin, sqrt
6+
from pyscipopt import Expr, MatrixExpr, MatrixVariable, MatrixExprCons, MatrixConstraint, ExprCons
7+
from pyscipopt.scip import GenExpr
18
from time import time
29

310
import numpy as np
@@ -209,7 +216,7 @@ def test_matrix_sum_argument():
209216
assert (m.getVal(x) == np.full((2, 3), 4)).all().all()
210217
assert (m.getVal(y) == np.full((2, 4), 3)).all().all()
211218

212-
219+
@pytest.mark.skip(reason="Performance test")
213220
def test_sum_performance():
214221
n = 1000
215222
model = Model()
@@ -442,6 +449,25 @@ def test_matrix_cons_indicator():
442449
assert m.getVal(z) == 1
443450

444451

452+
_binop_model = Model()
453+
454+
def var():
455+
return _binop_model.addVar()
456+
457+
def genexpr():
458+
return _binop_model.addVar() ** 0.6
459+
460+
def matvar():
461+
return _binop_model.addMatrixVar((1,))
462+
463+
@pytest.mark.parametrize("right", [var(), genexpr(), matvar()], ids=["var", "genexpr", "matvar"])
464+
@pytest.mark.parametrize("left", [var(), genexpr(), matvar()], ids=["var", "genexpr", "matvar"])
465+
@pytest.mark.parametrize("op", [operator.add, operator.sub, operator.mul, operator.truediv])
466+
def test_binop(op, left, right):
467+
res = op(left, right)
468+
assert isinstance(res, (Expr, GenExpr, MatrixExpr))
469+
470+
445471
def test_matrix_matmul_return_type():
446472
# test #1058, require returning type is MatrixExpr not MatrixVariable
447473
m = Model()

0 commit comments

Comments
 (0)