Skip to content

Commit 608206e

Browse files
aryan26royadam2392
andauthored
[ENH] Add the ability to convert a PAG to MAG (#93)
* Added function for converting pag to mag --------- Signed-off-by: Aryan Roy <aryanroy5678@gmail.com> Co-authored-by: Adam Li <adam2392@gmail.com>
1 parent 9f3e202 commit 608206e

File tree

3 files changed

+348
-2
lines changed

3 files changed

+348
-2
lines changed

doc/whats_new/v0.2.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ Changelog
2727
---------
2828
- |Feature| Implement and test functions to validate an MAG and check the presence of almost directed cycles, by `Aryan Roy`_ (:pr:`91`)
2929
- |Feature| Implement and test functions to convert a DAG to MAG, by `Aryan Roy`_ (:pr:`96`)
30+
- |Feature| Implement and test functions to convert a PAG to MAG, by `Aryan Roy`_ (:pr:`93`)
3031

3132
Code and Documentation Contributors
3233
-----------------------------------

pywhy_graphs/algorithms/pag.py

Lines changed: 270 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import logging
22
from collections import deque
3-
from itertools import chain
3+
from itertools import chain, combinations, permutations
44
from typing import List, Optional, Set, Tuple
55

66
import networkx as nx
77
import numpy as np
88

9-
from pywhy_graphs import PAG, StationaryTimeSeriesPAG
9+
from pywhy_graphs import ADMG, CPDAG, PAG, StationaryTimeSeriesPAG
1010
from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path
1111
from pywhy_graphs.typing import Node, TsNode
1212

@@ -22,6 +22,7 @@
2222
"pds_t",
2323
"pds_t_path",
2424
"is_definite_noncollider",
25+
"pag_to_mag",
2526
]
2627

2728

