Skip to content

Commit 34f688a

Browse files
committed
black formatting
1 parent abb4655 commit 34f688a

File tree

3 files changed

+46
-41
lines changed

3 files changed

+46
-41
lines changed

bhc/api.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66

77

88
class Result(object):
9-
def __init__(self,
10-
arc_list,
11-
node_ids,
12-
last_log_p,
13-
weights,
14-
hierarchy_cut,
15-
n_clusters):
9+
def __init__(
10+
self,
11+
arc_list,
12+
node_ids,
13+
last_log_p,
14+
weights,
15+
hierarchy_cut,
16+
n_clusters,
17+
):
1618
self.arc_list = arc_list
1719
self.node_ids = node_ids
1820
self.last_log_p = last_log_p
@@ -30,23 +32,20 @@ def __eq__(self, other):
3032
return self.source == other.source and self.target == other.target
3133

3234
def __repr__(self):
33-
return '{0} -> {1}'.format(str(self.source), str(self.target))
35+
return "{0} -> {1}".format(str(self.source), str(self.target))
3436

3537

3638
class AbstractPrior(ABC):
3739
@abstractmethod
38-
def calc_log_mlh(self, x_mat):
39-
...
40+
def calc_log_mlh(self, x_mat): ...
4041

4142

4243
class AbstractHierarchicalClustering(ABC):
4344
@abstractmethod
44-
def build(self):
45-
...
45+
def build(self): ...
4646

4747

48-
class AbstractBayesianBasedHierarchicalClustering(
49-
AbstractHierarchicalClustering, ABC):
48+
class AbstractBayesianBasedHierarchicalClustering(AbstractHierarchicalClustering, ABC):
5049
def __init__(self, data, model, alpha, cut_allowed):
5150
self.data = data
5251
self.model = model

bhc/core/bhc.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88
import bhc.api as api
99

1010

