Skip to content

Commit 1414f8e

Browse files
Constraint creation changes (#93)
* Reinstate check after renaming caused bugs in sampling * Remove old CausalProblem class * Add barebones classes to be populated later * Update two normal example test to use new infrastructure * Rework so that the lagrangian can be passed model parameters and the multiplier values as separate args * ruffing * Refactor out g.model argument from the Lagrangian call * Make TODOs obvious so I don't forget to do them * Add docstrings and more TODOs * Todo resolution and addition * Make _CPConstraint callable * Hide _CPComponent attributes that we don't expect to change * Test __call__ for _CPComponents * Add note about __call__ in docstring * Fix bug in how handlers are applied * Write tests for features * Edit Constraint so it is created in pieces * Rework Constraint.__init__ and docstring to match new format * Update two_normal_example integration test * Remove todo note * Update src/causalprog/causal_problem/causal_estimand.py Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com> --------- Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com>
1 parent d547c73 commit 1414f8e

File tree

2 files changed

+75
-11
lines changed

2 files changed

+75
-11
lines changed

src/causalprog/causal_problem/causal_estimand.py

Lines changed: 72 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections.abc import Callable
44
from typing import Any, Concatenate, TypeAlias
55

6+
import jax.numpy as jnp
67
import numpy.typing as npt
78

89
Model: TypeAlias = Callable[..., Any]
@@ -94,13 +95,74 @@ class Constraint(_CPComponent):
9495
and $\epsilon$ is some tolerance.
9596
"""
9697

97-
# TODO: (https://github.com/UCL/causalprog/issues/89)
98-
# Should explain that Constraint needs more inputs and slightly different
99-
# interpretation of the `do_with_samples` object.
100-
# Inputs:
101-
# - include epsilon as an input (allows constraints to have different tolerances)
102-
# - `do_with_samples` should just be $g(\theta)$. Then have the instance build the
103-
# full constraint that will need to be called in the Lagrangian.
104-
# - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in
105-
# the event $g$ is vector-valued.
106-
# If we do this, will also need to override __call__...
98+
data: npt.ArrayLike
99+
tolerance: npt.ArrayLike
100+
_outer_norm: Callable[[npt.ArrayLike], float]
101+
102+
def __init__(
103+
self,
104+
*effect_handlers: ModelMask,
105+
model_quantity: Callable[..., npt.ArrayLike],
106+
outer_norm: Callable[[npt.ArrayLike], float] | None = None,
107+
data: npt.ArrayLike = 0.0,
108+
tolerance: float = 1.0e-6,
109+
) -> None:
110+
r"""
111+
Create a new constraint.
112+
113+
Constraints have the form
114+
115+
$$ c(\theta) :=
116+
\mathrm{norm}\left( g(\theta)
117+
- g_{\mathrm{data}} \right)
118+
- \epsilon $$
119+
120+
where;
121+
- $\mathrm{norm}$ is the outer norm of the constraint (`outer_norm`),
122+
- $g(\theta)$ is the model quantity involved in the constraint
123+
(`model_quantity`),
124+
- $g_{\mathrm{data}}$ is the observed data (`data`),
125+
- $\epsilon$ is the tolerance in the data (`tolerance`).
126+
127+
In a causal problem, each constraint appears as the condition $c(\theta)\leq 0$
128+
in the minimisation / maximisation (hence the inclusion of the $-\epsilon$
129+
term within $c(\theta)$ itself).
130+
131+
$g$ should be a (possibly vector-valued) function that acts on (a subset of)
132+
samples from the random variables of the causal problem. It must accept
133+
variable keyword-arguments only, and should access the samples for each random
134+
variable by indexing via the RV names (node labels). It should return the
135+
model quantity as computed from the samples, that $g_{\mathrm{data}}$ observed.
136+
137+
$g_{\mathrm{data}}$ should be a fixed value whose shape is broadcast-able with
138+
the return shape of $g$. It defaults to $0$ if not explicitly set.
139+
140+
$\mathrm{norm}$ should be a suitable norm to take on the difference between the
141+
model quantity as predicted by the samples ($g$) and the observed data
142+
($g_{\mathrm{data}}$). It must return a scalar value. The default is the 2-norm.
143+
"""
144+
super().__init__(*effect_handlers, do_with_samples=model_quantity)
145+
146+
if outer_norm is None:
147+
self._outer_norm = jnp.linalg.vector_norm
148+
else:
149+
self._outer_norm = outer_norm
150+
151+
self.data = data
152+
self.tolerance = tolerance
153+
154+
def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
155+
"""
156+
Evaluate the constraint, given RV samples.
157+
158+
Args:
159+
samples: Mapping of RV (node) labels to drawn samples.
160+
161+
Returns:
162+
Value of the constraint.
163+
164+
"""
165+
return (
166+
self._outer_norm(self._do_with_samples(**samples) - self.data)
167+
- self.tolerance
168+
)

tests/test_integration/test_two_normal_example.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def test_two_normal_example(
7878
# Setup the optimisation problem from the graph
7979
ce = CausalEstimand(do_with_samples=lambda **pv: pv["X"].mean())
8080
con = Constraint(
81-
do_with_samples=lambda **pv: jnp.abs(pv["UX"].mean() - phi_observed) - epsilon
81+
model_quantity=lambda **pv: pv["UX"].mean(),
82+
data=phi_observed,
83+
tolerance=epsilon,
8284
)
8385
cp = CausalProblem(
8486
two_normal_graph(cov=1.0),

0 commit comments

Comments
 (0)