@@ -908,3 +909,270 @@ def _check_ts_node(node):
908909
)
909910
if node[1] > 0:
910911
raise ValueError(f"All lag points should be 0, or less. You passed in {node}.")
912+
913+
914+
def _apply_meek_rules(graph: CPDAG) -> None:
915+
"""Orient edges in a skeleton graph to estimate the causal DAG, or CPDAG.
916+
These are known as the Meek rules :footcite:`Meek1995`. They are deterministic
917+
in the sense that they are logical characterizations of what edges must be
918+
present given the rest of the local graph structure.
919+
Parameters
920+
----------
921+
graph : CPDAG
922+
A graph containing directed and undirected edges.
923+
"""
924+
# For all the combination of nodes i and j, apply the following
925+
# rules.
926+
completed = False
927+
while not completed: # type: ignore
928+
change_flag = False
929+
for i in graph.nodes:
930+
for j in graph.neighbors(i):
931+
if i == j:
932+
continue
933+
# Rule 1: Orient i-j into i->j whenever there is an arrow k->i
934+
# such that k and j are nonadjacent.
935+
r1_add = _meek_rule1(graph, i, j)
936+
937+
# Rule 2: Orient i-j into i->j whenever there is a chain
938+
# i->k->j.
939+
r2_add = _meek_rule2(graph, i, j)
940+
941+
# Rule 3: Orient i-j into i->j whenever there are two chains
942+
# i-k->j and i-l->j such that k and l are nonadjacent.
943+
r3_add = _meek_rule3(graph, i, j)
944+
945+
# Rule 4: Orient i-j into i->j whenever there are two chains
946+
# i-k->l and k->l->j such that k and j are nonadjacent.
947+
#
948+
r4_add = _meek_rule4(graph, i, j)
949+
950+
if any([r1_add, r2_add, r3_add, r4_add]) and not change_flag:
951+
change_flag = True
952+
if not change_flag:
953+
completed = True
954+
break
955+
956+
957+
def _meek_rule1(graph: CPDAG, i: str, j: str) -> bool:
958+
"""Apply rule 1 of Meek's rules.
959+
Looks for i - j such that k -> i, such that (k,i,j)
960+
is an unshielded triple. Then can orient i - j as i -> j.
961+
"""
962+
added_arrows = False
963+
964+
# Check if i-j.
965+
if graph.has_edge(i, j, graph.undirected_edge_name):
966+
for k in graph.predecessors(i):
967+
# Skip if k and j are adjacent because then it is a
968+
# shielded triple
969+
if j in graph.neighbors(k):
970+
continue
971+
972+
# check if the triple is in the graph's excluded triples
973+
if frozenset((k, i, j)) in graph.excluded_triples:
974+
continue
975+
976+
# Make i-j into i->j
977+
graph.orient_uncertain_edge(i, j)
978+
979+
added_arrows = True
980+
break
981+
return added_arrows
982+
983+
984+
def _meek_rule2(graph: CPDAG, i: str, j: str) -> bool:
985+
"""Apply rule 2 of Meek's rules.
986+
Check for i - j, and then looks for i -> k -> j
987+
triple, to orient i - j as i -> j.
988+
"""
989+
added_arrows = False
990+
991+
# Check if i-j.
992+
if graph.has_edge(i, j, graph.undirected_edge_name):
993+
# Find nodes k where k is i->k
994+
child_i = set()
995+
for k in graph.successors(i):
996+
if not graph.has_edge(k, i, graph.directed_edge_name):
997+
child_i.add(k)
998+
# Find nodes j where j is k->j.
999+
parent_j = set()
1000+
for k in graph.predecessors(j):
1001+
if not graph.has_edge(j, k, graph.directed_edge_name):
1002+
parent_j.add(k)
1003+
1004+
# Check if there is any node k where i->k->j.
1005+
candidate_k = child_i.intersection(parent_j)
1006+
# if the graph has excluded triples, we would check at this point
1007+
if graph.excluded_triples:
1008+
# check if the triple is in the graph's excluded triples
1009+
# if so, remove them from the candidates
1010+
for k in candidate_k:
1011+
if frozenset((i, k, j)) in graph.excluded_triples:
1012+
candidate_k.remove(k)
1013+
1014+
# if there are candidate 'k' nodes, then orient the edge accordingly
1015+
if len(candidate_k) > 0:
1016+
# Make i-j into i->j
1017+
graph.orient_uncertain_edge(i, j)
1018+
added_arrows = True
1019+
return added_arrows
1020+
1021+
1022+
def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool:
1023+
"""Apply rule 3 of Meek's rules.
1024+
Check for i - j, and then looks for k -> j <- l
1025+
collider, and i - k and i - l, then orient i -> j.
1026+
"""
1027+
added_arrows = False
1028+
1029+
# Check if i-j first
1030+
if graph.has_edge(i, j, graph.undirected_edge_name):
1031+
# For all the pairs of nodes adjacent to i,
1032+
# look for (k, l), such that j -> l and k -> l
1033+
for (k, l) in combinations(graph.neighbors(i), 2):
1034+
# Skip if k and l are adjacent.
1035+
if l in graph.neighbors(k):
1036+
continue
1037+
# Skip if not k->j.
1038+
if graph.has_edge(j, k, graph.directed_edge_name) or (
1039+
not graph.has_edge(k, j, graph.directed_edge_name)
1040+
):
1041+
continue
1042+
# Skip if not l->j.
1043+
if graph.has_edge(j, l, graph.directed_edge_name) or (
1044+
not graph.has_edge(l, j, graph.directed_edge_name)
1045+
):
1046+
continue
1047+
1048+
# check if the triple is inside graph's excluded triples
1049+
if frozenset((l, i, k)) in graph.excluded_triples:
1050+
continue
1051+
1052+
# if i - k and i - l, then at this point, we have a valid path
1053+
# to orient
1054+
if graph.has_edge(k, i, graph.undirected_edge_name) and graph.has_edge(
1055+
l, i, graph.undirected_edge_name
1056+
):
1057+
graph.orient_uncertain_edge(i, j)
1058+
added_arrows = True
1059+
break
1060+
return added_arrows
1061+
1062+
1063+
def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool:
1064+
"""Apply rule 4 of Meek's rules.
1065+
Check for i - j, and then looks for i - k -> l -> j, to orient i - j as i -> j.
1066+
"""
1067+
added_arrows = False
1068+
1069+
# Check if i-j.
1070+
if graph.has_edge(i, j, graph.undirected_edge_name):
1071+
# Find nodes k where k is i-k
1072+
adj_i = set()
1073+
for k in graph.neighbors(i):
1074+
if not graph.has_edge(k, i, graph.directed_edge_name):
1075+
adj_i.add(k)
1076+
1077+
# Find nodes l where j is l->j.
1078+
parent_j = set()
1079+
for k in graph.predecessors(j):
1080+
if not graph.has_edge(j, k, graph.directed_edge_name):
1081+
parent_j.add(k)
1082+
1083+
# generate all permutations of sets containing neighbors of i and parents of j
1084+
permut = permutations(adj_i, len(parent_j))
1085+
unq = set() # type: ignore
1086+
for comb in permut:
1087+
zipped = zip(comb, parent_j)
1088+
unq.update(zipped)
1089+
1090+
# check if these pairs have a directed edge between them and that k-j does not exist
1091+
dedges = set(graph.directed_edges)
1092+
undedges = set(graph.undirected_edges)
1093+
candidate_k = set()
1094+
for pair in unq:
1095+
if pair in dedges:
1096+
if (pair[0], j) not in undedges:
1097+
candidate_k.add(pair)
1098+
1099+
# if there are candidate 'k->l' pairs, then orient the edge accordingly
1100+
if len(candidate_k) > 0:
1101+
# Make i-j into i->j
1102+
# logger.info(f"R2: Removing edge {i}-{j} to form {i}->{j}.")
1103+
graph.orient_uncertain_edge(i, j)
1104+
added_arrows = True
1105+
return added_arrows
1106+
1107+
1108+
def pag_to_mag(graph):
1109+
"""Sample a MAG from a PAG using Zhang's algorithm.
1110+
1111+
Using the algorithm defined in Theorem 2 of :footcite:`Zhang2008`, which turns all
1112+
o-> edges to -> and -o edges to ->, then it converts the graph into a DAG with
1113+
no unshielded colliders using the meek rules.
1114+
1115+
Parameters
1116+
----------
1117+
G : Graph
1118+
The PAG.
1119+
1120+
Returns
1121+
-------
1122+
mag : Graph
1123+
The MAG constructed from the PAG.
1124+
"""
1125+
copy_graph = graph.copy()
1126+
1127+
cedges = set(copy_graph.circle_edges)
1128+
dedges = set(copy_graph.directed_edges)
1129+
1130+
temp_cpdag = CPDAG()
1131+
1132+
to_remove = []
1133+
to_reorient = []
1134+
to_add = []
1135+
1136+
for u, v in cedges:
1137+
if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge
1138+
to_remove.append((u, v))
1139+
elif (v, u) not in cedges: # reorient a '--o' edge to '-->'
1140+
to_reorient.append((u, v))
1141+
elif (v, u) in cedges and (
1142+
v,
1143+
u,
1144+
) not in to_add: # add all 'o--o' edges to the cpdag
1145+
to_add.append((u, v))
1146+
for u, v in to_remove:
1147+
copy_graph.remove_edge(u, v, copy_graph.circle_edge_name)
1148+
for u, v in to_reorient:
1149+
copy_graph.orient_uncertain_edge(u, v)
1150+
for u, v in to_add:
1151+
temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name)
1152+
1153+
flag = True
1154+
1155+
# convert the graph into a DAG with no unshielded colliders
1156+
1157+
while flag:
1158+
undedges = temp_cpdag.undirected_edges
1159+
if len(undedges) != 0:
1160+
for (u, v) in undedges:
1161+
temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name)
1162+
temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name)
1163+
_apply_meek_rules(temp_cpdag)
1164+
break
1165+
else:
1166+
flag = False
1167+
1168+
mag = ADMG() # provisional MAG
1169+
1170+
# construct the final MAG
1171+
1172+
for (u, v) in copy_graph.directed_edges:
1173+
mag.add_edge(u, v, mag.directed_edge_name)
1174+
1175+
for (u, v) in temp_cpdag.directed_edges:
1176+
mag.add_edge(u, v, mag.directed_edge_name)
1177+
1178+
return mag

