Skip to content

Commit 96e58b9

Browse files
aryan26royadam2392
andauthored
[ENH] Add the ability to check the validity of an MAG (#91)
* Added inducing path checking to MAG check * made find_adc public and added tests --------- Signed-off-by: “Aryan <“aryanroy5678@gmail.com”> Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent 013513b commit 96e58b9

File tree

4 files changed

+223
-1
lines changed

4 files changed

+223
-1
lines changed

doc/api.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ causal graph operations.
4343
.. autosummary::
4444
:toctree: generated/
4545

46+
valid_mag
47+
has_adc
4648
inducing_path
4749
is_valid_mec_graph
4850
possible_ancestors

doc/whats_new/v0.2.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Version 0.2
2525

2626
Changelog
2727
---------
28-
-
28+
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
2929

3030
Code and Documentation Contributors
3131
-----------------------------------
@@ -34,4 +34,5 @@ Thanks to everyone who has contributed to the maintenance and improvement of
3434
the project since version inception, including:
3535

3636
* `Adam Li`_
37+
* `Aryan Roy`_
3738

pywhy_graphs/algorithms/generic.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
"set_nodes_as_latent_confounders",
1414
"is_valid_mec_graph",
1515
"inducing_path",
16+
"has_adc",
17+
"valid_mag",
1618
]
1719

1820

@@ -604,3 +606,100 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
604606
break
605607

606608
return (path_exists, path)
609+
610+
611+
def has_adc(G):
612+
"""Check if a graph has an almost directed cycle (adc).
613+
614+
An almost directed cycle is a is a directed cycle containing
615+
one bidirected edge. For example, ``A -> B -> C <-> A`` is an adc.
616+
617+
Parameters
618+
----------
619+
G : Graph
620+
The graph.
621+
622+
Returns
623+
-------
624+
adc_present : bool
625+
A boolean indicating whether an almost directed cycle is present or not.
626+
"""
627+
628+
adc_present = False
629+
630+
biedges = G.bidirected_edges
631+
632+
for elem in G.nodes:
633+
ancestors = nx.ancestors(G.sub_directed_graph(), elem)
634+
descendants = nx.descendants(G.sub_directed_graph(), elem)
635+
for elem in biedges:
636+
if (elem[0] in ancestors and elem[1] in descendants) or (
637+
elem[1] in ancestors and elem[0] in descendants
638+
): # there is a bidirected edge from one of the ancestors to a descendant
639+
return not adc_present
640+
641+
return adc_present
642+
643+
644+
def valid_mag(G: ADMG, L: set = None, S: set = None):
645+
"""Checks if the provided graph is a valid maximal ancestral graph (MAG).
646+
647+
A valid MAG as defined in :footcite:`Zhang2008` is a mixed edge graph that
648+
only has directed and bi-directed edges, no directed or almost directed
649+
cycles and no inducing paths between any two non-adjacent pair of nodes.
650+
651+
Parameters
652+
----------
653+
G : Graph
654+
The graph.
655+
656+
Returns
657+
-------
658+
is_valid : bool
659+
A boolean indicating whether the provided graph is a valid MAG or not.
660+
661+
"""
662+
663+
if L is None:
664+
L = set()
665+
666+
if S is None:
667+
S = set()
668+
669+
directed_sub_graph = G.sub_directed_graph()
670+
671+
all_nodes = set(G.nodes)
672+
673+
# check if there are any undirected edges or more than one edges b/w two nodes
674+
for node in all_nodes:
675+
nb = set(G.neighbors(node))
676+
for elem in nb:
677+
edge_data = G.get_edge_data(node, elem)
678+
if edge_data["undirected"] is not None:
679+
return False
680+
elif (edge_data["bidirected"] is not None) and (edge_data["directed"] is not None):
681+
return False
682+
683+
# check if there are any directed cyclces
684+
try:
685+
nx.find_cycle(directed_sub_graph) # raises a NetworkXNoCycle error
686+
return False
687+
except nx.NetworkXNoCycle:
688+
pass
689+
690+
# check if there are any almost directed cycles
691+
if has_adc(G): # if there is an ADC, it's not a valid MAG
692+
return False
693+
694+
# check if there are any inducing paths between non-adjacent nodes
695+
696+
for source in all_nodes:
697+
nb = set(G.neighbors(source))
698+
cur_set = all_nodes - nb
699+
cur_set.remove(source)
700+
for dest in cur_set:
701+
out = inducing_path(G, source, dest, L, S)
702+
if out[0] is True:
703+
return False
704+
705+
return True

pywhy_graphs/algorithms/tests/test_generic.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,3 +228,123 @@ def test_is_collider():
228228
S = {"A"}
229229

