|
1 | 1 | import logging |
2 | 2 | from collections import deque |
3 | | -from itertools import chain |
| 3 | +from itertools import chain, combinations, permutations |
4 | 4 | from typing import List, Optional, Set, Tuple |
5 | 5 |
|
6 | 6 | import networkx as nx |
7 | 7 | import numpy as np |
8 | 8 |
|
9 | | -from pywhy_graphs import PAG, StationaryTimeSeriesPAG |
| 9 | +from pywhy_graphs import ADMG, CPDAG, PAG, StationaryTimeSeriesPAG |
10 | 10 | from pywhy_graphs.algorithms.generic import single_source_shortest_mixed_path |
11 | 11 | from pywhy_graphs.typing import Node, TsNode |
12 | 12 |
|
|
22 | 22 | "pds_t", |
23 | 23 | "pds_t_path", |
24 | 24 | "is_definite_noncollider", |
| 25 | + "pag_to_mag", |
25 | 26 | ] |
26 | 27 |
|
27 | 28 |
|
@@ -908,3 +909,270 @@ def _check_ts_node(node): |
908 | 909 | ) |
909 | 910 | if node[1] > 0: |
910 | 911 | raise ValueError(f"All lag points should be 0, or less. You passed in {node}.") |
| 912 | + |
| 913 | + |
| 914 | +def _apply_meek_rules(graph: CPDAG) -> None: |
| 915 | + """Orient edges in a skeleton graph to estimate the causal DAG, or CPDAG. |
| 916 | + These are known as the Meek rules :footcite:`Meek1995`. They are deterministic |
| 917 | + in the sense that they are logical characterizations of what edges must be |
| 918 | + present given the rest of the local graph structure. |
| 919 | + Parameters |
| 920 | + ---------- |
| 921 | + graph : CPDAG |
| 922 | + A graph containing directed and undirected edges. |
| 923 | + """ |
| 924 | + # For all the combination of nodes i and j, apply the following |
| 925 | + # rules. |
| 926 | + completed = False |
| 927 | + while not completed: # type: ignore |
| 928 | + change_flag = False |
| 929 | + for i in graph.nodes: |
| 930 | + for j in graph.neighbors(i): |
| 931 | + if i == j: |
| 932 | + continue |
| 933 | + # Rule 1: Orient i-j into i->j whenever there is an arrow k->i |
| 934 | + # such that k and j are nonadjacent. |
| 935 | + r1_add = _meek_rule1(graph, i, j) |
| 936 | + |
| 937 | + # Rule 2: Orient i-j into i->j whenever there is a chain |
| 938 | + # i->k->j. |
| 939 | + r2_add = _meek_rule2(graph, i, j) |
| 940 | + |
| 941 | + # Rule 3: Orient i-j into i->j whenever there are two chains |
| 942 | + # i-k->j and i-l->j such that k and l are nonadjacent. |
| 943 | + r3_add = _meek_rule3(graph, i, j) |
| 944 | + |
| 945 | + # Rule 4: Orient i-j into i->j whenever there are two chains |
| 946 | + # i-k->l and k->l->j such that k and j are nonadjacent. |
| 947 | + # |
| 948 | + r4_add = _meek_rule4(graph, i, j) |
| 949 | + |
| 950 | + if any([r1_add, r2_add, r3_add, r4_add]) and not change_flag: |
| 951 | + change_flag = True |
| 952 | + if not change_flag: |
| 953 | + completed = True |
| 954 | + break |
| 955 | + |
| 956 | + |
| 957 | +def _meek_rule1(graph: CPDAG, i: str, j: str) -> bool: |
| 958 | + """Apply rule 1 of Meek's rules. |
| 959 | + Looks for i - j such that k -> i, such that (k,i,j) |
| 960 | + is an unshielded triple. Then can orient i - j as i -> j. |
| 961 | + """ |
| 962 | + added_arrows = False |
| 963 | + |
| 964 | + # Check if i-j. |
| 965 | + if graph.has_edge(i, j, graph.undirected_edge_name): |
| 966 | + for k in graph.predecessors(i): |
| 967 | + # Skip if k and j are adjacent because then it is a |
| 968 | + # shielded triple |
| 969 | + if j in graph.neighbors(k): |
| 970 | + continue |
| 971 | + |
| 972 | + # check if the triple is in the graph's excluded triples |
| 973 | + if frozenset((k, i, j)) in graph.excluded_triples: |
| 974 | + continue |
| 975 | + |
| 976 | + # Make i-j into i->j |
| 977 | + graph.orient_uncertain_edge(i, j) |
| 978 | + |
| 979 | + added_arrows = True |
| 980 | + break |
| 981 | + return added_arrows |
| 982 | + |
| 983 | + |
| 984 | +def _meek_rule2(graph: CPDAG, i: str, j: str) -> bool: |
| 985 | + """Apply rule 2 of Meek's rules. |
| 986 | + Check for i - j, and then looks for i -> k -> j |
| 987 | + triple, to orient i - j as i -> j. |
| 988 | + """ |
| 989 | + added_arrows = False |
| 990 | + |
| 991 | + # Check if i-j. |
| 992 | + if graph.has_edge(i, j, graph.undirected_edge_name): |
| 993 | + # Find nodes k where k is i->k |
| 994 | + child_i = set() |
| 995 | + for k in graph.successors(i): |
| 996 | + if not graph.has_edge(k, i, graph.directed_edge_name): |
| 997 | + child_i.add(k) |
| 998 | + # Find nodes j where j is k->j. |
| 999 | + parent_j = set() |
| 1000 | + for k in graph.predecessors(j): |
| 1001 | + if not graph.has_edge(j, k, graph.directed_edge_name): |
| 1002 | + parent_j.add(k) |
| 1003 | + |
| 1004 | + # Check if there is any node k where i->k->j. |
| 1005 | + candidate_k = child_i.intersection(parent_j) |
| 1006 | + # if the graph has excluded triples, we would check at this point |
| 1007 | + if graph.excluded_triples: |
| 1008 | + # check if the triple is in the graph's excluded triples |
| 1009 | + # if so, remove them from the candidates |
| 1010 | + for k in candidate_k: |
| 1011 | + if frozenset((i, k, j)) in graph.excluded_triples: |
| 1012 | + candidate_k.remove(k) |
| 1013 | + |
| 1014 | + # if there are candidate 'k' nodes, then orient the edge accordingly |
| 1015 | + if len(candidate_k) > 0: |
| 1016 | + # Make i-j into i->j |
| 1017 | + graph.orient_uncertain_edge(i, j) |
| 1018 | + added_arrows = True |
| 1019 | + return added_arrows |
| 1020 | + |
| 1021 | + |
| 1022 | +def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool: |
| 1023 | + """Apply rule 3 of Meek's rules. |
| 1024 | + Check for i - j, and then looks for k -> j <- l |
| 1025 | + collider, and i - k and i - l, then orient i -> j. |
| 1026 | + """ |
| 1027 | + added_arrows = False |
| 1028 | + |
| 1029 | + # Check if i-j first |
| 1030 | + if graph.has_edge(i, j, graph.undirected_edge_name): |
| 1031 | + # For all the pairs of nodes adjacent to i, |
| 1032 | + # look for (k, l), such that j -> l and k -> l |
| 1033 | + for (k, l) in combinations(graph.neighbors(i), 2): |
| 1034 | + # Skip if k and l are adjacent. |
| 1035 | + if l in graph.neighbors(k): |
| 1036 | + continue |
| 1037 | + # Skip if not k->j. |
| 1038 | + if graph.has_edge(j, k, graph.directed_edge_name) or ( |
| 1039 | + not graph.has_edge(k, j, graph.directed_edge_name) |
| 1040 | + ): |
| 1041 | + continue |
| 1042 | + # Skip if not l->j. |
| 1043 | + if graph.has_edge(j, l, graph.directed_edge_name) or ( |
| 1044 | + not graph.has_edge(l, j, graph.directed_edge_name) |
| 1045 | + ): |
| 1046 | + continue |
| 1047 | + |
| 1048 | + # check if the triple is inside graph's excluded triples |
| 1049 | + if frozenset((l, i, k)) in graph.excluded_triples: |
| 1050 | + continue |
| 1051 | + |
| 1052 | + # if i - k and i - l, then at this point, we have a valid path |
| 1053 | + # to orient |
| 1054 | + if graph.has_edge(k, i, graph.undirected_edge_name) and graph.has_edge( |
| 1055 | + l, i, graph.undirected_edge_name |
| 1056 | + ): |
| 1057 | + graph.orient_uncertain_edge(i, j) |
| 1058 | + added_arrows = True |
| 1059 | + break |
| 1060 | + return added_arrows |
| 1061 | + |
| 1062 | + |
| 1063 | +def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool: |
| 1064 | + """Apply rule 4 of Meek's rules. |
| 1065 | + Check for i - j, and then looks for i - k -> l -> j, to orient i - j as i -> j. |
| 1066 | + """ |
| 1067 | + added_arrows = False |
| 1068 | + |
| 1069 | + # Check if i-j. |
| 1070 | + if graph.has_edge(i, j, graph.undirected_edge_name): |
| 1071 | + # Find nodes k where k is i-k |
| 1072 | + adj_i = set() |
| 1073 | + for k in graph.neighbors(i): |
| 1074 | + if not graph.has_edge(k, i, graph.directed_edge_name): |
| 1075 | + adj_i.add(k) |
| 1076 | + |
| 1077 | + # Find nodes l where j is l->j. |
| 1078 | + parent_j = set() |
| 1079 | + for k in graph.predecessors(j): |
| 1080 | + if not graph.has_edge(j, k, graph.directed_edge_name): |
| 1081 | + parent_j.add(k) |
| 1082 | + |
| 1083 | + # generate all permutations of sets containing neighbors of i and parents of j |
| 1084 | + permut = permutations(adj_i, len(parent_j)) |
| 1085 | + unq = set() # type: ignore |
| 1086 | + for comb in permut: |
| 1087 | + zipped = zip(comb, parent_j) |
| 1088 | + unq.update(zipped) |
| 1089 | + |
| 1090 | + # check if these pairs have a directed edge between them and that k-j does not exist |
| 1091 | + dedges = set(graph.directed_edges) |
| 1092 | + undedges = set(graph.undirected_edges) |
| 1093 | + candidate_k = set() |
| 1094 | + for pair in unq: |
| 1095 | + if pair in dedges: |
| 1096 | + if (pair[0], j) not in undedges: |
| 1097 | + candidate_k.add(pair) |
| 1098 | + |
| 1099 | + # if there are candidate 'k->l' pairs, then orient the edge accordingly |
| 1100 | + if len(candidate_k) > 0: |
| 1101 | + # Make i-j into i->j |
| 1102 | + # logger.info(f"R2: Removing edge {i}-{j} to form {i}->{j}.") |
| 1103 | + graph.orient_uncertain_edge(i, j) |
| 1104 | + added_arrows = True |
| 1105 | + return added_arrows |
| 1106 | + |
| 1107 | + |
| 1108 | +def pag_to_mag(graph): |
| 1109 | + """Sample a MAG from a PAG using Zhang's algorithm. |
| 1110 | +
|
| 1111 | + Using the algorithm defined in Theorem 2 of :footcite:`Zhang2008`, which turns all |
| 1112 | + o-> edges to -> and -o edges to ->, then it converts the graph into a DAG with |
| 1113 | + no unshielded colliders using the meek rules. |
| 1114 | +
|
| 1115 | + Parameters |
| 1116 | + ---------- |
| 1117 | + G : Graph |
| 1118 | + The PAG. |
| 1119 | +
|
| 1120 | + Returns |
| 1121 | + ------- |
| 1122 | + mag : Graph |
| 1123 | + The MAG constructed from the PAG. |
| 1124 | + """ |
| 1125 | + copy_graph = graph.copy() |
| 1126 | + |
| 1127 | + cedges = set(copy_graph.circle_edges) |
| 1128 | + dedges = set(copy_graph.directed_edges) |
| 1129 | + |
| 1130 | + temp_cpdag = CPDAG() |
| 1131 | + |
| 1132 | + to_remove = [] |
| 1133 | + to_reorient = [] |
| 1134 | + to_add = [] |
| 1135 | + |
| 1136 | + for u, v in cedges: |
| 1137 | + if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge |
| 1138 | + to_remove.append((u, v)) |
| 1139 | + elif (v, u) not in cedges: # reorient a '--o' edge to '-->' |
| 1140 | + to_reorient.append((u, v)) |
| 1141 | + elif (v, u) in cedges and ( |
| 1142 | + v, |
| 1143 | + u, |
| 1144 | + ) not in to_add: # add all 'o--o' edges to the cpdag |
| 1145 | + to_add.append((u, v)) |
| 1146 | + for u, v in to_remove: |
| 1147 | + copy_graph.remove_edge(u, v, copy_graph.circle_edge_name) |
| 1148 | + for u, v in to_reorient: |
| 1149 | + copy_graph.orient_uncertain_edge(u, v) |
| 1150 | + for u, v in to_add: |
| 1151 | + temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name) |
| 1152 | + |
| 1153 | + flag = True |
| 1154 | + |
| 1155 | + # convert the graph into a DAG with no unshielded colliders |
| 1156 | + |
| 1157 | + while flag: |
| 1158 | + undedges = temp_cpdag.undirected_edges |
| 1159 | + if len(undedges) != 0: |
| 1160 | + for (u, v) in undedges: |
| 1161 | + temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name) |
| 1162 | + temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name) |
| 1163 | + _apply_meek_rules(temp_cpdag) |
| 1164 | + break |
| 1165 | + else: |
| 1166 | + flag = False |
| 1167 | + |
| 1168 | + mag = ADMG() # provisional MAG |
| 1169 | + |
| 1170 | + # construct the final MAG |
| 1171 | + |
| 1172 | + for (u, v) in copy_graph.directed_edges: |
| 1173 | + mag.add_edge(u, v, mag.directed_edge_name) |
| 1174 | + |
| 1175 | + for (u, v) in temp_cpdag.directed_edges: |
| 1176 | + mag.add_edge(u, v, mag.directed_edge_name) |
| 1177 | + |
| 1178 | + return mag |
0 commit comments