pywhy_graphs/algorithms/tests/test_pag.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -647,3 +647,80 @@ def test_pdst(pdst_graph):
647647
ex_pdsep_t = pds_t(G, ("E", 0), ("x", -1))
648648
assert ("y", -2) not in xe_pdsep_t
649649
assert ("y", -2) not in ex_pdsep_t
650+
651+
652+
def test_pag_to_mag():
653+
654+
# C o- A o-> D <-o B
655+
# B o-o A o-o C o-> D
656+
657+
pag = PAG()
658+
pag.add_edge("A", "D", pag.directed_edge_name)
659+
pag.add_edge("A", "C", pag.circle_edge_name)
660+
pag.add_edge("D", "A", pag.circle_edge_name)
661+
pag.add_edge("B", "D", pag.directed_edge_name)
662+
pag.add_edge("C", "D", pag.directed_edge_name)
663+
pag.add_edge("D", "B", pag.circle_edge_name)
664+
pag.add_edge("D", "C", pag.circle_edge_name)
665+
pag.add_edge("C", "A", pag.circle_edge_name)
666+
pag.add_edge("B", "A", pag.circle_edge_name)
667+
pag.add_edge("A", "B", pag.circle_edge_name)
668+
669+
out_mag = pywhy_graphs.pag_to_mag(pag)
670+
671+
# C <- A -> B -> D or C -> A -> B -> D or C <- A <- B -> D
672+
# A -> D <- C
673+
674+
assert (
675+
((out_mag.has_edge("A", "B")) or (out_mag.has_edge("B", "A")))
676+
and ((out_mag.has_edge("A", "C")) or (out_mag.has_edge("C", "A")))
677+
and (out_mag.has_edge("A", "D"))
678+
and (out_mag.has_edge("B", "D"))
679+
and (out_mag.has_edge("C", "D"))
680+
)
681+
682+
# D o-> A <-o B
683+
# D o-o B
684+
pag = PAG()
685+
pag.add_edge("A", "B", pag.circle_edge_name)
686+
pag.add_edge("B", "A", pag.directed_edge_name)
687+
pag.add_edge("D", "A", pag.directed_edge_name)
688+
pag.add_edge("A", "D", pag.circle_edge_name)
689+
pag.add_edge("D", "B", pag.circle_edge_name)
690+
pag.add_edge("B", "D", pag.circle_edge_name)
691+
692+
out_mag = pywhy_graphs.pag_to_mag(pag)
693+
694+
# B -> A <- D
695+
# D -> B or D <- B
696+
697+
assert (
698+
out_mag.has_edge("B", "A")
699+
and out_mag.has_edge("D", "A")
700+
and (out_mag.has_edge("D", "B") or out_mag.has_edge("B", "D"))
701+
)
702+
703+
# A -> B <- C o-o D
704+
# D o-o E -> B
705+
706+
pag = PAG()
707+
pag.add_edge("A", "B", pag.directed_edge_name)
708+
pag.add_edge("C", "B", pag.directed_edge_name)
709+
pag.add_edge("E", "B", pag.directed_edge_name)
710+
pag.add_edge("E", "D", pag.circle_edge_name)
711+
pag.add_edge("C", "D", pag.circle_edge_name)
712+
pag.add_edge("D", "E", pag.circle_edge_name)
713+
pag.add_edge("D", "C", pag.circle_edge_name)
714+
715+
out_mag = pywhy_graphs.pag_to_mag(pag)
716+
717+
# A -> B <- C <- D or A -> B <- C -> D
718+
# D <- E -> B or D <- E -> B
719+
720+
assert (
721+
out_mag.has_edge("A", "B")
722+
and out_mag.has_edge("C", "B")
723+
and out_mag.has_edge("E", "B")
724+
and (out_mag.has_edge("E", "D") or out_mag.has_edge("D", "E"))
725+
and (out_mag.has_edge("D", "C") or out_mag.has_edge("C", "D"))
726+
)

0 commit comments

Comments
 (0)