11-
class BayesianHierarchicalClustering(
12-
api.AbstractBayesianBasedHierarchicalClustering):
11+
class BayesianHierarchicalClustering(api.AbstractBayesianBasedHierarchicalClustering):
1312
"""
1413
Reference: HELLER, Katherine A.; GHAHRAMANI, Zoubin.
1514
Bayesian hierarchical clustering.
@@ -26,7 +25,7 @@ def build(self):
2625

2726
weights = []
2827

29-
# active nodes
28+
# active nodes (all)
3029
active_nodes = np.arange(n_objects)
3130
# assignments - starting each point in its own cluster
3231
assignments = np.arange(n_objects)
@@ -41,7 +40,8 @@ def build(self):
4140
for i in range(n_objects):
4241
# compute log(d_k)
4342
log_d[i] = BayesianHierarchicalClustering.__calc_log_d(
44-
self.alpha, n[i], None)
43+
self.alpha, n[i], None
44+
)
4545
# compute log(p_i)
4646
log_p[i] = self.model.calc_log_mlh(self.data[i])
4747

@@ -54,7 +54,8 @@ def build(self):
5454
n_ch = n[i] + n[j]
5555
log_d_ch = log_d[i] + log_d[j]
5656
log_dk = BayesianHierarchicalClustering.__calc_log_d(
57-
self.alpha, n_ch, log_d_ch)
57+
self.alpha, n_ch, log_d_ch
58+
)
5859
# compute log(pi_k)
5960
log_pik = np.log(self.alpha) + gammaln(n_ch) - log_dk
6061
# compute log(p_k)
@@ -67,8 +68,11 @@ def build(self):
6768
log_r = r1 - r2
6869
# store results
6970
merge_info = [i, j, log_r, r1, r2]
70-
tmp_merge = merge_info if tmp_merge is None \
71+
tmp_merge = (
72+
merge_info
73+
if tmp_merge is None
7174
else np.vstack((tmp_merge, merge_info))
75+
)
7276

7377
# find clusters to merge
7478
arc_list = np.empty(0, dtype=api.Arc)
@@ -100,7 +104,8 @@ def build(self):
100104
# compute log(d_ij)
101105
log_d_ch = log_d[i] + log_d[j]
102106
log_d_ij = BayesianHierarchicalClustering.__calc_log_d(
103-
self.alpha, n[ij], log_d_ch)
107+
self.alpha, n[ij], log_d_ch
108+
)
104109
log_d = np.append(log_d, log_d_ij)
105110
# update assignments
106111
assignments[np.argwhere(assignments == i)] = ij
@@ -129,14 +134,15 @@ def build(self):
129134
n_ch = n[k] + n[ij]
130135
log_d_ch = log_d[k] + log_d[ij]
131136
log_dij = BayesianHierarchicalClustering.__calc_log_d(
132-
self.alpha, n_ch, log_d_ch)
137+
self.alpha, n_ch, log_d_ch
138+
)
133139
# compute log(pi_k)
134140
log_pik = np.log(self.alpha) + gammaln(n_ch) - log_dij
135141
# compute log(p_k)
136-
data_merged = self.data[np.argwhere(
137-
assignments == active_nodes[k]).flatten()]
138-
log_p_ij = self.model.calc_log_mlh(
139-
np.vstack((x_mat_ij, data_merged)))
142+
data_merged = self.data[
143+
np.argwhere(assignments == active_nodes[k]).flatten()
144+
]
145+
log_p_ij = self.model.calc_log_mlh(np.vstack((x_mat_ij, data_merged)))
140146
# compute log(r_k)
141147
log_p_ch = log_p[ij] + log_p[active_nodes[k]]
142148
r1 = log_pik + log_p_ij
@@ -146,12 +152,14 @@ def build(self):
146152
merge_info = [ij, active_nodes[k], log_r, r1, r2]
147153
tmp_merge = np.vstack((tmp_merge, merge_info))
148154

149-
return api.Result(arc_list,
150-
np.arange(0, ij + 1),
151-
log_p[-1],
152-
np.array(weights),
153-
hierarchy_cut,
154-
len(np.unique(assignments)))
155+
return api.Result(
156+
arc_list,
157+
np.arange(0, ij + 1),
158+
log_p[-1],
159+
np.array(weights),
160+
hierarchy_cut,
161+
len(np.unique(assignments)),
162+
)
155163

156164
@staticmethod
157165
def __calc_log_d(alpha, nk, log_d_ch):

bhc/core/prior.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,16 @@ def calc_log_mlh(self, x_mat):
3131
x_mat_l = x_mat_l[np.newaxis] if x_mat_l.ndim == 1 else x_mat_l
3232
n, d = x_mat_l.shape
3333
s_mat_p, rp, vp = NormalInverseWishart.__calc_posterior(
34-
x_mat_l, self.s_mat, self.r, self.v, self.m)
34+
x_mat_l, self.s_mat, self.r, self.v, self.m
35+
)
3536
log_prior = NormalInverseWishart.__calc_log_prior(s_mat_p, rp, vp)
3637
return log_prior - self.log_prior0 - LOG2PI * (n * d / 2.0)
3738

3839
@staticmethod
3940
def __calc_log_prior(s_mat, r, v):
4041
d = s_mat.shape[0]
4142
log_prior = LOG2 * (v * d / 2.0) + (d / 2.0) * np.log(2.0 * np.pi / r)
42-
log_prior += multigammaln(v / 2.0, d) - \
43-
(v / 2.0) * np.log(np.linalg.det(s_mat))
43+
log_prior += multigammaln(v / 2.0, d) - (v / 2.0) * np.log(np.linalg.det(s_mat))
4444
return log_prior
4545

4646
@staticmethod
@@ -49,8 +49,7 @@ def __calc_posterior(x_mat, s_mat, r, v, m):
4949
x_bar = np.mean(x_mat, axis=0)
5050
rp = r + n
5151
vp = v + n
52-
s_mat_t = np.zeros(s_mat.shape) if n == 1 else (
53-
n - 1) * np.cov(x_mat.T)
52+
s_mat_t = np.zeros(s_mat.shape) if n == 1 else (n - 1) * np.cov(x_mat.T)
5453
dt = (x_bar - m)[np.newaxis]
5554
s_mat_p = s_mat + s_mat_t + (r * n / rp) * np.dot(dt.T, dt)
5655
return s_mat_p, rp, vp
@@ -62,7 +61,6 @@ def create(data, g, scale_factor):
6261
data_matrix_cov = np.cov(data.T)
6362
scatter_matrix = (data_matrix_cov / g).T
6463

65-
return NormalInverseWishart(scatter_matrix,
66-
scale_factor,
67-
degrees_of_freedom,
68-
data_mean)
64+
return NormalInverseWishart(
65+
scatter_matrix, scale_factor, degrees_of_freedom, data_mean
66+
)

0 commit comments

Comments
 (0)