Skip to content

Commit b169084

Browse files
committed
BHC speedup (memory optimization and some batching)
1 parent 34f688a commit b169084

File tree

2 files changed

+76
-25
lines changed

2 files changed

+76
-25
lines changed

bhc/core/bhc.py

Lines changed: 29 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ def build(self):
4848
ij = n_objects - 1
4949

5050
# for every pair of data points
51+
pair_count = n_objects * (n_objects - 1) // 2
52+
tmp_merge = np.empty((pair_count, 5), dtype=float)
53+
row = 0
5154
for i in range(n_objects):
55+
log_p_k_row = self.model.row_of_log_likelihood_for_pairs(self.data, i)
5256
for j in range(i + 1, n_objects):
5357
# compute log(d_k)
5458
n_ch = n[i] + n[j]
@@ -59,28 +63,22 @@ def build(self):
5963
# compute log(pi_k)
6064
log_pik = np.log(self.alpha) + gammaln(n_ch) - log_dk
6165
# compute log(p_k)
62-
data_merged = np.vstack((self.data[i], self.data[j]))
63-
log_p_k = self.model.calc_log_mlh(data_merged)
66+
log_p_k = log_p_k_row[j - i - 1] # since j starts at i + 1
6467
# compute log(r_k)
6568
log_p_ch = log_p[i] + log_p[j]
6669
r1 = log_pik + log_p_k
6770
r2 = log_d_ch - log_dk + log_p_ch
6871
log_r = r1 - r2
6972
# store results
70-
merge_info = [i, j, log_r, r1, r2]
71-
tmp_merge = (
72-
merge_info
73-
if tmp_merge is None
74-
else np.vstack((tmp_merge, merge_info))
75-
)
73+
tmp_merge[row] = [i, j, log_r, r1, r2]
74+
row += 1
7675

7776
# find clusters to merge
7877
arc_list = np.empty(0, dtype=api.Arc)
78+
data_per_cluster = [np.array([self.data[i]]) for i in range(n_objects)]
7979
while active_nodes.size > 1:
8080
# find i, j with the highest probability of the merged hypothesis
81-
max_log_rk = np.max(tmp_merge[:, 2])
82-
ids_matched = np.argwhere(tmp_merge[:, 2] == max_log_rk)
83-
position = np.min(ids_matched)
81+
position = np.argmax(tmp_merge[:, 2]) # returns the first occurrence
8482
i, j, log_r, r1, r2 = tmp_merge[position]
8583
i = int(i)
8684
j = int(j)
@@ -91,12 +89,6 @@ def build(self):
9189
hierarchy_cut = True
9290
break
9391

94-
# turn nodes i,j off
95-
tmp_merge[np.argwhere(tmp_merge[:, 0] == i).flatten(), 2] = -np.inf
96-
tmp_merge[np.argwhere(tmp_merge[:, 1] == i).flatten(), 2] = -np.inf
97-
tmp_merge[np.argwhere(tmp_merge[:, 0] == j).flatten(), 2] = -np.inf
98-
tmp_merge[np.argwhere(tmp_merge[:, 1] == j).flatten(), 2] = -np.inf
99-
10092
# new node ij
10193
ij = n.size
10294
n_ch = n[i] + n[j]
@@ -107,7 +99,12 @@ def build(self):
10799
self.alpha, n[ij], log_d_ch
108100
)
109101
log_d = np.append(log_d, log_d_ij)
110-
# update assignments
102+
# update cluster assignments
103+
data_per_cluster.append(
104+
np.vstack((data_per_cluster[i], data_per_cluster[j]))
105+
)
106+
data_per_cluster[i] = None
107+
data_per_cluster[j] = None
111108
assignments[np.argwhere(assignments == i)] = ij
112109
assignments[np.argwhere(assignments == j)] = ij
113110

@@ -121,14 +118,20 @@ def build(self):
121118
j_idx = np.argwhere(active_nodes == j).flatten()
122119
active_nodes = np.delete(active_nodes, [i_idx, j_idx])
123120
active_nodes = np.append(active_nodes, ij)
121+
122+
# clean up tmp_merge
123+
# keep rows where neither column 0 nor column 1 equals i or j
124+
mask = ~np.isin(tmp_merge[:, :2], [i, j]).any(axis=1)
125+
tmp_merge = tmp_merge[mask]
126+
124127
# compute log(p_ij)
125128
t1 = np.maximum(r1, r2)
126129
t2 = np.minimum(r1, r2)
127130
log_p_ij = t1 + np.log(1 + np.exp(t2 - t1))
128131
log_p = np.append(log_p, log_p_ij)
129132

