|
| 1 | +# ___________________________________________________________________________ |
| 2 | +# |
| 3 | +# Pyomo: Python Optimization Modeling Objects |
| 4 | +# Copyright (c) 2008-2025 |
| 5 | +# National Technology and Engineering Solutions of Sandia, LLC |
| 6 | +# Under the terms of Contract DE-NA0003525 with National Technology and |
| 7 | +# Engineering Solutions of Sandia, LLC, the U.S. Government retains certain |
| 8 | +# rights in this software. |
| 9 | +# This software is distributed under the 3-clause BSD License. |
| 10 | +# ___________________________________________________________________________ |
| 11 | + |
| 12 | +from pyomo.core.expr.visitor import StreamBasedExpressionVisitor |
| 13 | +from pyomo.core.expr.numeric_expr import ( |
| 14 | + ExternalFunctionExpression, |
| 15 | + NegationExpression, |
| 16 | + PowExpression, |
| 17 | + MaxExpression, |
| 18 | + MinExpression, |
| 19 | + ProductExpression, |
| 20 | + MonomialTermExpression, |
| 21 | + DivisionExpression, |
| 22 | + SumExpression, |
| 23 | + Expr_ifExpression, |
| 24 | + UnaryFunctionExpression, |
| 25 | + AbsExpression, |
| 26 | +) |
| 27 | +from pyomo.core.expr.relational_expr import ( |
| 28 | + RangedExpression, |
| 29 | + InequalityExpression, |
| 30 | + EqualityExpression, |
| 31 | +) |
| 32 | +from pyomo.core.base.var import VarData, ScalarVar |
| 33 | +from pyomo.core.base.param import ParamData, ScalarParam |
| 34 | +from pyomo.core.base.expression import ExpressionData, ScalarExpression |
| 35 | +from pyomo.repn.util import ExitNodeDispatcher |
| 36 | +from pyomo.common.collections import ComponentSet |
| 37 | + |
| 38 | + |
| 39 | +def handle_var(node, collector): |
| 40 | + collector.variables.add(node) |
| 41 | + return None |
| 42 | + |
| 43 | + |
| 44 | +def handle_param(node, collector): |
| 45 | + collector.params.add(node) |
| 46 | + return None |
| 47 | + |
| 48 | + |
| 49 | +def handle_named_expression(node, collector): |
| 50 | + collector.named_expressions.add(node) |
| 51 | + return None |
| 52 | + |
| 53 | + |
| 54 | +def handle_external_function(node, collector): |
| 55 | + collector.external_functions.add(node) |
| 56 | + return None |
| 57 | + |
| 58 | + |
| 59 | +def handle_skip(node, collector): |
| 60 | + return None |
| 61 | + |
| 62 | + |
| 63 | +collector_handlers = ExitNodeDispatcher() |
| 64 | +collector_handlers[VarData] = handle_var |
| 65 | +collector_handlers[ParamData] = handle_param |
| 66 | +collector_handlers[ExpressionData] = handle_named_expression |
| 67 | +collector_handlers[ScalarExpression] = handle_named_expression |
| 68 | +collector_handlers[ExternalFunctionExpression] = handle_external_function |
| 69 | +collector_handlers[NegationExpression] = handle_skip |
| 70 | +collector_handlers[PowExpression] = handle_skip |
| 71 | +collector_handlers[MaxExpression] = handle_skip |
| 72 | +collector_handlers[MinExpression] = handle_skip |
| 73 | +collector_handlers[ProductExpression] = handle_skip |
| 74 | +collector_handlers[MonomialTermExpression] = handle_skip |
| 75 | +collector_handlers[DivisionExpression] = handle_skip |
| 76 | +collector_handlers[SumExpression] = handle_skip |
| 77 | +collector_handlers[Expr_ifExpression] = handle_skip |
| 78 | +collector_handlers[UnaryFunctionExpression] = handle_skip |
| 79 | +collector_handlers[AbsExpression] = handle_skip |
| 80 | +collector_handlers[RangedExpression] = handle_skip |
| 81 | +collector_handlers[InequalityExpression] = handle_skip |
| 82 | +collector_handlers[EqualityExpression] = handle_skip |
| 83 | +collector_handlers[int] = handle_skip |
| 84 | +collector_handlers[float] = handle_skip |
| 85 | + |
| 86 | + |
| 87 | +class _ComponentFromExprCollector(StreamBasedExpressionVisitor): |
| 88 | + def __init__(self, **kwds): |
| 89 | + self.named_expressions = ComponentSet() |
| 90 | + self.variables = ComponentSet() |
| 91 | + self.params = ComponentSet() |
| 92 | + self.external_functions = ComponentSet() |
| 93 | + super().__init__(**kwds) |
| 94 | + |
| 95 | + def exitNode(self, node, data): |
| 96 | + return collector_handlers[node.__class__](node, self) |
| 97 | + |
| 98 | + def beforeChild(self, node, child, child_idx): |
| 99 | + if child in self.named_expressions: |
| 100 | + return False, None |
| 101 | + return True, None |
| 102 | + |
| 103 | + |
| 104 | +_visitor = _ComponentFromExprCollector() |
| 105 | + |
| 106 | + |
| 107 | +def collect_components_from_expr(expr): |
| 108 | + _visitor.__init__() |
| 109 | + _visitor.walk_expression(expr) |
| 110 | + return ( |
| 111 | + _visitor.named_expressions, |
| 112 | + _visitor.variables, |
| 113 | + _visitor.params, |
| 114 | + _visitor.external_functions, |
| 115 | + ) |
0 commit comments