Skip to content

Commit f159427

Browse files
Lagrangian respects effect handlers (#94)
* Reinstate check after renaming caused bugs in sampling * Remove old CausalProblem class * Add barebones classes to be populated later * Update two normal example test to use new infrastructure * Rework so that the lagrangian can be passed model parameters and the multiplier values as separate args * ruffing * Refactor out g.model argument from the Lagrangian call * Make TODOs obvious so I don't forget to do them * Add docstrings and more TODOs * Todo resolution and addition * Make _CPConstraint callable * Hide _CPComponent attributes that we don't expect to change * Test __call__ for _CPComponents * Add note about __call__ in docstring * Fix bug in how handlers are applied * Write tests for features * Edit Constraint so it is created in pieces * Rework Constraint.__init__ and docstring to match new format * Update two_normal_example integration test * Remove todo note * Create wrapper class to make handlers easier. __eq__ placeholder for now * Tidy eq docstring * Write necessary comparison methods * Write model association method and implement in Lagrangian * Docstirngs and breakout HandlerToApply class to submodule * Docstirngs and breakout HandlerToApply class to submodule * Tests for HandlerToApply class * Tests for can_use_same_model * Tests for associating models to components of the CP * Reorganise classes now that structure is somewhat rigid * Add module-level import for user-facing functions * Update src/causalprog/causal_problem/causal_problem.py Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com> * Reinstate if not else fix --------- Co-authored-by: Matthew Scroggs <matthew.w.scroggs@gmail.com>
1 parent 1414f8e commit f159427

File tree

10 files changed

+612
-70
lines changed

10 files changed

+612
-70
lines changed
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,7 @@
11
"""Classes for defining causal problems."""
2+
3+
from .causal_problem import CausalProblem
4+
from .components import CausalEstimand, Constraint
5+
from .handlers import HandlerToApply
6+
7+
__all__ = ("CausalEstimand", "CausalProblem", "Constraint", "HandlerToApply")
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
"""Base class for components of causal problems."""
2+
3+
from collections.abc import Callable
4+
from typing import Any
5+
6+
import numpy.typing as npt
7+
8+
from causalprog.causal_problem.handlers import EffectHandler, HandlerToApply, Model
9+
10+
11+
class _CPComponent:
12+
"""
13+
Base class for components of a Causal Problem.
14+
15+
A _CPComponent has an attached method that it can apply to samples
16+
(`do_with_samples`), which will be passed sample values of the RVs
17+
during solution of a Causal Problem and used to evaluate the causal
18+
estimand or constraint the instance represents.
19+
20+
It also has a sequence of effect handlers that need to be applied
21+
to the sampling model before samples can be drawn to evaluate this
22+
component. For example, if a component requires conditioning on the
23+
value of a RV, the `condition` handler needs to be applied to the
24+
underlying model, before generating samples to pass to the
25+
`do_with_sample` method. `effect_handlers` will be applied to the model
26+
in the order they are given.
27+
"""
28+
29+
do_with_samples: Callable[..., npt.ArrayLike]
30+
effect_handlers: tuple[HandlerToApply, ...]
31+
32+
@property
33+
def requires_model_adaption(self) -> bool:
34+
"""Return True if effect handlers need to be applied to model."""
35+
return len(self.effect_handlers) > 0
36+
37+
def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
38+
"""
39+
Evaluate the estimand or constraint, given sample values.
40+
41+
Args:
42+
samples: Mapping of RV (node) labels to samples of that RV.
43+
44+
Returns:
45+
Value of the estimand or constraint, given the samples.
46+
47+
"""
48+
return self._do_with_samples(**samples)
49+
50+
def __init__(
51+
self,
52+
*effect_handlers: HandlerToApply | tuple[EffectHandler, dict[str, Any]],
53+
do_with_samples: Callable[..., npt.ArrayLike],
54+
) -> None:
55+
self.effect_handlers = tuple(
56+
h if isinstance(h, HandlerToApply) else HandlerToApply.from_pair(h)
57+
for h in effect_handlers
58+
)
59+
self._do_with_samples = do_with_samples
60+
61+
def apply_effects(self, model: Model) -> Model:
62+
"""Apply any necessary effect handlers prior to evaluating."""
63+
adapted_model = model
64+
for handler in self.effect_handlers:
65+
adapted_model = handler.handler(adapted_model, handler.options)
66+
return adapted_model
67+
68+
def can_use_same_model_as(self, other: "_CPComponent") -> bool:
69+
"""
70+
Determine if two components use the same (predictive) model.
71+
72+
Two components rely on the same model if they apply the same handlers
73+
to the model, which occurs if and only if `self.effect_handlers` and
74+
`other.effect_handlers` contain identical entries, in the same order.
75+
"""
76+
if (not isinstance(other, _CPComponent)) or (
77+
len(self.effect_handlers) != len(other.effect_handlers)
78+
):
79+
return False
80+
81+
return all(
82+
my_handler == their_handler
83+
for my_handler, their_handler in zip(
84+
self.effect_handlers, other.effect_handlers, strict=True
85+
)
86+
)

src/causalprog/causal_problem/causal_problem.py

Lines changed: 100 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
import numpy.typing as npt
77
from numpyro.infer import Predictive
88

9-
from causalprog.causal_problem.causal_estimand import CausalEstimand, Constraint
9+
from causalprog.causal_problem._base_component import _CPComponent
10+
from causalprog.causal_problem.components import (
11+
CausalEstimand,
12+
Constraint,
13+
)
1014
from causalprog.graph import Graph
1115

1216

@@ -40,6 +44,11 @@ class CausalProblem:
4044
causal_estimand: CausalEstimand
4145
constraints: list[Constraint]
4246

47+
@property
48+
def _ordered_components(self) -> list[_CPComponent]:
49+
"""Internal ordering for components of the causal problem."""
50+
return [*self.constraints, self.causal_estimand]
51+
4352
def __init__(
4453
self,
4554
graph: Graph,
@@ -51,6 +60,83 @@ def __init__(
5160
self.causal_estimand = causal_estimand
5261
self.constraints = list(constraints)
5362

63+
def _associate_models_to_components(
64+
self, n_samples: int
65+
) -> tuple[list[Predictive], list[int]]:
66+
"""
67+
Create models to be used by components of the problem.
68+
69+
Depending on how many constraints (and the causal estimand) require effect
70+
handlers to wrap `self._underlying_graph.model`, we will need to create several
71+
predictive models to sample from. However, we also want to minimise the number
72+
of such models we have to make, in order to minimise the time we spend
73+
actually computing samples.
74+
75+
As such, in this method we determine:
76+
- How many models we will need to build, by grouping the constraints and the
77+
causal estimand by the handlers they use.
78+
- Build these models, returning them in a list called `models`.
79+
- Build another list that maps the index of components in
80+
`self._ordered_components` to the index of the model in `models` that they
81+
use. The causal estimand is by convention the component at index -1 of this
82+
returned list.
83+
84+
Args:
85+
n_samples: Value to be passed to `numpyro.Predictive`'s `num_samples`
86+
argument for each of the models that are constructed from the underlying
87+
graph.
88+
89+
Returns:
90+
list[Predictive]: List of Predictive models, whose elements contain all the
91+
models needed by the components.
92+
list[int]: Mapping of component indexes (as per `self_ordered_components`)
93+
to the index of the model in the first return argument that the
94+
component uses.
95+
96+
"""
97+
models: list[Predictive] = []
98+
grouped_component_indexes: list[list[int]] = []
99+
for index, component in enumerate(self._ordered_components):
100+
# Determine if this constraint uses the same handlers as those of any of
101+
# the other sets.
102+
belongs_to_existing_group = False
103+
for group in grouped_component_indexes:
104+
# Pull any element from the group to compare models to.
105+
# Items in a group are known to have the same model, so we can just
106+
# pull out the first one.
107+
group_element = self._ordered_components[group[0]]
108+
# Check if the current constraint can also use this model.
109+
if component.can_use_same_model_as(group_element):
110+
group.append(index)
111+
belongs_to_existing_group = True
112+
break
113+
114+
# If the component does not fit into any existing group, create a new
115+
# group for it. And add the model corresponding to the group to the
116+
# list of models.
117+
if not belongs_to_existing_group:
118+
grouped_component_indexes.append([index])
119+
120+
models.append(
121+
Predictive(
122+
component.apply_effects(self._underlying_graph.model),
123+
num_samples=n_samples,
124+
)
125+
)
126+
127+
# Now "invert" the grouping, creating a mapping that maps the index of a
128+
# component to the (index of the) model it uses.
129+
component_index_to_model_index = []
130+
for index in range(len(self._ordered_components)):
131+
for group_index, group in enumerate(grouped_component_indexes):
132+
if index in group:
133+
component_index_to_model_index.append(group_index)
134+
break
135+
# All indexes should belong to at least one group (worst case scenario,
136+
# their own individual group). Thus, it is safe to do the above to create
137+
# the mapping from component index -> model (group) index.
138+
return models, component_index_to_model_index
139+
54140
def lagrangian(
55141
self, n_samples: int = 1000, *, maximum_problem: bool = False
56142
) -> Callable[[dict[str, npt.ArrayLike], npt.ArrayLike, jax.Array], npt.ArrayLike]:
@@ -89,26 +175,28 @@ def lagrangian(
89175
"""
90176
maximisation_prefactor = -1.0 if maximum_problem else 1.0
91177

178+
# Build association between self.constraints and the model-samples that each
179+
# one needs to use. We do this here, since once it is constructed, it is
180+
# fixed, and doesn't need to be done each time we call the Lagrangian.
181+
models, component_to_index_mapping = self._associate_models_to_components(
182+
n_samples
183+
)
184+
92185
def _inner(
93186
parameter_values: dict[str, npt.ArrayLike],
94187
l_mult: jax.Array,
95188
rng_key: jax.Array,
96189
) -> npt.ArrayLike:
97-
# In general, we will need to check which of our CE/CONs require masking,
98-
# and do multiple predictive models to account for this...
99-
# We can always pre-build the predictive models too, so we should replace
100-
# the "model" input with something that can map the right predictive models
101-
# to the CE/CONS that need them.
102-
# TODO: https://github.com/UCL/causalprog/issues/90
103-
predictive_model = Predictive(
104-
model=self._underlying_graph.model, num_samples=n_samples
190+
# Draw samples from all models
191+
all_samples = tuple(
192+
sample_model(model, rng_key, parameter_values) for model in models
105193
)
106-
all_samples = sample_model(predictive_model, rng_key, parameter_values)
107194

108-
value = maximisation_prefactor * self.causal_estimand(all_samples)
195+
value = maximisation_prefactor * self.causal_estimand(all_samples[-1])
109196
# TODO: https://github.com/UCL/causalprog/issues/87
110197
value += sum(
111-
l_mult[i] * c(all_samples) for i, c in enumerate(self.constraints)
198+
l_mult[i] * c(all_samples[component_to_index_mapping[i]])
199+
for i, c in enumerate(self.constraints)
112200
)
113201
return value
114202

src/causalprog/causal_problem/causal_estimand.py renamed to src/causalprog/causal_problem/components.py

Lines changed: 2 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -6,66 +6,13 @@
66
import jax.numpy as jnp
77
import numpy.typing as npt
88

9+
from causalprog.causal_problem._base_component import _CPComponent
10+
911
Model: TypeAlias = Callable[..., Any]
1012
EffectHandler: TypeAlias = Callable[Concatenate[Model, ...], Model]
1113
ModelMask: TypeAlias = tuple[EffectHandler, dict]
1214

1315

14-
class _CPComponent:
15-
"""
16-
Base class for components of a Causal Problem.
17-
18-
A _CPComponent has an attached method that it can apply to samples
19-
(`do_with_samples`), which will be passed sample values of the RVs
20-
during solution of a Causal Problem and used to evaluate the causal
21-
estimand or constraint the instance represents.
22-
23-
It also has a sequence of effect handlers that need to be applied
24-
to the sampling model before samples can be drawn to evaluate this
25-
component. For example, if a component requires conditioning on the
26-
value of a RV, the `condition` handler needs to be applied to the
27-
underlying model, before generating samples to pass to the
28-
`do_with_sample` method. `effect_handlers` will be applied to the model
29-
in the order they are given.
30-
"""
31-
32-
do_with_samples: Callable[..., npt.ArrayLike]
33-
effect_handlers: tuple[ModelMask, ...]
34-
35-
@property
36-
def requires_model_adaption(self) -> bool:
37-
"""Return True if effect handlers need to be applied to model."""
38-
return len(self._effect_handlers) > 0
39-
40-
def __call__(self, samples: dict[str, npt.ArrayLike]) -> npt.ArrayLike:
41-
"""
42-
Evaluate the estimand or constraint, given sample values.
43-
44-
Args:
45-
samples: Mapping of RV (node) labels to samples of that RV.
46-
47-
Returns:
48-
Value of the estimand or constraint, given the samples.
49-
50-
"""
51-
return self._do_with_samples(**samples)
52-
53-
def __init__(
54-
self,
55-
*effect_handlers: ModelMask,
56-
do_with_samples: Callable[..., npt.ArrayLike],
57-
) -> None:
58-
self._effect_handlers = tuple(effect_handlers)
59-
self._do_with_samples = do_with_samples
60-
61-
def apply_effects(self, model: Model) -> Model:
62-
"""Apply any necessary effect handlers prior to evaluating."""
63-
adapted_model = model
64-
for handler, handler_options in self._effect_handlers:
65-
adapted_model = handler(adapted_model, handler_options)
66-
return adapted_model
67-
68-
6916
class CausalEstimand(_CPComponent):
7017
"""
7118
A Causal Estimand.

0 commit comments

Comments
 (0)