Skip to content

Commit 35c7f3e

Browse files
Zeroto521Copilot
andauthored
BUG: MatrixExpr can't be compared with Expr (#1069)
* Replace Variable with Expr in MatrixExpr Updated type hints and isinstance checks in MatrixExpr comparison methods to use Expr instead of Variable. This change improves compatibility with broader expression types in matrix operations. * add test case * Replace Variable with Expr in MatrixExprCons Updated type hints and isinstance checks in MatrixExprCons.__le__ and __ge__ methods to use Expr instead of Variable. This change improves consistency with the expected types for matrix expression constraints. * add test case * Update CHANGELOG.md * Add test for ranged matrix constraint Introduces test_ranged_matrix_cons to verify correct behavior when adding a ranged matrix constraint to the model. Ensures that the matrix variable x is set to ones as expected. * Refactor matrix comparison operators using helper Introduced a shared _matrixexpr_richcmp helper to handle rich comparison logic for MatrixExpr and MatrixExprCons, reducing code duplication and improving maintainability. Updated __le__, __ge__, and __eq__ methods to use this helper, and removed redundant code. * Replace TypeError with NotImplementedError in __eq__ The __eq__ method of MatrixExprCons now raises NotImplementedError with a descriptive message instead of TypeError, clarifying that '==' comparison is not supported. * Add tests for matrix constraint operators Added tests for '<=', '>=', and '==' operators in matrix constraints. Verified correct exception is raised for unsupported '==' operator. * Update CHANGELOG.md * BUG: fix circular imports Relocated the _is_number utility from expr.pxi to matrix.pxi for better modularity. Updated _matrixexpr_richcmp to use a local _richcmp helper for comparison operations. * Fix matrix comparison shape handling Replaces usage of undefined 'shape' variable with 'self.shape' when creating the result array in _matrixexpr_richcmp, ensuring correct array dimensions. * Fix redundant .all() calls in matrix variable tests Removed unnecessary double .all() calls in assertions for matrix variable tests, simplifying the checks for equality with np.ones(3). * Fix matrix variable test assertions to use getVal Updated assertions in test_matrix_variable.py to use m.getVal(x) and m.getVal(y) instead of direct variable comparison. This ensures the tests check the evaluated values from the model rather than the symbolic variables. * let MatrixExprCons support <= and >= operator * Refactor matrix comparison tests to optimize assertions and remove redundant checks * let MatrixExprCons support <= and >= operator * find what type it is * align with `__add__` * test "==" first * Revert "let MatrixExprCons support <= and >= operator" This reverts commit f69ce7e. * Revert "let MatrixExprCons support <= and >= operator" This reverts commit b6dcf42. * find what type it is * test expr * Change the order * Remove ExprCons Can't add with ExprCons * Ranged ExprCons requires number * Update CHANGELOG.md * Lint codes with 4 spaces * keep `_is_number` in expr.pxi * Update CHANGELOG.md Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Add quotes for annotations * Simplify loop via `numpy.ndarray.flat` Replaced explicit loops with list comprehensions for element-wise comparison, improving readability and potentially performance in matrix expression comparisons. * Add shape check for ndarray comparison in _matrixexpr_richcmp Raises a ValueError if the shapes of self and the other ndarray do not match during comparison, preventing invalid element-wise operations. * TST: test MatrixExprCons vs Variable --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent 803e447 commit 35c7f3e

File tree

4 files changed

+106
-98
lines changed

4 files changed

+106
-98
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@
99
### Changed
1010
- Add a PEP 735 dependency group for test dependencies in `pyproject.toml`
1111
- Speed up MatrixVariable.sum(axis=None) via quicksum
12+
- MatrixVariable now supports comparison with Expr
1213
### Removed
1314

14-
## v5.6.0 - 2025.08.26
15+
## 5.6.0 - 2025.08.26
1516
### Added
1617
- More support for AND-Constraints
1718
- Added support for knapsack constraints

src/pyscipopt/expr.pxi

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
# Modifying the expression directly would be a bug, given that the expression might be re-used by the user. </pre>
4545
include "matrix.pxi"
4646

