Skip to content

Commit 6cdb132

Browse files
Zeroto521Joao-DionisioCopilot
authored
Feature: MatrixVariable supports numpy broadcast feature (#1092)
* Add broadcast test for matrix constraints * Feature: support matrix broadcast * lint via ruff * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * Update CHANGELOG.md * miss launch optimization * Set up bound and target * lint codes Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update CHANGELOG.md --------- Co-authored-by: João Dionísio <57299939+Joao-Dionisio@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent f335b56 commit 6cdb132

File tree

3 files changed

+19
-12
lines changed

3 files changed

+19
-12
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
### Added
55
- Added possibility of having variables in exponent.
66
- Added basic type stubs to help with IDE autocompletion and type checking.
7+
- MatrixVariable comparisons (<=, >=, ==) now support numpy's broadcast feature.
78
### Fixed
89
- Implemented all binary operations between MatrixExpr and GenExpr
910
- Fixed the type of @ matrix operation result from MatrixVariable to MatrixExpr.

src/pyscipopt/matrix.pxi

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,14 @@ def _matrixexpr_richcmp(self, other, op):
2828
else:
2929
raise NotImplementedError("Can only support constraints with '<=', '>=', or '=='.")
3030

31-
res = np.empty(self.shape, dtype=object)
3231
if _is_number(other) or isinstance(other, Expr):
32+
res = np.empty(self.shape, dtype=object)
3333
res.flat = [_richcmp(i, other, op) for i in self.flat]
3434

3535
elif isinstance(other, np.ndarray):
36-
if self.shape != other.shape:
37-
raise ValueError("Shapes do not match for comparison.")
38-
39-
res.flat = [_richcmp(i, j, op) for i, j in zip(self.flat, other.flat)]
36+
out = np.broadcast(self, other)
37+
res = np.empty(out.shape, dtype=object)
38+
res.flat = [_richcmp(i, j, op) for i, j in out]
4039

4140
else:
4241
raise TypeError(f"Unsupported type {type(other)}")

tests/test_matrix_variable.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,4 @@
11
import operator
2-
import pdb
3-
import pprint
4-
import pytest
5-
from pyscipopt import Model, Variable, log, exp, cos, sin, sqrt
6-
from pyscipopt import Expr, MatrixExpr, MatrixVariable, MatrixExprCons, MatrixConstraint, ExprCons
7-
from pyscipopt.scip import GenExpr
82
from time import time
93

104
import numpy as np
@@ -22,10 +16,10 @@
2216
cos,
2317
exp,
2418
log,
25-
quicksum,
2619
sin,
2720
sqrt,
2821
)
22+
from pyscipopt.scip import GenExpr
2923

3024

3125
def test_catching_errors():
@@ -525,3 +519,16 @@ def test_matrix_matmul_return_type():
525519
y = m.addMatrixVar((2, 3))
526520
z = m.addMatrixVar((3, 4))
527521
assert type(y @ z) is MatrixExpr
522+
523+
524+
def test_broadcast():
525+
# test #1065
526+
m = Model()
527+
x = m.addMatrixVar((2, 3), ub=10)
528+
529+
m.addMatrixCons(x == np.zeros((2, 1)))
530+
531+
m.setObjective(x.sum(), "maximize")
532+
m.optimize()
533+
534+
assert (m.getVal(x) == np.zeros((2, 3))).all()

0 commit comments

Comments
 (0)