130133
# for every pair ij x active
131-
x_mat_ij = self.data[np.argwhere(assignments == ij).flatten()]
134+
collected_merge_info = np.empty((len(active_nodes) - 1, 5), dtype=float)
132135
for k in range(active_nodes.size - 1):
133136
# compute log(d_k)
134137
n_ch = n[k] + n[ij]
@@ -139,18 +142,19 @@ def build(self):
139142
# compute log(pi_k)
140143
log_pik = np.log(self.alpha) + gammaln(n_ch) - log_dij
141144
# compute log(p_k)
142-
data_merged = self.data[
143-
np.argwhere(assignments == active_nodes[k]).flatten()
144-
]
145-
log_p_ij = self.model.calc_log_mlh(np.vstack((x_mat_ij, data_merged)))
145+
data_merged = np.vstack(
146+
(data_per_cluster[ij], data_per_cluster[active_nodes[k]])
147+
)
148+
log_p_ij = self.model.calc_log_mlh(data_merged)
146149
# compute log(r_k)
147150
log_p_ch = log_p[ij] + log_p[active_nodes[k]]
148151
r1 = log_pik + log_p_ij
149152
r2 = log_d_ch - log_dij + log_p_ch
150153
log_r = r1 - r2
151154
# store results
152-
merge_info = [ij, active_nodes[k], log_r, r1, r2]
153-
tmp_merge = np.vstack((tmp_merge, merge_info))
155+
collected_merge_info[k] = [ij, active_nodes[k], log_r, r1, r2]
156+
157+
tmp_merge = np.vstack((tmp_merge, collected_merge_info))
154158

155159
return api.Result(
156160
arc_list,

bhc/core/prior.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,53 @@ def calc_log_mlh(self, x_mat):
3636
log_prior = NormalInverseWishart.__calc_log_prior(s_mat_p, rp, vp)
3737
return log_prior - self.log_prior0 - LOG2PI * (n * d / 2.0)
3838

39+
def row_of_log_likelihood_for_pairs(
40+
self,
41+
X, # (N, d) data matrix
42+
i, # index of the row you want (int)
43+
):
44+
"""
45+
Returns 1D array containing the log-likelihoods for pairs of points needed for the
46+
initialization of bhc. This function combines i with all other points j > i and returns
47+
the log-likelihood of those clusters (containing two points each).
48+
"""
49+
N, d = X.shape
50+
if d != self.s_mat.shape[0]:
51+
raise ValueError("data dimension and prior scale matrix do not match")
52+
53+
# ------------------------------------------------------------------
54+
# Pairwise sufficient statistics – only for j > i (batched)
55+
# ------------------------------------------------------------------
56+
# slice of points that matter
57+
Xj = X[i + 1 :] # shape (N-i-1, d)
58+
diff = X[i] - Xj # broadcasted automatically
59+
x_bar = 0.5 * (X[i] + Xj) # (N-i-1, d)
60+
61+
# Scatter matrix S = ½ diff·diffᵀ → (N-i-1, d, d)
62+
S = 0.5 * np.einsum("...i,...j->...ij", diff, diff)
63+
# Term (r·2/(r+2))·(x̄‑m)(x̄‑m)ᵀ
64+
dt = x_bar - self.m # (N-i-1, d)
65+
outer_dt = np.einsum("...i,...j->...ij", dt, dt) # (N-i-1, d, d)
66+
term = (self.r * 2.0 / (self.r + 2.0)) * outer_dt
67+
# Posterior scale matrix for each pair
68+
s_mat_p = self.s_mat[None, :, :] + S + term # (N-i-1, d, d)
69+
70+
# ------------------------------------------------------------------
71+
# Log‑posterior for each pair (batched)
72+
# ------------------------------------------------------------------
73+
rp = self.r + 2.0 # each cluster has two points
74+
vp = self.v + 2.0
75+
sign, logdet = slogdet(s_mat_p) # (N-i-1,)
76+
log_prior_post = (
77+
LOG2 * (vp * d / 2.0)
78+
+ (d / 2.0) * np.log(2.0 * np.pi / rp)
79+
+ multigammaln(vp / 2.0, d)
80+
- (vp / 2.0) * logdet
81+
) # (N-i-1,)
82+
83+
# Final log-likelihood for each pair
84+
return log_prior_post - self.log_prior0 - LOG2PI * d # (N-i-1,)
85+
3986
@staticmethod
4087
def __calc_log_prior(s_mat, r, v):
4188
d = s_mat.shape[0]

0 commit comments

Comments
 (0)