47+
4748
def _is_number(e):
4849
try:
4950
f = float(e)
@@ -52,7 +53,8 @@ def _is_number(e):
5253
return False
5354
except TypeError: # for other types (Variable, Expr)
5455
return False
55-
56+
57+
5658
def _expr_richcmp(self, other, op):
5759
if op == 1: # <=
5860
if isinstance(other, Expr) or isinstance(other, GenExpr):
@@ -62,7 +64,7 @@ def _expr_richcmp(self, other, op):
6264
elif isinstance(other, MatrixExpr):
6365
return _expr_richcmp(other, self, 5)
6466
else:
65-
raise NotImplementedError
67+
raise TypeError(f"Unsupported type {type(other)}")
6668
elif op == 5: # >=
6769
if isinstance(other, Expr) or isinstance(other, GenExpr):
6870
return (self - other) >= 0.0
@@ -71,7 +73,7 @@ def _expr_richcmp(self, other, op):
7173
elif isinstance(other, MatrixExpr):
7274
return _expr_richcmp(other, self, 1)
7375
else:
74-
raise NotImplementedError
76+
raise TypeError(f"Unsupported type {type(other)}")
7577
elif op == 2: # ==
7678
if isinstance(other, Expr) or isinstance(other, GenExpr):
7779
return (self - other) == 0.0
@@ -80,7 +82,7 @@ def _expr_richcmp(self, other, op):
8082
elif isinstance(other, MatrixExpr):
8183
return _expr_richcmp(other, self, 2)
8284
else:
83-
raise NotImplementedError
85+
raise TypeError(f"Unsupported type {type(other)}")
8486
else:
8587
raise NotImplementedError("Can only support constraints with '<=', '>=', or '=='.")
8688

@@ -201,7 +203,7 @@ cdef class Expr:
201203
elif isinstance(right, MatrixExpr):
202204
return right + left
203205
else:
204-
raise NotImplementedError
206+
raise TypeError(f"Unsupported type {type(right)}")
205207

206208
return Expr(terms)
207209

@@ -218,7 +220,7 @@ cdef class Expr:
218220
# TypeError: Cannot convert pyscipopt.scip.SumExpr to pyscipopt.scip.Expr
219221
return buildGenExprObj(self) + other
220222
else:
221-
raise NotImplementedError
223+
raise TypeError(f"Unsupported type {type(other)}")
222224

223225
return self
224226

@@ -337,26 +339,26 @@ cdef class ExprCons:
337339
def __richcmp__(self, other, op):
338340
'''turn it into a constraint'''
339341
if op == 1: # <=
340-
if not self._rhs is None:
341-
raise TypeError('ExprCons already has upper bound')
342-
assert not self._lhs is None
342+
if not self._rhs is None:
343+
raise TypeError('ExprCons already has upper bound')
344+
assert not self._lhs is None
343345

344-
if not _is_number(other):
345-
raise TypeError('Ranged ExprCons is not well defined!')
346+
if not _is_number(other):
347+
raise TypeError('Ranged ExprCons is not well defined!')
346348

347-
return ExprCons(self.expr, lhs=self._lhs, rhs=float(other))
349+
return ExprCons(self.expr, lhs=self._lhs, rhs=float(other))
348350
elif op == 5: # >=
349-
if not self._lhs is None:
350-
raise TypeError('ExprCons already has lower bound')
351-
assert self._lhs is None
352-
assert not self._rhs is None
351+
if not self._lhs is None:
352+
raise TypeError('ExprCons already has lower bound')
353+
assert self._lhs is None
354+
assert not self._rhs is None
353355

354-
if not _is_number(other):
355-
raise TypeError('Ranged ExprCons is not well defined!')
356+
if not _is_number(other):
357+
raise TypeError('Ranged ExprCons is not well defined!')
356358

357-
return ExprCons(self.expr, lhs=float(other), rhs=self._rhs)
359+
return ExprCons(self.expr, lhs=float(other), rhs=self._rhs)
358360
else:
359-
raise TypeError
361+
raise NotImplementedError("Ranged ExprCons can only support with '<=' or '>='.")
360362

361363
def __repr__(self):
362364
return 'ExprCons(%s, %s, %s)' % (self.expr, self._lhs, self._rhs)

src/pyscipopt/matrix.pxi

Lines changed: 40 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from typing import Union
88

