Skip to content

Commit dceceb1

Browse files
authored
Merge pull request #3567 from jsiirola/fbbt-numpy
FBBT: resolve bug registering native type handlers
2 parents c5a50b7 + 7b7c7cd commit dceceb1

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

pyomo/contrib/fbbt/fbbt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1105,7 +1105,9 @@ def _before_external_function(visitor, child):
11051105
def _register_new_before_child_handler(visitor, child):
11061106
handlers = _before_child_handlers
11071107
child_type = child.__class__
1108-
if child.is_variable_type():
1108+
if child_type in native_types:
1109+
handlers[child_type] = _before_constant
1110+
elif child.is_variable_type():
11091111
handlers[child_type] = _before_var
11101112
elif not child.is_potentially_variable():
11111113
handlers[child_type] = _before_NPV

pyomo/contrib/fbbt/tests/test_fbbt.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
import pyomo.common.unittest as unittest
1313
import pyomo.environ as pyo
14-
from pyomo.contrib.fbbt.fbbt import fbbt, compute_bounds_on_expr
14+
from pyomo.contrib.fbbt.fbbt import fbbt, compute_bounds_on_expr, _before_child_handlers
1515
from pyomo.common.dependencies import numpy as np, numpy_available
1616
from pyomo.common.fileutils import find_library
1717
from pyomo.common.log import LoggingIntercept
@@ -1363,3 +1363,30 @@ def test_ranged_expression(self):
13631363
self.assertEqual(m.l.bounds, (2, 7))
13641364
self.assertEqual(m.x.bounds, (3, 7))
13651365
self.assertEqual(m.u.bounds, (3, 8))
1366+
1367+
@unittest.skipUnless(numpy_available, "Test requires numpy")
1368+
def test_numpy_leaves(self):
1369+
m = pyo.ConcreteModel()
1370+
m.l = pyo.Var(bounds=(2, None))
1371+
m.x = pyo.Var()
1372+
m.u = pyo.Var(bounds=(None, 8))
1373+
m.c = pyo.Constraint(
1374+
expr=pyo.inequality(m.l + np.int32(1), m.x, m.u - np.float64(1))
1375+
)
1376+
1377+
# Remove the numpy types so we can test that automatic numeric
1378+
# type registrations
1379+
old = [(t, _before_child_handlers.pop(t, None)) for t in (np.int32, np.float64)]
1380+
1381+
try:
1382+
self.tightener(m)
1383+
self.tightener(m)
1384+
self.assertEqual(m.l.bounds, (2, 6.0))
1385+
self.assertEqual(m.x.bounds, (3, 7.0))
1386+
self.assertEqual(m.u.bounds, (4, 8.0))
1387+
finally:
1388+
for t, fcn in old:
1389+
if fcn is None:
1390+
_before_child_handlers.pop(t, None)
1391+
else:
1392+
_before_child_handlers[t] = fcn

0 commit comments

Comments
 (0)