Skip to content

Commit a6dd8ab

Browse files
Feature: Speed up MatrixVariable.sum(axis=None) via quicksum (#1078)
* Speed up `matrix.sum` * Add test for sum with axis in matrix variable Added a test to verify that summing a matrix variable with axis=0 returns a MatrixExpr with the expected shape. This improves coverage for matrix sum operations with axis specified. * Add matrix sum performance test Introduced a new test to compare performance of matrix sum versus element-wise sum. Refactored imports for clarity and consistency. Renamed performance test for better description. * Use a more large size data * Fix variable naming in test_sum_performance Renamed the Model instance from 'm' to 'model' to avoid confusion with the integer variable 'm' and improve code clarity in the test_sum_performance function. * Fix variable name in performance test assertion Replaces incorrect usage of 'm' with 'model' in the assertion within test_sum_performance to ensure the correct object is referenced. * Fix expected shape in matrix sum test Updated the assertion in test_matrix_sum_argument to expect shape (1,) instead of (1, 1) when summing along axis 0. This aligns the test with the actual output of the sum operation. * Update CHANGELOG.md * Adjust performance test assertion in matrix variable tests Modified the assertion in test_sum_performance to compare orig_time + 1 with matrix_time instead of orig_time. This may address timing precision or test flakiness. * Try a bigger data size * Compare `np.sum` and `quicksum` * Try a bigger size --------- Co-authored-by: João Dionísio <57299939+Joao-Dionisio@users.noreply.github.com>
1 parent 53289a0 commit a6dd8ab

File tree

3 files changed

+50
-10
lines changed

3 files changed

+50
-10
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
### Fixed
77
- Fixed the type of @ matrix operation result from MatrixVariable to MatrixExpr.
88
### Changed
9+
- Speed up MatrixVariable.sum(axis=None) via quicksum
910
### Removed
1011

1112
## v5.6.0 - 2025.08.26

src/pyscipopt/matrix.pxi

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,14 @@ def _is_number(e):
1818
class MatrixExpr(np.ndarray):
1919
def sum(self, **kwargs):
2020
"""
21-
Based on `numpy.ndarray.sum`, but returns a scalar if the result is a single value.
22-
This is useful for matrix expressions where the sum might reduce to a single value.
21+
Based on `numpy.ndarray.sum`, but returns a scalar if `axis=None`.
22+
This is useful for matrix expressions to compare with a matrix or a scalar.
2323
"""
24-
res = super().sum(**kwargs)
25-
return res if res.size > 1 else res.item()
24+
25+
if kwargs.get("axis") is None:
26+
# Speed up `.sum()` #1070
27+
return quicksum(self.flat)
28+
return super().sum(**kwargs)
2629

2730
def __le__(self, other: Union[float, int, Variable, np.ndarray, 'MatrixExpr']) -> np.ndarray:
2831

tests/test_matrix_variable.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,24 @@
1-
import pdb
2-
import pprint
3-
import pytest
4-
from pyscipopt import Model, Variable, log, exp, cos, sin, sqrt
5-
from pyscipopt import Expr, MatrixExpr, MatrixVariable, MatrixExprCons, MatrixConstraint, ExprCons
61
from time import time
72

83
import numpy as np
4+
import pytest
5+
6+
from pyscipopt import (
7+
Expr,
8+
ExprCons,
9+
MatrixConstraint,
10+
MatrixExpr,
11+
MatrixExprCons,
12+
MatrixVariable,
13+
Model,
14+
Variable,
15+
cos,
16+
exp,
17+
log,
18+
quicksum,
19+
sin,
20+
sqrt,
21+
)
922

1023

1124
def test_catching_errors():
@@ -170,6 +183,10 @@ def test_expr_from_matrix_vars():
170183
def test_matrix_sum_argument():
171184
m = Model()
172185

186+
# Return a array when axis isn't None
187+
res = m.addMatrixVar((3, 1)).sum(axis=0)
188+
assert isinstance(res, MatrixExpr) and res.shape == (1,)
189+
173190
# compare the result of summing 2d array to a scalar with a scalar
174191
x = m.addMatrixVar((2, 3), "x", "I", ub=4)
175192
m.addMatrixCons(x.sum() == 24)
@@ -192,6 +209,25 @@ def test_matrix_sum_argument():
192209
assert (m.getVal(x) == np.full((2, 3), 4)).all().all()
193210
assert (m.getVal(y) == np.full((2, 4), 3)).all().all()
194211

212+
213+
def test_sum_performance():
214+
n = 1000
215+
model = Model()
216+
x = model.addMatrixVar((n, n))
217+
218+
# Original sum via `np.sum`
219+
start_orig = time()
220+
np.sum(x)
221+
end_orig = time()
222+
223+
# Optimized sum via `quicksum`
224+
start_matrix = time()
225+
x.sum()
226+
end_matrix = time()
227+
228+
assert model.isGT(end_orig - start_orig, end_matrix - start_matrix)
229+
230+
195231
def test_add_cons_matrixVar():
196232
m = Model()
197233
matrix_variable = m.addMatrixVar(shape=(3, 3), vtype="B", name="A", obj=1)
@@ -339,7 +375,7 @@ def test_MatrixVariable_attributes():
339375
assert x.varMayRound().tolist() == [[True, True], [True, True]]
340376

341377
@pytest.mark.skip(reason="Performance test")
342-
def test_performance():
378+
def test_add_cons_performance():
343379
start_orig = time()
344380
m = Model()
345381
x = {}

0 commit comments

Comments
 (0)