230230
assert pywhy_graphs.inducing_path(admg, "Z", "Y", L, S)[0]
231+
232+
233+
def test_has_adc():
234+
# K -> H -> Z -> X -> Y -> J <- K
235+
admg = ADMG()
236+
admg.add_edge("Z", "X", admg.directed_edge_name)
237+
admg.add_edge("X", "Y", admg.directed_edge_name)
238+
admg.add_edge("Y", "J", admg.directed_edge_name)
239+
admg.add_edge("H", "Z", admg.directed_edge_name)
240+
admg.add_edge("K", "H", admg.directed_edge_name)
241+
admg.add_edge("K", "J", admg.directed_edge_name)
242+
243+
assert not pywhy_graphs.has_adc(admg) # there is no cycle completed by a bidirected edge
244+
245+
# K -> H -> Z -> X -> Y -> J <-> K
246+
admg = ADMG()
247+
admg.add_edge("Z", "X", admg.directed_edge_name)
248+
admg.add_edge("X", "Y", admg.directed_edge_name)
249+
admg.add_edge("Y", "J", admg.directed_edge_name)
250+
admg.add_edge("H", "Z", admg.directed_edge_name)
251+
admg.add_edge("K", "H", admg.directed_edge_name)
252+
admg.add_edge("Y", "J", admg.directed_edge_name)
253+
admg.add_edge("K", "J", admg.bidirected_edge_name)
254+
255+
assert pywhy_graphs.has_adc(admg) # there is a bidirected edge from J to K, completing a cycle
256+
257+
# K -> H -> Z -> X -> Y <- J <-> K
258+
admg = ADMG()
259+
admg.add_edge("Z", "X", admg.directed_edge_name)
260+
admg.add_edge("X", "Y", admg.directed_edge_name)
261+
admg.add_edge("J", "Y", admg.directed_edge_name)
262+
admg.add_edge("H", "Z", admg.directed_edge_name)
263+
admg.add_edge("K", "H", admg.directed_edge_name)
264+
admg.add_edge("K", "J", admg.bidirected_edge_name)
265+
266+
assert not pywhy_graphs.has_adc(admg) # Y <- J is not correctly oriented
267+
268+
# I -> H -> Z -> X -> Y -> J <-> K
269+
# J -> I
270+
admg = ADMG()
271+
admg.add_edge("Z", "X", admg.directed_edge_name)
272+
admg.add_edge("X", "Y", admg.directed_edge_name)
273+
admg.add_edge("Y", "J", admg.directed_edge_name)
274+
admg.add_edge("H", "Z", admg.directed_edge_name)
275+
admg.add_edge("K", "H", admg.directed_edge_name)
276+
admg.add_edge("Y", "H", admg.directed_edge_name)
277+
admg.add_edge("K", "J", admg.bidirected_edge_name)
278+
279+
assert pywhy_graphs.has_adc(admg) # J <-> K completes an otherwise directed cycle
280+
281+
282+
def test_valid_mag():
283+
# K -> H -> Z -> X -> Y -> J <- K
284+
admg = ADMG()
285+
admg.add_edge("Z", "X", admg.directed_edge_name)
286+
admg.add_edge("X", "Y", admg.directed_edge_name)
287+
admg.add_edge("Y", "J", admg.directed_edge_name)
288+
admg.add_edge("H", "Z", admg.directed_edge_name)
289+
admg.add_edge("K", "H", admg.directed_edge_name)
290+
admg.add_edge("K", "J", admg.directed_edge_name)
291+
292+
S = {"J"}
293+
L = {}
294+
295+
assert not pywhy_graphs.valid_mag(
296+
admg, L, S # J is in S and is a collider on the path Y -> J <- K
297+
)
298+
299+
S = {}
300+
301+
assert pywhy_graphs.valid_mag(admg, L, S) # there are no valid inducing paths
302+
303+
# K -> H -> Z -> X -> Y -> J -> K
304+
admg = ADMG()
305+
admg.add_edge("Z", "X", admg.directed_edge_name)
306+
admg.add_edge("X", "Y", admg.directed_edge_name)
307+
admg.add_edge("Y", "J", admg.directed_edge_name)
308+
admg.add_edge("H", "Z", admg.directed_edge_name)
309+
admg.add_edge("K", "H", admg.directed_edge_name)
310+
admg.add_edge("J", "K", admg.directed_edge_name)
311+
312+
L = {}
313+
314+
assert not pywhy_graphs.valid_mag(admg, L, S) # there is a directed cycle
315+
316+
# K -> H -> Z -> X -> Y -> J <- K
317+
# H <-> J
318+
admg = ADMG()
319+
admg.add_edge("Z", "X", admg.directed_edge_name)
320+
admg.add_edge("X", "Y", admg.directed_edge_name)
321+
admg.add_edge("Y", "J", admg.directed_edge_name)
322+
admg.add_edge("H", "Z", admg.directed_edge_name)
323+
admg.add_edge("K", "H", admg.directed_edge_name)
324+
admg.add_edge("K", "J", admg.directed_edge_name)
325+
admg.add_edge("H", "J", admg.bidirected_edge_name)
326+
327+
assert not pywhy_graphs.valid_mag(admg) # there is an almost directed cycle
328+
329+
admg = ADMG()
330+
admg.add_edge("Z", "X", admg.directed_edge_name)
331+
admg.add_edge("X", "Y", admg.directed_edge_name)
332+
admg.add_edge("Y", "J", admg.directed_edge_name)
333+
admg.add_edge("H", "Z", admg.directed_edge_name)
334+
admg.add_edge("K", "H", admg.directed_edge_name)
335+
admg.add_edge("K", "J", admg.directed_edge_name)
336+
admg.add_edge("H", "J", admg.bidirected_edge_name)
337+
admg.add_edge("H", "J", admg.directed_edge_name)
338+
339+
assert not pywhy_graphs.valid_mag(admg) # there are two edges between H and J
340+
341+
admg = ADMG()
342+
admg.add_edge("Z", "X", admg.directed_edge_name)
343+
admg.add_edge("X", "Y", admg.directed_edge_name)
344+
admg.add_edge("Y", "J", admg.directed_edge_name)
345+
admg.add_edge("H", "Z", admg.directed_edge_name)
346+
admg.add_edge("K", "H", admg.directed_edge_name)
347+
admg.add_edge("K", "J", admg.directed_edge_name)
348+
admg.add_edge("H", "J", admg.undirected_edge_name)
349+
350+
assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J

0 commit comments

Comments
 (0)