Skip to content

Commit 9f3e202

Browse files
aryan26royadam2392
andauthored
[ENH] Add the ability to convert a DAG to an MAG (#96)
* Add is_maximal function * add DAG to MAG function --------- Signed-off-by: Aryan Roy <aryanroy5678@gmail.com> Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent 37ede7f commit 9f3e202

File tree

4 files changed

+246
-0
lines changed

4 files changed

+246
-0
lines changed

doc/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ causal graph operations.
4343
.. autosummary::
4444
:toctree: generated/
4545

46+
dag_to_mag
4647
valid_mag
4748
has_adc
4849
inducing_path

doc/whats_new/v0.2.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ Version 0.2
2626
Changelog
2727
---------
2828
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
29+
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)
2930

3031
Code and Documentation Contributors
3132
-----------------------------------

pywhy_graphs/algorithms/generic.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
"inducing_path",
1616
"has_adc",
1717
"valid_mag",
18+
"dag_to_mag",
19+
"is_maximal",
1820
]
1921

2022

@@ -567,6 +569,9 @@ def inducing_path(G, node_x: Node, node_y: Node, L: Set = None, S: Set = None):
567569
if node_x == node_y:
568570
raise ValueError("The source and destination nodes are the same.")
569571

572+
if (node_x in L) or (node_y in L) or (node_x in S) or (node_y in S):
573+
return (False, [])
574+
570575
edges = G.edges()
571576

572577
# XXX: fix this when graphs are refactored to only check for directed/bidirected edge types
@@ -703,3 +708,118 @@ def valid_mag(G: ADMG, L: set = None, S: set = None):
703708
return False
704709

705710
return True
711+
712+
713+
def dag_to_mag(G, L: Set = None, S: Set = None):
714+
"""Converts a DAG to a valid MAG.
715+
716+
The algorithm is defined in :footcite:`Zhang2008` on page 1877.
717+
718+
Parameters:
719+
-----------
720+
G : Graph
721+
The graph.
722+
L : Set
723+
Nodes that are ignored on the path. Defaults to an empty set.
724+
S : Set
725+
Nodes that are always conditioned on. Defaults to an empty set.
726+
727+
Returns
728+
-------
729+
mag : Graph
730+
The MAG.
731+
"""
732+
733+
if L is None:
734+
L = set()
735+
736+
if S is None:
737+
S = set()
738+
739+
# for each pair of nodes find if they have an inducing path between them.
740+
# only then will they be adjacent in the MAG.
741+
742+
all_nodes = set(G.nodes)
743+
adj_nodes = []
744+
745+
for source in all_nodes:
746+
copy_all = all_nodes.copy()
747+
copy_all.remove(source)
748+
for dest in copy_all:
749+
out = inducing_path(G, source, dest, L, S)
750+
if out[0] is True and {source, dest} not in adj_nodes:
751+
adj_nodes.append({source, dest})
752+
753+
# find the ancestors of B U S (ansB) and A U S (ansA) for each pair of adjacent nodes
754+
755+
mag = ADMG()
756+
757+
for A, B in adj_nodes:
758+
759+
AuS = S.union(A)
760+
BuS = S.union(B)
761+
762+
ansA: Set = set()
763+
ansB: Set = set()
764+
765+
for node in AuS:
766+
ansA = ansA.union(_directed_sub_graph_ancestors(G, node))
767+
768+
for node in BuS:
769+
ansB = ansB.union(_directed_sub_graph_ancestors(G, node))
770+
771+
if A in ansB and B not in ansA:
772+
# if A is in ansB and B is not in ansA, A -> B
773+
mag.add_edge(A, B, mag.directed_edge_name)
774+
775+
elif A not in ansB and B in ansA:
776+
# if B is in ansA and A is not in ansB, A <- B
777+
mag.add_edge(B, A, mag.directed_edge_name)
778+
779+
elif A not in ansB and B not in ansA:
780+
# if A is not in ansB and B is not in ansA, A <-> B
781+
mag.add_edge(B, A, mag.bidirected_edge_name)
782+
783+
elif A in ansB and B in ansA:
784+
# if A is in ansB and B is in ansA, A - B
785+
mag.add_edge(B, A, mag.undirected_edge_name)
786+
787+
return mag
788+
789+
790+
def is_maximal(G, L: Set = None, S: Set = None):
791+
"""Checks to see if the graph is maximal.
792+
793+
Parameters:
794+
-----------
795+
G : Graph
796+
The graph.
797+
798+
Returns
799+
-------
800+
is_maximal : bool
801+
A boolean indicating whether the provided graph is maximal or not.
802+
"""
803+
804+
if L is None:
805+
L = set()
806+
807+
if S is None:
808+
S = set()
809+
810+
all_nodes = set(G.nodes)
811+
checked = set()
812+
for source in all_nodes:
813+
nb = set(G.neighbors(source))
814+
cur_set = all_nodes - nb
815+
cur_set.remove(source)
816+
for dest in cur_set:
817+
current_pair = frozenset({source, dest})
818+
if current_pair not in checked:
819+
checked.add(current_pair)
820+
out = inducing_path(G, source, dest, L, S)
821+
if out[0] is True:
822+
return False
823+
else:
824+
continue
825+
return True

pywhy_graphs/algorithms/tests/test_generic.py

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,30 @@ def test_inducing_path_corner_cases():
214214

215215
assert pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]
216216

