Skip to content

Commit 0577ae1

Browse files
Improve testing of do, and fix various quirks in implementation (#81)
* improve do test * don't add edges that include the do node * restructure do code, make it remove predecessors of targets of do * add test that edges are correct after do * make code more readable, upset ruff * check for label, not node object * noqa * test graph in #80 * further simplify do function * pop not del * Apply suggestions from code review Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com> * update docstring, variable names, ruff * move test_algorithms.py to its own folder * make graph functions return tuples * make algorithm functions return tuples * docstring * tuple ... and other fixes * mypy * ruff * tuples in tests * more ruff --------- Co-authored-by: Will Graham <32364977+willGraham01@users.noreply.github.com>
1 parent e64ba98 commit 0577ae1

File tree

7 files changed

+233
-118
lines changed

7 files changed

+233
-118
lines changed

src/causalprog/algorithms/do.py

Lines changed: 84 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,55 @@
22

33
from copy import deepcopy
44

5-
from causalprog.graph import Graph
5+
from causalprog.graph import Graph, Node
6+
7+
8+
def get_included_excluded_successors(
9+
graph: Graph, node_list: dict[str, Node], successors_of: str
10+
) -> tuple[tuple[str, ...], tuple[str, ...]]:
11+
"""
12+
Split successors of a node into nodes included and not included in a list.
13+
14+
Split the successorts of a node into a list of nodes that are included in
15+
the input node list and a list of nodes that are not in the list.
16+
17+
Args:
18+
graph: The graph
19+
node_list: A dictionary of nodes, indexed by label
20+
successors_of: The node to check the successors of
21+
22+
Returns:
23+
Lists of included and excluded nodes
24+
25+
"""
26+
included = []
27+
excluded = []
28+
for n in graph.successors[graph.get_node(successors_of)]:
29+
if n.label in node_list:
30+
included.append(n)
31+
else:
32+
excluded.append(n)
33+
return tuple(included), tuple(excluded)
34+
35+
36+
def removable_nodes(graph: Graph, nodes: dict[str, Node]) -> tuple[str, ...]:
37+
"""
38+
Generate list of nodes that can be removed from the graph.
39+
40+
Args:
41+
graph: The graph
42+
nodes: A dictionary of nodes, indexed by label
43+
44+
Returns:
45+
List of labels of removable nodes
46+
47+
"""
48+
removable: list[str] = []
49+
for n in nodes:
50+
included, excluded = get_included_excluded_successors(graph, nodes, n)
51+
if len(excluded) > 0 and len(included) == 0:
52+
removable.append(n)
53+
return tuple(removable)
654

755

856
def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph:
@@ -22,46 +70,47 @@ def do(graph: Graph, node: str, value: float, label: str | None = None) -> Graph
2270
if label is None:
2371
label = f"{graph.label}|do({node}={value})"
2472

25-
old_g = graph._graph # noqa: SLF001
26-
g = deepcopy(old_g)
27-
28-
nodes_by_label = {n.label: n for n in g.nodes}
29-
g.remove_node(nodes_by_label[node])
73+
nodes = {n.label: deepcopy(n) for n in graph.nodes if n.label != node}
3074

31-
new_nodes = {}
3275
# Search through the old graph, identifying nodes that had parameters which were
3376
# defined by the node being fixed in the DO operation.
3477
# We recreate these nodes, but replace each such parameter we encounter with
3578
# a constant parameter equal that takes the fixed value given as an input.
36-
for original_node in old_g.nodes:
37-
new_n = None
38-
for parameter_name, parameter_target_node in original_node.parameters.items():
39-
if parameter_target_node == node:
40-
# If this parameter in the original_node was determined by the node we
41-
# are fixing with DO.
42-
if new_n is None:
43-
new_n = deepcopy(original_node)
79+
for n in nodes.values():
80+
params = tuple(n.parameters.keys())
81+
for parameter_name in params:
82+
if n.parameters[parameter_name] == node:
4483
# Swap the parameter to a constant parameter, giving it the fixed value
45-
new_n.constant_parameters[parameter_name] = value
84+
n.constant_parameters[parameter_name] = value
4685
# Remove the parameter from the node's record of non-constant parameters
47-
new_n.parameters.pop(parameter_name)
48-
# If we had to recreate a new node, add it to the new (Di)Graph.
49-
# Also record the name of the node that it is set to replace
50-
if new_n is not None:
51-
g.add_node(new_n)
52-
new_nodes[original_node.label] = new_n
53-
54-
# Any new_nodes whose counterparts connect to other nodes in the network need
55-
# to mimic these links.
56-
for edge in old_g.edges:
57-
if edge[0].label in new_nodes or edge[1].label in new_nodes:
58-
g.add_edge(
59-
new_nodes.get(edge[0].label, edge[0]),
60-
new_nodes.get(edge[1].label, edge[1]),
86+
n.parameters.pop(parameter_name)
87+
88+
# Recursively remove nodes that are predecessors of removed nodes
89+
nodes_to_remove: tuple[str, ...] = (node,)
90+
while len(nodes_to_remove) > 0:
91+
nodes_to_remove = removable_nodes(graph, nodes)
92+
for n in removable_nodes(graph, nodes):
93+
nodes.pop(n)
94+
95+
# Check for nodes that are predecessors of both a removed node and a remaining node
96+
# and throw an error if one of these is found
97+
for n in nodes:
98+
_, excluded = get_included_excluded_successors(graph, nodes, n)
99+
if len(excluded) > 0:
100+
msg = (
101+
"Node that is predecessor of node set by do and "
102+
f'nodes that are not removed found ("{n}")'
61103
)
62-
# Now that the new_nodes are present in the graph, and correctly connected, remove
63-
# their counterparts from the graph.
64-
for original_node in new_nodes:
65-
g.remove_node(nodes_by_label[original_node])
104+
raise ValueError(msg)
105+
106+
g = Graph(label=f"{label}|do[{node}={value}]")
107+
for n in nodes.values():
108+
g.add_node(n)
109+
110+
# Any nodes whose counterparts connect to other nodes in the network need
111+
# to mimic these links.
112+
for edge in graph.edges:
113+
if edge[0].label in nodes and edge[1].label in nodes:
114+
g.add_edge(edge[0].label, edge[1].label)
66115

67-
return Graph(label=label, graph=g)
116+
return g

src/causalprog/graph/graph.py

Lines changed: 27 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy.typing as npt
55

66
from causalprog._abc.labelled import Labelled
7-
from causalprog.graph.node import DistributionNode, Node, ParameterNode
7+
from causalprog.graph.node import Node
88

99

1010
class Graph(Labelled):
@@ -89,7 +89,7 @@ def add_edge(self, start_node: Node | str, end_node: Node | str) -> None:
8989
self._graph.add_edge(start_node, end_node)
9090

9191
@property
92-
def parameter_nodes(self) -> tuple[ParameterNode, ...]:
92+
def parameter_nodes(self) -> tuple[Node, ...]:
9393
"""
9494
Returns all parameter nodes in the graph.
9595
@@ -107,31 +107,31 @@ def parameter_nodes(self) -> tuple[ParameterNode, ...]:
107107
return tuple(node for node in self.ordered_nodes if node.is_parameter)
108108

109109
@property
110-
def predecessors(self) -> dict[Node, list[Node]]:
110+
def predecessors(self) -> dict[Node, tuple[Node, ...]]:
111111
"""
112112
Get predecessors of every node.
113113
114114
Returns:
115115
Mapping of each Node to its predecessor Nodes
116116
117117
"""
118-
return {node: list(self._graph.predecessors(node)) for node in self.nodes}
118+
return {node: tuple(self._graph.predecessors(node)) for node in self.nodes}
119119

120120
@property
121-
def successors(self) -> dict[Node, list[Node]]:
121+
def successors(self) -> dict[Node, tuple[Node, ...]]:
122122
"""
123123
Get successors of every node.
124124
125125
Returns:
126126
Mapping of each Node to its successor Nodes.
127127
128128
"""
129-
return {node: list(self._graph.successors(node)) for node in self.nodes}
129+
return {node: tuple(self._graph.successors(node)) for node in self.nodes}
130130

131131
@property
132-
def nodes(self) -> list[Node]:
132+
def nodes(self) -> tuple[Node, ...]:
133133
"""
134-
Get the nodes of the graph, with no enforeced ordering.
134+
Get the nodes of the graph, with no enforced ordering.
135135
136136
Returns:
137137
A list of all the nodes in the graph.
@@ -140,10 +140,21 @@ def nodes(self) -> list[Node]:
140140
ordered_nodes: Fetch an ordered list of the nodes in the graph.
141141
142142
"""
143-
return list(self._graph.nodes())
143+
return tuple(self._graph.nodes())
144144

145145
@property
146-
def ordered_nodes(self) -> list[Node]:
146+
def edges(self) -> tuple[tuple[Node, Node], ...]:
147+
"""
148+
Get the edges of the graph.
149+
150+
Returns:
151+
A tuple of all the edges in the graph.
152+
153+
"""
154+
return tuple(self._graph.edges())
155+
156+
@property
157+
def ordered_nodes(self) -> tuple[Node, ...]:
147158
"""
148159
Nodes ordered so that each node appears after its dependencies.
149160
@@ -155,23 +166,23 @@ def ordered_nodes(self) -> list[Node]:
155166
if not nx.is_directed_acyclic_graph(self._graph):
156167
msg = "Graph is not acyclic."
157168
raise RuntimeError(msg)
158-
return list(nx.topological_sort(self._graph))
169+
return tuple(nx.topological_sort(self._graph))
159170

160171
@property
161-
def ordered_dist_nodes(self) -> list[DistributionNode]:
172+
def ordered_dist_nodes(self) -> tuple[Node, ...]:
162173
"""
163174
`DistributionNode`s in dependency order.
164175
165176
Each `DistributionNode` in the returned list appears after all its
166177
dependencies. Order is derived from `self.ordered_nodes`, selecting
167178
only those nodes where `is_distribution` is `True`.
168179
"""
169-
return [node for node in self.ordered_nodes if node.is_distribution]
180+
return tuple(node for node in self.ordered_nodes if node.is_distribution)
170181

171182
def roots_down_to_outcome(
172183
self,
173184
outcome_node_label: str,
174-
) -> list[Node]:
185+
) -> tuple[Node, ...]:
175186
"""
176187
Get ordered list of nodes that outcome depends on.
177188
@@ -186,9 +197,9 @@ def roots_down_to_outcome(
186197
"""
187198
outcome = self.get_node(outcome_node_label)
188199
ancestors = nx.ancestors(self._graph, outcome)
189-
return [
200+
return tuple(
190201
node for node in self.ordered_nodes if node == outcome or node in ancestors
191-
]
202+
)
192203

193204
def model(self, **parameter_values: npt.ArrayLike) -> dict[str, npt.ArrayLike]:
194205
"""

src/causalprog/graph/node.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,12 @@ def copy(self) -> Node:
201201

202202
@override
203203
def __repr__(self) -> str:
204-
return f'DistributionNode("{self.label}")'
204+
r = f'DistributionNode({self._dist.__name__}, label="{self.label}"'
205+
if len(self._parameters) > 0:
206+
r += f", parameters={self._parameters}"
207+
if len(self._constant_parameters) > 0:
208+
r += f", constant_parameters={self._constant_parameters}"
209+
return r
205210

206211
@override
207212
@property
@@ -283,7 +288,7 @@ def copy(self) -> Node:
283288

284289
@override
285290
def __repr__(self) -> str:
286-
return f'ParameterNode("{self.label}")'
291+
return f'ParameterNode(label="{self.label}")'
287292

288293
@override
289294
@property

tests/test_algorithms/test_do.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Tests for the do algorithm."""
2+
3+
from causalprog import algorithms
4+
from causalprog.graph import Graph, ParameterNode
5+
6+
max_samples = 10**5
7+
8+
9+
def test_do(two_normal_graph, raises_context):
10+
graph = two_normal_graph(5.0, 1.2, 0.8)
11+
graph2 = algorithms.do(graph, "UX", 4.0)
12+
13+
assert "loc" in graph.get_node("X").parameters
14+
assert "loc" not in graph.get_node("X").constant_parameters
15+
assert "loc" not in graph2.get_node("X").parameters
16+
assert "loc" in graph2.get_node("X").constant_parameters
17+
18+
graph.get_node("UX")
19+
with raises_context(KeyError('Node not found with label "UX"')):
20+
graph2.get_node("UX")
21+
22+
23+
def test_do_removes_dependencies(two_normal_graph, raises_context):
24+
graph = two_normal_graph()
25+
graph2 = algorithms.do(graph, "UX", 4.0)
26+
27+
for node in ["UX", "mean", "cov"]:
28+
graph.get_node(node)
29+
with raises_context(KeyError(f'Node not found with label "{node}"')):
30+
graph2.get_node(node)
31+
32+
33+
def test_do_edges(two_normal_graph):
34+
graph = two_normal_graph()
35+
graph2 = algorithms.do(graph, "UX", 4.0)
36+
37+
edges = [(e[0].label, e[1].label) for e in graph.edges]
38+
edges2 = [(e[0].label, e[1].label) for e in graph2.edges]
39+
40+
# Check that correct edges are removed
41+
for e in [
42+
("UX", "X"),
43+
("mean", "UX"),
44+
("cov", "UX"),
45+
]:
46+
assert e in edges
47+
assert e not in edges2
48+
49+
# Check that correct edges remain
50+
for e in [
51+
("cov2", "X"),
52+
]:
53+
assert e in edges
54+
assert e in edges2
55+
56+
57+
def test_do_error(raises_context):
58+
graph = Graph(label="ABC")
59+
graph.add_node(ParameterNode(label="A"))
60+
graph.add_node(ParameterNode(label="B1"))
61+
graph.add_node(ParameterNode(label="B2"))
62+
graph.add_node(ParameterNode(label="C"))
63+
graph.add_edge("A", "B1")
64+
graph.add_edge("A", "B2")
65+
graph.add_edge("B1", "C")
66+
graph.add_edge("B2", "C")
67+
68+
# Currently, applying to do to a node that had predecessors that cannot be removed
69+
# raises an error, see https://github.com/UCL/causalprog/issues/80
70+
with raises_context(
71+
ValueError(
72+
"Node that is predecessor of node set by do and nodes that are not removed "
73+
"found"
74+
)
75+
):
76+
algorithms.do(graph, "B1", 1.0)

0 commit comments

Comments
 (0)