9+
910
def _is_number(e):
1011
try:
1112
f = float(e)
@@ -15,6 +16,34 @@ def _is_number(e):
1516
except TypeError: # for other types (Variable, Expr)
1617
return False
1718

19+
20+
def _matrixexpr_richcmp(self, other, op):
21+
def _richcmp(self, other, op):
22+
if op == 1: # <=
23+
return self.__le__(other)
24+
elif op == 5: # >=
25+
return self.__ge__(other)
26+
elif op == 2: # ==
27+
return self.__eq__(other)
28+
else:
29+
raise NotImplementedError("Can only support constraints with '<=', '>=', or '=='.")
30+
31+
res = np.empty(self.shape, dtype=object)
32+
if _is_number(other) or isinstance(other, Expr):
33+
res.flat = [_richcmp(i, other, op) for i in self.flat]
34+
35+
elif isinstance(other, np.ndarray):
36+
if self.shape != other.shape:
37+
raise ValueError("Shapes do not match for comparison.")
38+
39+
res.flat = [_richcmp(i, j, op) for i, j in zip(self.flat, other.flat)]
40+
41+
else:
42+
raise TypeError(f"Unsupported type {type(other)}")
43+
44+
return res.view(MatrixExprCons)
45+
46+
1847
class MatrixExpr(np.ndarray):
1948
def sum(self, **kwargs):
2049
"""
@@ -27,51 +56,15 @@ class MatrixExpr(np.ndarray):
2756
return quicksum(self.flat)
2857
return super().sum(**kwargs)
2958

30-
def __le__(self, other: Union[float, int, Variable, np.ndarray, 'MatrixExpr']) -> np.ndarray:
31-
32-
expr_cons_matrix = np.empty(self.shape, dtype=object)
33-
if _is_number(other) or isinstance(other, Variable):
34-
for idx in np.ndindex(self.shape):
35-
expr_cons_matrix[idx] = self[idx] <= other
36-
37-
elif isinstance(other, np.ndarray):
38-
for idx in np.ndindex(self.shape):
39-
expr_cons_matrix[idx] = self[idx] <= other[idx]
40-
else:
41-
raise TypeError(f"Unsupported type {type(other)}")
42-
43-
return expr_cons_matrix.view(MatrixExprCons)
44-
45-
def __ge__(self, other: Union[float, int, Variable, np.ndarray, 'MatrixExpr']) -> np.ndarray:
46-
47-
expr_cons_matrix = np.empty(self.shape, dtype=object)
48-
if _is_number(other) or isinstance(other, Variable):
49-
for idx in np.ndindex(self.shape):
50-
expr_cons_matrix[idx] = self[idx] >= other
51-
52-
elif isinstance(other, np.ndarray):
53-
for idx in np.ndindex(self.shape):
54-
expr_cons_matrix[idx] = self[idx] >= other[idx]
55-
else:
56-
raise TypeError(f"Unsupported type {type(other)}")
59+
def __le__(self, other: Union[float, int, "Expr", np.ndarray, "MatrixExpr"]) -> MatrixExprCons:
60+
return _matrixexpr_richcmp(self, other, 1)
5761

58-
return expr_cons_matrix.view(MatrixExprCons)
62+
def __ge__(self, other: Union[float, int, "Expr", np.ndarray, "MatrixExpr"]) -> MatrixExprCons:
63+
return _matrixexpr_richcmp(self, other, 5)
5964

60-
def __eq__(self, other: Union[float, int, Variable, np.ndarray, 'MatrixExpr']) -> np.ndarray:
61-
62-
expr_cons_matrix = np.empty(self.shape, dtype=object)
63-
if _is_number(other) or isinstance(other, Variable):
64-
for idx in np.ndindex(self.shape):
65-
expr_cons_matrix[idx] = self[idx] == other
66-
67-
elif isinstance(other, np.ndarray):
68-
for idx in np.ndindex(self.shape):
69-
expr_cons_matrix[idx] = self[idx] == other[idx]
70-
else:
71-
raise TypeError(f"Unsupported type {type(other)}")
65+
def __eq__(self, other: Union[float, int, "Expr", np.ndarray, "MatrixExpr"]) -> MatrixExprCons:
66+
return _matrixexpr_richcmp(self, other, 2)
7267

73-
return expr_cons_matrix.view(MatrixExprCons)
74-
7568
def __add__(self, other):
7669
return super().__add__(other).view(MatrixExpr)
7770

@@ -110,41 +103,11 @@ class MatrixGenExpr(MatrixExpr):
110103

111104
class MatrixExprCons(np.ndarray):
112105

113-
def __le__(self, other: Union[float, int, Variable, MatrixExpr]) -> np.ndarray:
114-
115-
if not _is_number(other) or not isinstance(other, MatrixExpr):
116-
raise TypeError('Ranged MatrixExprCons is not well defined!')
117-
118-
expr_cons_matrix = np.empty(self.shape, dtype=object)
119-
if _is_number(other) or isinstance(other, Variable):
120-
for idx in np.ndindex(self.shape):
121-
expr_cons_matrix[idx] = self[idx] <= other
122-
123-
elif isinstance(other, np.ndarray):
124-
for idx in np.ndindex(self.shape):
125-
expr_cons_matrix[idx] = self[idx] <= other[idx]
126-
else:
127-
raise TypeError(f"Unsupported type {type(other)}")
128-
129-
return expr_cons_matrix.view(MatrixExprCons)
130-
131-
def __ge__(self, other: Union[float, int, Variable, MatrixExpr]) -> np.ndarray:
132-
133-
if not _is_number(other) or not isinstance(other, MatrixExpr):
134-
raise TypeError('Ranged MatrixExprCons is not well defined!')
135-
136-
expr_cons_matrix = np.empty(self.shape, dtype=object)
137-
if _is_number(other) or isinstance(other, Variable):
138-
for idx in np.ndindex(self.shape):
139-
expr_cons_matrix[idx] = self[idx] >= other
140-
141-
elif isinstance(other, np.ndarray):
142-
for idx in np.ndindex(self.shape):
143-
expr_cons_matrix[idx] = self[idx] >= other[idx]
144-
else:
145-
raise TypeError(f"Unsupported type {type(other)}")
106+
def __le__(self, other: Union[float, int, np.ndarray]) -> MatrixExprCons:
107+
return _matrixexpr_richcmp(self, other, 1)
146108

147-
return expr_cons_matrix.view(MatrixExprCons)
109+
def __ge__(self, other: Union[float, int, np.ndarray]) -> MatrixExprCons:
110+
return _matrixexpr_richcmp(self, other, 5)
148111

149112
def __eq__(self, other):
150-
raise TypeError
113+
raise NotImplementedError("Cannot compare MatrixExprCons with '=='.")

tests/test_matrix_variable.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,48 @@ def test_matrix_cons_indicator():
449449
assert m.getVal(z) == 1
450450

451451

452+
def test_matrix_compare_with_expr():
453+
m = Model()
454+
var = m.addVar(vtype="B", ub=0)
455+
456+
# test "<=" and ">=" operator
457+
x = m.addMatrixVar(3)
458+
m.addMatrixCons(x <= var + 1)
459+
m.addMatrixCons(x >= var + 1)
460+
461+
# test "==" operator
462+
y = m.addMatrixVar(3)
463+
m.addMatrixCons(y == var + 1)
464+
465+
m.setObjective(x.sum() + y.sum())
466+
m.optimize()
467+
468+
assert (m.getVal(x) == np.ones(3)).all()
469+
assert (m.getVal(y) == np.ones(3)).all()
470+
471+
472+
def test_ranged_matrix_cons_with_expr():
473+
m = Model()
474+
x = m.addMatrixVar(3)
475+
var = m.addVar(vtype="B", ub=0)
476+
477+
# test MatrixExprCons vs Variable
478+
with pytest.raises(TypeError):
479+
m.addMatrixCons((x <= 1) >= var)
480+
481+
# test "==" operator
482+
with pytest.raises(NotImplementedError):
483+
m.addMatrixCons((x <= 1) == 1)
484+
485+
# test "<=" and ">=" operator
486+
m.addMatrixCons((x <= 1) >= 1)
487+
488+
m.setObjective(x.sum())
489+
m.optimize()
490+
491+
assert (m.getVal(x) == np.ones(3)).all()
492+
493+
452494
_binop_model = Model()
453495

454496
def var():

0 commit comments

Comments
 (0)