1111from __future__ import annotations
1212
1313from collections .abc import Callable
14-
1514from functools import partial
1615from typing import Union
1716
2625ScipyConstraintDict = dict [
2726 str , Union [str , Callable [[np .ndarray ], float ], Callable [[np .ndarray ], np .ndarray ]]
2827]
29- NLC_TOL = - 1e-6
28+ CONST_TOL = 1e-6
3029
3130
3231def make_scipy_bounds (
@@ -511,9 +510,12 @@ def f_grad(X):
511510
512511
513512def nonlinear_constraint_is_feasible (
514- nonlinear_inequality_constraint : Callable , is_intrapoint : bool , x : Tensor
513+ nonlinear_inequality_constraint : Callable ,
514+ is_intrapoint : bool ,
515+ x : Tensor ,
516+ tolerance : float = CONST_TOL ,
515517) -> Tensor :
516- """Checks if a nonlinear inequality constraint is fulfilled.
518+ """Checks if a nonlinear inequality constraint is fulfilled (within tolerance) .
517519
518520 Args:
519521 nonlinear_inequality_constraint: Callable to evaluate the
@@ -523,14 +525,17 @@ def nonlinear_constraint_is_feasible(
523525 constraint has to evaluated over the whole q-batch and is a an
524526 inter-point constraint.
525527 x: Tensor of shape (batch x q x d).
528+ tolerance: Rather than using the exact `const(x) >= 0` constraint, this helper
529+ checks feasibility of `const(x) >= -tolerance`. This avoids marking the
530+ candidates as infeasible due to tiny violations.
526531
527532 Returns:
528533 A boolean tensor of shape (batch) indicating if the constraint is
529534 satified by the corresponding batch of `x`.
530535 """
531536
532537 def check_x (x : Tensor ) -> bool :
533- return _arrayify (nonlinear_inequality_constraint (x )).item () >= NLC_TOL
538+ return _arrayify (nonlinear_inequality_constraint (x )).item () >= - tolerance
534539
535540 x_flat = x .view (- 1 , * x .shape [- 2 :])
536541 is_feasible = torch .ones (x_flat .shape [0 ], dtype = torch .bool , device = x .device )
@@ -603,3 +608,82 @@ def make_scipy_nonlinear_inequality_constraints(
603608 shapeX = shapeX ,
604609 )
605610 return scipy_nonlinear_inequality_constraints
611+
612+
613+ def evaluate_feasibility (
614+ X : Tensor ,
615+ inequality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
616+ equality_constraints : list [tuple [Tensor , Tensor , float ]] | None = None ,
617+ nonlinear_inequality_constraints : list [tuple [Callable , bool ]] | None = None ,
618+ tolerance : float = CONST_TOL ,
619+ ) -> Tensor :
620+ r"""Evaluate feasibility of candidate points (within a tolerance).
621+
622+ Args:
623+ X: The candidate tensor of shape `batch x q x d`.
624+ inequality_constraints: A list of tuples (indices, coefficients, rhs),
625+ with each tuple encoding an inequality constraint of the form
626+ `\sum_i (X[indices[i]] * coefficients[i]) >= rhs`. `indices` and
627+ `coefficients` should be torch tensors. See the docstring of
628+ `make_scipy_linear_constraints` for an example. When q=1, or when
629+ applying the same constraint to each candidate in the batch
630+ (intra-point constraint), `indices` should be a 1-d tensor.
631+ For inter-point constraints, in which the constraint is applied to the
632+ whole batch of candidates, `indices` must be a 2-d tensor, where
633+ in each row `indices[i] =(k_i, l_i)` the first index `k_i` corresponds
634+ to the `k_i`-th element of the `q`-batch and the second index `l_i`
635+ corresponds to the `l_i`-th feature of that element.
636+ equality_constraints: A list of tuples (indices, coefficients, rhs),
637+ with each tuple encoding an equality constraint of the form
638+ `\sum_i (X[indices[i]] * coefficients[i]) = rhs`. See the docstring of
639+ `make_scipy_linear_constraints` for an example.
640+ nonlinear_inequality_constraints: A list of tuples representing the nonlinear
641+ inequality constraints. The first element in the tuple is a callable
642+ representing a constraint of the form `callable(x) >= 0`. In case of an
643+ intra-point constraint, `callable()`takes in an one-dimensional tensor of
644+ shape `d` and returns a scalar. In case of an inter-point constraint,
645+ `callable()` takes a two dimensional tensor of shape `q x d` and again
646+ returns a scalar. The second element is a boolean, indicating if it is an
647+ intra-point or inter-point constraint (`True` for intra-point. `False` for
648+ inter-point). For more information on intra-point vs inter-point
649+ constraints, see the docstring of the `inequality_constraints` argument.
650+ tolerance: The tolerance used to check the feasibility of equality constraints
651+ and non-linear inequality constraints. For equality constraints, we check
652+ if `abs(const(X) - rhs) < tolerance`. For non-linear inequality constraints,
653+ we check if `const(X) >= -tolerance`. This avoids marking the candidates as
654+ infeasible due to tiny violations.
655+
656+ Returns:
657+ A boolean tensor of shape `batch` indicating if the corresponding candidate of
658+ shape `q x d` is feasible.
659+ """
660+ is_feasible = torch .ones (X .shape [:- 2 ], device = X .device , dtype = torch .bool )
661+ if inequality_constraints is not None :
662+ for idx , coef , rhs in inequality_constraints :
663+ if idx .ndim == 1 :
664+ # Intra-point constraints.
665+ is_feasible &= ((X [..., idx ] * coef ).sum (dim = - 1 ) >= rhs ).all (dim = - 1 )
666+ else :
667+ # Inter-point constraints.
668+ is_feasible &= (X [..., idx [:, 0 ], idx [:, 1 ]] * coef ).sum (dim = - 1 ) >= rhs
669+ if equality_constraints is not None :
670+ for idx , coef , rhs in equality_constraints :
671+ if idx .ndim == 1 :
672+ # Intra-point constraints.
673+ is_feasible &= (
674+ ((X [..., idx ] * coef ).sum (dim = - 1 ) - rhs ).abs () < tolerance
675+ ).all (dim = - 1 )
676+ else :
677+ # Inter-point constraints.
678+ is_feasible &= (
679+ (X [..., idx [:, 0 ], idx [:, 1 ]] * coef ).sum (dim = - 1 ) - rhs
680+ ).abs () < tolerance
681+ if nonlinear_inequality_constraints is not None :
682+ for const , intra in nonlinear_inequality_constraints :
683+ is_feasible &= nonlinear_constraint_is_feasible (
684+ nonlinear_inequality_constraint = const ,
685+ is_intrapoint = intra ,
686+ x = X ,
687+ tolerance = tolerance ,
688+ )
689+ return is_feasible
0 commit comments