Skip to content

Commit f48c551

Browse files
authored
Merge pull request #3636 from jsiirola/numpy-expr
Support using Expressions in numpy matrix operations
2 parents ec73186 + da298bd commit f48c551

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

pyomo/core/base/expression.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,11 @@
3030
import pyomo.core.expr.numeric_expr as numeric_expr
3131
from pyomo.core.base.component import ComponentData, ModelComponentFactory
3232
from pyomo.core.base.global_set import UnindexedComponent_index
33-
from pyomo.core.base.indexed_component import IndexedComponent, UnindexedComponent_set
33+
from pyomo.core.base.indexed_component import (
34+
IndexedComponent,
35+
UnindexedComponent_set,
36+
IndexedComponent_NDArrayMixin,
37+
)
3438
from pyomo.core.expr.numvalue import as_numeric
3539
from pyomo.core.base.initializer import Initializer
3640

@@ -235,7 +239,7 @@ class _GeneralExpressionData(metaclass=RenamedClass):
235239
@ModelComponentFactory.register(
236240
"Named expressions that can be used in other expressions."
237241
)
238-
class Expression(IndexedComponent):
242+
class Expression(IndexedComponent, IndexedComponent_NDArrayMixin):
239243
"""A shared expression container, which may be defined over an index.
240244
241245
Parameters

pyomo/core/tests/unit/test_expr_numpy.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import pyomo.common.unittest as unittest
1313

1414
from pyomo.common.dependencies import numpy as np, numpy_available
15-
from pyomo.environ import ConcreteModel, Var, Constraint
15+
from pyomo.environ import ConcreteModel, Var, Constraint, Param, Expression
1616

1717

1818
@unittest.skipUnless(numpy_available, "tests require numpy")
@@ -37,7 +37,7 @@ def test_scalar_operations(self):
3737
self.assertExpressionsEqual(np.array([5, 6]) * m.x, [5 * m.x, 6 * m.x])
3838
self.assertExpressionsEqual(np.array([8, m.x]) * m.x, [8 * m.x, m.x * m.x])
3939

40-
def test_vector_operations(self):
40+
def test_variable_vector_operations(self):
4141
m = ConcreteModel()
4242
m.x = Var()
4343
m.y = Var([0, 1, 2])
@@ -90,6 +90,21 @@ def test_vector_operations(self):
9090
m.x * m.y * a, [5 * m.y[0] * m.x, 5 * m.y[1] * m.x, 5 * m.y[2] * m.x]
9191
)
9292

93+
def test_expression_vector_operations(self):
94+
m = ConcreteModel()
95+
m.p = Param(range(3), range(2), initialize=lambda m, i, j: 10 * i + j)
96+
97+
m.e = Expression(range(3))
98+
m.f = Expression(range(2))
99+
100+
expr = np.transpose(m.e) @ m.p @ m.f
101+
print(expr)
102+
self.assertExpressionsEqual(
103+
expr,
104+
(m.e[0] * 0 + m.e[1] * 10 + m.e[2] * 20) * m.f[0]
105+
+ (m.e[0] * 1 + m.e[1] * 11 + m.e[2] * 21) * m.f[1],
106+
)
107+
93108

94109
if __name__ == "__main__":
95110
unittest.main()

0 commit comments

Comments
 (0)