-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathldp_utils.py
More file actions
158 lines (127 loc) · 5.39 KB
/
ldp_utils.py
File metadata and controls
158 lines (127 loc) · 5.39 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# General importations.
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
import platform
import time
import itertools
import warnings
from scipy import stats
# Baselines.
from causallearn.search.ConstraintBased.PC import pc
from causallearn.search.ConstraintBased.FCI import fci
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from causallearn.utils.cit import chisq
from IPython.display import Image, display
import networkx as nx
import pydot
class LDPUtils():
def get_ci(self,
results: list,
z: float = 1.96,
return_mean: bool = True) -> tuple:
'''
Default is 95% confidence interval.
'''
mean = np.mean(results)
se = stats.sem(results)
ci = [mean - (z * se), mean + (z * se)]
if return_mean:
return mean, ci
return ci
def plot_nx(self,
adjacency_matrix,
labels,
figsize = (10,10),
dpi = 200,
node_size = 800,
arrow_size = 10):
g = nx.from_numpy_array(adjacency_matrix, create_using = nx.DiGraph)
plt.figure(figsize = figsize, dpi = dpi)
nx.draw_circular(g,
node_size = node_size,
labels = dict(zip(list(range(len(labels))), labels)),
arrowsize = arrow_size,
with_labels = True)
plt.show()
plt.close()
def view_pydot(self,
pdot,
figsize = 500):
img = Image(pdot.create_png(), width = 500)
display(img)
def plot_pydot_from_adjacency_matrix(self,
adjacency_matrix,
labels,
undirected_edges: list = [("Z7", "X")],
uncertain_edges: list = [("X", "Z2"), ("X", "Z6")]):
# Convert adjacency matrix to networkx graph, then convert to pydot.
g = nx.from_numpy_matrix(adjacency_matrix, create_using = nx.DiGraph)
node_label_map = dict(zip(list(range(len(labels))), labels))
g = nx.relabel_nodes(g, node_label_map, copy = True)
p = nx.drawing.nx_pydot.to_pydot(g)
# Replace X -> Y with dashed arrow.
string = p.to_string()
string = string.replace('X -> Y [weight="1.0"]', 'X -> Y [weight="1.0", style="dashed"]')
p = pydot.graph_from_dot_data(string)[0]
p.del_node('"\\n"')
# Remove arrow heads from undirected edges.
if undirected_edges is not None:
string = p.to_string()
for edge in undirected_edges:
replace_from = '{} -> {} [weight="1.0"]'.format(edge[0], edge[1])
replace_to = '{} -> {} [weight="1.0", dir=none]'.format(edge[0], edge[1])
string = string.replace(replace_from, replace_to)
p = pydot.graph_from_dot_data(string)[0]
p.del_node('"\\n"')
# Make uncertain edges dotted.
if uncertain_edges is not None:
string = p.to_string()
for edge in uncertain_edges:
replace_from = '{} -> {} [weight="1.0"]'.format(edge[0], edge[1])
replace_to = '{} -> {} [weight="1.0", style="dotted"]'.format(edge[0], edge[1])
string = string.replace(replace_from, replace_to)
p = pydot.graph_from_dot_data(string)[0]
p.del_node('"\\n"')
# View graph.
self.view_pydot(p)
plt.close()
def is_valid_adjustment_set(self,
adj_matrix: np.ndarray,
x: int,
y: int,
S: list,
verbose: bool = False) -> bool:
'''
Checks if `S` is a valid adjustment wrt `X` and `Y`.
Parameters:
------------
adj_matrix: The adjacency matrix for the DAG.
x: The index of the `X` node.
y: The index of the `Y` node.
S: The indices of the `S` nodes (list[int]).
Return:
---------
Bool: Whether S is a valid adjustment wrt X and Y.
'''
S = set(S)
dag = nx.from_numpy_array(adj_matrix, create_using=nx.DiGraph)
# Set of nodes (including Y) on causal paths from X to Y.
causal_nodes = set(p for path in nx.all_simple_paths(dag, source=x, target=y) for p in path if p != x)
# The forbidden set contains nodes on causal paths from X to Y and their descendants.
forbidden = set(causal_nodes)
for n in causal_nodes:
forbidden |= nx.descendants(dag, n)
# A valid adjustment set cannot contain forbidden nodes.
if len(forbidden & S) > 0:
if verbose:
print("Forbidden nodes in adjustment set:", S.intersection(forbidden))
return False
# Remove outgoing edges from X on causal paths to Y.
edges_to_remove = [p[0] for p in nx.all_simple_edge_paths(dag, x, y)]
dag.remove_edges_from(edges_to_remove)
# check if S d-separates X and Y in the mutilated graph.
return nx.d_separated(dag, {x}, {y}, S)