Skip to content

Commit c66f37b

Browse files
authored
Make CausalEstimand and Constraint callable (#92)
* Reinstate check after renaming caused bugs in sampling * Remove old CausalProblem class * Add barebones classes to be populated later * Update two normal example test to use new infrastructure * Rework so that the lagrangian can be passed model parameters and the multiplier values as separate args * ruffing * Refactor out g.model argument from the Lagrangian call * Make TODOs obvious so I don't forget to do them * Add docstrings and more TODOs * Todo resolution and addition * Make _CPConstraint callable * Hide _CPComponent attributes that we don't expect to change * Test __call__ for _CPComponents * Add note about __call__ in docstring * Fix bug in how handlers are applied * Write tests for features
1 parent 0577ae1 commit c66f37b

File tree

3 files changed

+205
-11
lines changed

3 files changed

+205
-11
lines changed

src/causalprog/causal_problem/causal_estimand.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,34 @@ class _CPComponent:
3434
@property
3535
def requires_model_adaption(self) -> bool:
3636
"""Return True if effect handlers need to be applied to model."""
37-
return len(self.effect_handlers) > 0
37+
return len(self._effect_handlers) > 0
38+
39+
def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
40+
"""
41+
Evaluate the estimand or constraint, given sample values.
42+
43+
Args:
44+
samples: Mapping of RV (node) labels to samples of that RV.
45+
46+
Returns:
47+
Value of the estimand or constraint, given the samples.
48+
49+
"""
50+
return self._do_with_samples(**samples)
3851

3952
def __init__(
4053
self,
4154
*effect_handlers: ModelMask,
4255
do_with_samples: Callable[..., npt.ArrayLike],
4356
) -> None:
44-
self.effect_handlers = tuple(effect_handlers)
45-
self.do_with_samples = do_with_samples
57+
self._effect_handlers = tuple(effect_handlers)
58+
self._do_with_samples = do_with_samples
4659

4760
def apply_effects(self, model: Model) -> Model:
4861
"""Apply any necessary effect handlers prior to evaluating."""
4962
adapted_model = model
50-
for handler, handler_options in self.effect_handlers:
51-
adapted_model = handler(adapted_model, **handler_options)
63+
for handler, handler_options in self._effect_handlers:
64+
adapted_model = handler(adapted_model, handler_options)
5265
return adapted_model
5366

5467

@@ -90,3 +103,4 @@ class Constraint(_CPComponent):
90103
# full constraint that will need to be called in the Lagrangian.
91104
# - $g$ still needs to be scalar valued? Allow a wrapper function to be applied in
92105
# the event $g$ is vector-valued.
106+
# If we do this, will also need to override __call__...

src/causalprog/causal_problem/causal_problem.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,14 +105,10 @@ def _inner(
105105
)
106106
all_samples = sample_model(predictive_model, rng_key, parameter_values)
107107

108-
# TODO: https://github.com/UCL/causalprog/issues/86
109-
value = maximisation_prefactor * self.causal_estimand.do_with_samples(
110-
**all_samples
111-
)
108+
value = maximisation_prefactor * self.causal_estimand(all_samples)
112109
# TODO: https://github.com/UCL/causalprog/issues/87
113110
value += sum(
114-
l_mult[i] * c.do_with_samples(**all_samples)
115-
for i, c in enumerate(self.constraints)
111+
l_mult[i] * c(all_samples) for i, c in enumerate(self.constraints)
116112
)
117113
return value
118114

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
from collections.abc import Callable
2+
3+
import jax.numpy as jnp
4+
import numpy.typing as npt
5+
import pytest
6+
from numpyro.handlers import condition, do
7+
8+
from causalprog.causal_problem.causal_estimand import Model, ModelMask, _CPComponent
9+
from causalprog.graph import Graph
10+
11+
12+
@pytest.mark.parametrize(
13+
("expression", "samples", "expect_error"),
14+
[
15+
pytest.param(
16+
lambda **pv: jnp.atleast_1d(0.0), {}, None, id="Constant expression"
17+
),
18+
pytest.param(
19+
lambda **pv: jnp.atleast_1d(0.0),
20+
{"not_needed": jnp.atleast_1d(0.0)},
21+
None,
22+
id="Un-needed samples",
23+
),
24+
pytest.param(
25+
lambda **pv: pv["a"],
26+
{"a": jnp.atleast_1d(1.0)},
27+
None,
28+
id="All needed samples given",
29+
),
30+
pytest.param(
31+
lambda **pv: pv["b"],
32+
{"a": jnp.atleast_1d(1.0)},
33+
KeyError("b"),
34+
id="Missing sample",
35+
),
36+
],
37+
)
38+
def test_call(
39+
expression: Callable,
40+
samples: dict[str, npt.ArrayLike],
41+
expect_error: Exception | None,
42+
raises_context,
43+
) -> None:
44+
"""Check that _CPComponent correctly calls its _do_with_samples attribute."""
45+
46+
component = _CPComponent(do_with_samples=expression)
47+
48+
assert callable(component)
49+
50+
if expect_error:
51+
with raises_context(expect_error):
52+
component(samples)
53+
else:
54+
assert jnp.allclose(component(samples), expression(**samples))
55+
56+
57+
@pytest.fixture
58+
def conditioned_on_x_1(
59+
two_normal_graph: Callable[..., Graph],
60+
) -> Callable[..., Callable[..., None]]:
61+
"""
62+
Only intended for use in test_apply_handlers.
63+
64+
Builds the model expected when we condition on X=1.
65+
"""
66+
67+
def _inner(**two_normal_graph_options: float) -> Callable[..., None]:
68+
return condition(
69+
two_normal_graph(**two_normal_graph_options).model,
70+
{"X": jnp.atleast_1d(1.0)},
71+
)
72+
73+
return _inner
74+
75+
76+
@pytest.fixture
77+
def double_condition(
78+
two_normal_graph: Callable[..., Graph],
79+
) -> Callable[..., Callable[..., None]]:
80+
"""
81+
Only intended for use in test_apply_handlers.
82+
83+
Builds the model expected when we condition on UX=-10, then again on
84+
UX=10 (which should override the first action).
85+
"""
86+
87+
def _inner(**two_normal_graph_options: float) -> Callable[..., None]:
88+
return condition(
89+
condition(
90+
two_normal_graph(**two_normal_graph_options).model,
91+
{"UX": jnp.atleast_1d(-10.0)},
92+
),
93+
{"UX": jnp.atleast_1d(10.0)},
94+
)
95+
96+
return _inner
97+
98+
99+
@pytest.fixture
100+
def condition_then_do(
101+
two_normal_graph: Callable[..., Graph],
102+
) -> Callable[..., Callable[..., None]]:
103+
"""
104+
Only intended for use in test_apply_handlers.
105+
106+
Builds the model expected when we first condition on UX=0, and then
107+
apply do(X = 10). When sampling, we should still draw samples from
108+
X as per a N(UX, 1.0).
109+
"""
110+
111+
def _inner(**two_normal_graph_options: float) -> Callable[..., None]:
112+
return do(
113+
condition(
114+
two_normal_graph(**two_normal_graph_options).model,
115+
{"UX": jnp.atleast_1d(0.0)},
116+
),
117+
{"X": jnp.atleast_1d(10.0)},
118+
)
119+
120+
return _inner
121+
122+
123+
@pytest.mark.parametrize(
124+
("handlers", "expected_model"),
125+
[
126+
pytest.param(
127+
((condition, {"X": jnp.atleast_1d(1.0)}),),
128+
"conditioned_on_x_1",
129+
id="Condition X to 1",
130+
),
131+
# Should condition on UX=-10, then OVERRIDE this with UX=10.
132+
pytest.param(
133+
(
134+
(condition, {"UX": jnp.atleast_1d(-10.0)}),
135+
(condition, {"UX": jnp.atleast_1d(10.0)}),
136+
),
137+
"double_condition",
138+
id="Condition twice on same variable",
139+
),
140+
# Condition UX=0, but then do X=10.
141+
# Should still observe samples of X given by N(0, 1).
142+
pytest.param(
143+
(
144+
(condition, {"UX": jnp.atleast_1d(0.0)}),
145+
(do, {"X": jnp.atleast_1d(10.0)}),
146+
),
147+
"condition_then_do",
148+
id="Condition then do",
149+
),
150+
],
151+
)
152+
def test_apply_handlers(
153+
handlers: tuple[ModelMask],
154+
expected_model: Model,
155+
two_normal_graph: Callable[..., Graph],
156+
request: pytest.FixtureRequest,
157+
assert_samples_are_identical,
158+
run_default_nuts_mcmc,
159+
two_normal_graph_params: dict[str, float] | None = None,
160+
do_with_samples: Callable[..., npt.ArrayLike] = lambda **pv: pv["X"].mean(),
161+
) -> None:
162+
"""
163+
Test that model handlers are correctly applied to graphs.
164+
165+
Note that the order of the handlers is important, as it dictates
166+
which effects are applied first.
167+
"""
168+
if two_normal_graph_params is None:
169+
two_normal_graph_params = {"mean": 0.0, "cov": 1.0, "cov2": 1.0}
170+
if isinstance(expected_model, str):
171+
expected_model = request.getfixturevalue(expected_model)(
172+
**two_normal_graph_params
173+
)
174+
175+
g = two_normal_graph(**two_normal_graph_params)
176+
177+
cp = _CPComponent(*handlers, do_with_samples=do_with_samples)
178+
179+
handled_model = cp.apply_effects(g.model)
180+
181+
handled_mcmc = run_default_nuts_mcmc(handled_model)
182+
expected_mcmc = run_default_nuts_mcmc(expected_model)
183+
184+
assert_samples_are_identical(handled_mcmc, expected_mcmc)

0 commit comments

Comments
 (0)