@@ -48,7 +48,11 @@ def build(self):
4848 ij = n_objects - 1
4949
5050 # for every pair of data points
51+ pair_count = n_objects * (n_objects - 1 ) // 2
52+ tmp_merge = np .empty ((pair_count , 5 ), dtype = float )
53+ row = 0
5154 for i in range (n_objects ):
55+ log_p_k_row = self .model .row_of_log_likelihood_for_pairs (self .data , i )
5256 for j in range (i + 1 , n_objects ):
5357 # compute log(d_k)
5458 n_ch = n [i ] + n [j ]
@@ -59,28 +63,22 @@ def build(self):
5963 # compute log(pi_k)
6064 log_pik = np .log (self .alpha ) + gammaln (n_ch ) - log_dk
6165 # compute log(p_k)
62- data_merged = np .vstack ((self .data [i ], self .data [j ]))
63- log_p_k = self .model .calc_log_mlh (data_merged )
66+ log_p_k = log_p_k_row [j - i - 1 ] # since j starts at i + 1
6467 # compute log(r_k)
6568 log_p_ch = log_p [i ] + log_p [j ]
6669 r1 = log_pik + log_p_k
6770 r2 = log_d_ch - log_dk + log_p_ch
6871 log_r = r1 - r2
6972 # store results
70- merge_info = [i , j , log_r , r1 , r2 ]
71- tmp_merge = (
72- merge_info
73- if tmp_merge is None
74- else np .vstack ((tmp_merge , merge_info ))
75- )
73+ tmp_merge [row ] = [i , j , log_r , r1 , r2 ]
74+ row += 1
7675
7776 # find clusters to merge
7877 arc_list = np .empty (0 , dtype = api .Arc )
78+ data_per_cluster = [np .array ([self .data [i ]]) for i in range (n_objects )]
7979 while active_nodes .size > 1 :
8080 # find i, j with the highest probability of the merged hypothesis
81- max_log_rk = np .max (tmp_merge [:, 2 ])
82- ids_matched = np .argwhere (tmp_merge [:, 2 ] == max_log_rk )
83- position = np .min (ids_matched )
81+ position = np .argmax (tmp_merge [:, 2 ]) # returns the first occurrence
8482 i , j , log_r , r1 , r2 = tmp_merge [position ]
8583 i = int (i )
8684 j = int (j )
@@ -91,12 +89,6 @@ def build(self):
9189 hierarchy_cut = True
9290 break
9391
94- # turn nodes i,j off
95- tmp_merge [np .argwhere (tmp_merge [:, 0 ] == i ).flatten (), 2 ] = - np .inf
96- tmp_merge [np .argwhere (tmp_merge [:, 1 ] == i ).flatten (), 2 ] = - np .inf
97- tmp_merge [np .argwhere (tmp_merge [:, 0 ] == j ).flatten (), 2 ] = - np .inf
98- tmp_merge [np .argwhere (tmp_merge [:, 1 ] == j ).flatten (), 2 ] = - np .inf
99-
10092 # new node ij
10193 ij = n .size
10294 n_ch = n [i ] + n [j ]
@@ -107,7 +99,12 @@ def build(self):
10799 self .alpha , n [ij ], log_d_ch
108100 )
109101 log_d = np .append (log_d , log_d_ij )
110- # update assignments
102+ # update cluster assignments
103+ data_per_cluster .append (
104+ np .vstack ((data_per_cluster [i ], data_per_cluster [j ]))
105+ )
106+ data_per_cluster [i ] = None
107+ data_per_cluster [j ] = None
111108 assignments [np .argwhere (assignments == i )] = ij
112109 assignments [np .argwhere (assignments == j )] = ij
113110
@@ -121,14 +118,20 @@ def build(self):
121118 j_idx = np .argwhere (active_nodes == j ).flatten ()
122119 active_nodes = np .delete (active_nodes , [i_idx , j_idx ])
123120 active_nodes = np .append (active_nodes , ij )
121+
122+ # clean up tmp_merge
123+ # keep rows where neither column 0 nor column 1 equals i or j
124+ mask = ~ np .isin (tmp_merge [:, :2 ], [i , j ]).any (axis = 1 )
125+ tmp_merge = tmp_merge [mask ]
126+
124127 # compute log(p_ij)
125128 t1 = np .maximum (r1 , r2 )
126129 t2 = np .minimum (r1 , r2 )
127130 log_p_ij = t1 + np .log (1 + np .exp (t2 - t1 ))
128131 log_p = np .append (log_p , log_p_ij )
129132
130133 # for every pair ij x active
131- x_mat_ij = self . data [ np .argwhere ( assignments == ij ). flatten ()]
134+ collected_merge_info = np .empty (( len ( active_nodes ) - 1 , 5 ), dtype = float )
132135 for k in range (active_nodes .size - 1 ):
133136 # compute log(d_k)
134137 n_ch = n [k ] + n [ij ]
@@ -139,18 +142,19 @@ def build(self):
139142 # compute log(pi_k)
140143 log_pik = np .log (self .alpha ) + gammaln (n_ch ) - log_dij
141144 # compute log(p_k)
142- data_merged = self . data [
143- np . argwhere ( assignments == active_nodes [k ]). flatten ( )
144- ]
145- log_p_ij = self .model .calc_log_mlh (np . vstack (( x_mat_ij , data_merged )) )
145+ data_merged = np . vstack (
146+ ( data_per_cluster [ ij ], data_per_cluster [ active_nodes [k ]] )
147+ )
148+ log_p_ij = self .model .calc_log_mlh (data_merged )
146149 # compute log(r_k)
147150 log_p_ch = log_p [ij ] + log_p [active_nodes [k ]]
148151 r1 = log_pik + log_p_ij
149152 r2 = log_d_ch - log_dij + log_p_ch
150153 log_r = r1 - r2
151154 # store results
152- merge_info = [ij , active_nodes [k ], log_r , r1 , r2 ]
153- tmp_merge = np .vstack ((tmp_merge , merge_info ))
155+ collected_merge_info [k ] = [ij , active_nodes [k ], log_r , r1 , r2 ]
156+
157+ tmp_merge = np .vstack ((tmp_merge , collected_merge_info ))
154158
155159 return api .Result (
156160 arc_list ,
0 commit comments