217+
# X -> Z <- Y, A <- B <- Z
218+
admg = ADMG()
219+
admg.add_edge("X", "Z", admg.directed_edge_name)
220+
admg.add_edge("Y", "Z", admg.directed_edge_name)
221+
admg.add_edge("Z", "B", admg.directed_edge_name)
222+
admg.add_edge("B", "A", admg.directed_edge_name)
223+
224+
L = {"X"}
225+
S = {"A"}
226+
227+
assert not pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]
228+
229+
# X -> Z <- Y, A <- B <- Z
230+
admg = ADMG()
231+
admg.add_edge("X", "Z", admg.directed_edge_name)
232+
admg.add_edge("Y", "Z", admg.directed_edge_name)
233+
admg.add_edge("Z", "B", admg.directed_edge_name)
234+
admg.add_edge("B", "A", admg.directed_edge_name)
235+
236+
L = {}
237+
S = {"A", "Y"}
238+
239+
assert not pywhy_graphs.inducing_path(admg, "X", "Y", L, S)[0]
240+
217241

218242
def test_is_collider():
219243
# Z -> X -> A <- B -> Y; H -> A
@@ -348,3 +372,103 @@ def test_valid_mag():
348372
admg.add_edge("H", "J", admg.undirected_edge_name)
349373

350374
assert not pywhy_graphs.valid_mag(admg) # there is an undirected edge between H and J
375+
376+
377+
def test_dag_to_mag():
378+
379+
# A -> E -> S
380+
# H -> E , H -> R
381+
admg = ADMG()
382+
admg.add_edge("A", "E", admg.directed_edge_name)
383+
admg.add_edge("E", "S", admg.directed_edge_name)
384+
admg.add_edge("H", "E", admg.directed_edge_name)
385+
admg.add_edge("H", "R", admg.directed_edge_name)
386+
387+
S = {"S"}
388+
L = {"H"}
389+
390+
out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
391+
assert pywhy_graphs.is_maximal(out_mag)
392+
assert not pywhy_graphs.has_adc(out_mag)
393+
out_edges = out_mag.edges()
394+
dir_edges = list(out_edges["directed"])
395+
assert (
396+
("A", "R") in out_edges["directed"]
397+
and ("E", "R") in out_edges["directed"]
398+
and len(out_edges["directed"]) == 2
399+
)
400+
assert ("A", "E") in out_edges["undirected"]
401+
402+
out_mag = pywhy_graphs.dag_to_mag(admg)
403+
dir_edges = list(out_mag.edges()["directed"])
404+
405+
assert (
406+
("A", "E") in dir_edges
407+
and ("E", "S") in dir_edges
408+
and ("H", "E") in dir_edges
409+
and ("H", "R") in dir_edges
410+
)
411+
412+
# A -> E -> S <- H
413+
# H -> E , H -> R,
414+
415+
admg = ADMG()
416+
admg.add_edge("A", "E", admg.directed_edge_name)
417+
admg.add_edge("H", "S", admg.directed_edge_name)
418+
admg.add_edge("H", "E", admg.directed_edge_name)
419+
admg.add_edge("H", "R", admg.directed_edge_name)
420+
421+
S = {"S"}
422+
L = {"H"}
423+
424+
out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
425+
assert pywhy_graphs.is_maximal(out_mag)
426+
assert not pywhy_graphs.has_adc(out_mag)
427+
out_edges = out_mag.edges()
428+
429+
dir_edges = list(out_edges["directed"])
430+
assert ("A", "E") in out_edges["directed"] and len(out_edges["directed"]) == 1
431+
assert ("E", "R") in out_edges["bidirected"]
432+
433+
# P -> S -> L <- G
434+
# G -> S -> I <- J
435+
# J -> S
436+
437+
admg = ADMG()
438+
admg.add_edge("P", "S", admg.directed_edge_name)
439+
admg.add_edge("S", "L", admg.directed_edge_name)
440+
admg.add_edge("G", "S", admg.directed_edge_name)
441+
admg.add_edge("G", "L", admg.directed_edge_name)
442+
admg.add_edge("I", "S", admg.directed_edge_name)
443+
admg.add_edge("J", "I", admg.directed_edge_name)
444+
admg.add_edge("J", "S", admg.directed_edge_name)
445+
446+
S = set()
447+
L = {"J"}
448+
449+
out_mag = pywhy_graphs.dag_to_mag(admg, L, S)
450+
assert pywhy_graphs.is_maximal(out_mag)
451+
assert not pywhy_graphs.has_adc(out_mag)
452+
out_edges = out_mag.edges()
453+
dir_edges = list(out_edges["directed"])
454+
assert (
455+
("G", "S") in dir_edges
456+
and ("G", "L") in dir_edges
457+
and ("S", "L") in dir_edges
458+
and ("I", "S") in dir_edges
459+
and ("P", "S") in dir_edges
460+
and len(dir_edges) == 5
461+
)
462+
463+
464+
def test_is_maximal():
465+
# X <- Y <-> Z <-> H; Z -> X
466+
admg = ADMG()
467+
admg.add_edge("Y", "X", admg.directed_edge_name)
468+
admg.add_edge("Z", "X", admg.directed_edge_name)
469+
admg.add_edge("Z", "Y", admg.bidirected_edge_name)
470+
admg.add_edge("Z", "H", admg.bidirected_edge_name)
471+
472+
S = {}
473+
L = {"Y"}
474+
assert not pywhy_graphs.is_maximal(admg, L, S)

0 commit comments

Comments
 (0)