@@ -58,24 +58,13 @@ fn sorted_eig<A: Scalar + Lapack>(
5858
5959/// Masks a matrix with the given `matrix`
6060fn ndarray_mask < A : Scalar > ( matrix : ArrayView2 < A > , mask : & [ bool ] ) -> Array2 < A > {
61- let ( rows , cols ) = ( matrix . nrows ( ) , matrix. ncols ( ) ) ;
61+ assert_eq ! ( mask . len ( ) , matrix. ncols( ) ) ;
6262
63- assert_eq ! ( mask. len( ) , cols) ;
63+ let indices = ( 0 ..mask. len ( ) ) . zip ( mask. into_iter ( ) )
64+ . filter ( |( _, b) | * * b) . map ( |( a, _) | a)
65+ . collect :: < Vec < usize > > ( ) ;
6466
65- let n_positive = mask. iter ( ) . filter ( |x| * * x) . count ( ) ;
66-
67- let matrix = matrix
68- . gencolumns ( )
69- . into_iter ( )
70- . zip ( mask. iter ( ) )
71- . filter ( |( _, x) | * * x)
72- . map ( |( x, _) | x. to_vec ( ) )
73- . flatten ( )
74- . collect :: < Vec < A > > ( ) ;
75-
76- Array2 :: from_shape_vec ( ( n_positive, rows) , matrix)
77- . unwrap ( )
78- . reversed_axes ( )
67+ matrix. select ( Axis ( 1 ) , & indices)
7968}
8069
8170/// Applies constraints ensuring that a matrix is orthogonal to it
@@ -193,19 +182,16 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
193182 let ax = a ( x. view ( ) ) ;
194183 let xax = x. t ( ) . dot ( & ax) ;
195184
196- // perform eigenvalue decomposition on XAX
185+ // perform eigenvalue decomposition of XAX as initialization
197186 let ( mut lambda, eig_block) = match sorted_eig ( xax. view ( ) , None , size_x, & order) {
198187 Ok ( x) => x,
199188 Err ( err) => return EigResult :: NoResult ( err) ,
200189 } ;
201190
202- //dbg!(&lambda, &eig_block);
203-
204191 // initiate X and AX with eigenvector
205192 let mut x = x. dot ( & eig_block) ;
206193 let mut ax = ax. dot ( & eig_block) ;
207194
208- //dbg!(&X, &AX);
209195 let mut activemask = vec ! [ true ; size_x] ;
210196 let mut residual_norms = Vec :: new ( ) ;
211197 let mut results = vec ! [ ( lambda. clone( ) , x. clone( ) ) ] ;
@@ -219,10 +205,11 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
219205
220206 let final_norm = loop {
221207 // calculate residual
222- let lambda_tmp = lambda . clone ( ) . insert_axis ( Axis ( 0 ) ) ;
223- let tmp = & x * & lambda_tmp ;
208+ let lambda_diag = Array2 :: from_diag ( & lambda ) ;
209+ let lambda_x = x . dot ( & lambda_diag ) ;
224210
225- let r = & ax - & tmp;
211+ // calculate residual AX - lambdaX
212+ let r = & ax - & lambda_x;
226213
227214 // calculate L2 norm of error for every eigenvalue
228215 let tmp = r. gencolumns ( ) . into_iter ( ) . map ( |x| x. norm ( ) ) . collect :: < Vec < A :: Real > > ( ) ;
@@ -248,7 +235,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
248235 let mut active_block_r = ndarray_mask ( r. view ( ) , & activemask) ;
249236 // apply preconditioner
250237 m ( active_block_r. view_mut ( ) ) ;
251-
238+ // apply constraints
252239 if let ( Some ( ref y) , Some ( ref fact_yy) ) = ( & y, & fact_yy) {
253240 apply_constraints ( active_block_r. view_mut ( ) , fact_yy, y. view ( ) ) ;
254241 }
@@ -271,17 +258,14 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
271258 let active_ap = ndarray_mask ( ap. view ( ) , & activemask) ;
272259
273260 let ( active_p, p_r) = orthonormalize ( active_p) . unwrap ( ) ;
274- //dbg!(&active_P, &P_R);
261+
275262 let active_ap = match p_r. solve_triangular ( UPLO :: Lower , Diag :: NonUnit , & active_ap. reversed_axes ( ) ) {
276263 Ok ( x) => x,
277264 Err ( err) => break Err ( err) ,
278265 } ;
279266
280267 let active_ap = active_ap. reversed_axes ( ) ;
281268
282- //dbg!(&active_AP);
283- //dbg!(&R);
284-
285269 let xap = x. t ( ) . dot ( & active_ap) ;
286270 let wap = r. t ( ) . dot ( & active_ap) ;
287271 let pap = active_p. t ( ) . dot ( & active_ap) ;
@@ -291,7 +275,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
291275 (
292276 stack ! [
293277 Axis ( 0 ) ,
294- stack![ Axis ( 1 ) , Array2 :: from_diag ( & lambda ) , xaw, xap] ,
278+ stack![ Axis ( 1 ) , lambda_diag , xaw, xap] ,
295279 stack![ Axis ( 1 ) , xaw. t( ) , waw, wap] ,
296280 stack![ Axis ( 1 ) , xap. t( ) , wap. t( ) , pap]
297281 ] ,
@@ -308,7 +292,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
308292 (
309293 stack ! [
310294 Axis ( 0 ) ,
311- stack![ Axis ( 1 ) , Array2 :: from_diag ( & lambda ) , xaw] ,
295+ stack![ Axis ( 1 ) , lambda_diag , xaw] ,
312296 stack![ Axis ( 1 ) , xaw. t( ) , waw]
313297 ] ,
314298 stack ! [ Axis ( 0 ) , stack![ Axis ( 1 ) , ident0, xw] , stack![ Axis ( 1 ) , xw. t( ) , ident] ] ,
@@ -323,25 +307,15 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
323307 } ;
324308 lambda = new_lambda;
325309
326- //dbg!(&lambda, &eig_vecs);
327310 let ( pp, app, eig_x) = if let ( Some ( _) , ( Some ( ref active_p) , Some ( ref active_ap) ) ) = ( ap, ( active_p, active_ap) )
328311 {
329312 let eig_x = eig_vecs. slice ( s ! [ ..size_x, ..] ) ;
330313 let eig_r = eig_vecs. slice ( s ! [ size_x..size_x + current_block_size, ..] ) ;
331314 let eig_p = eig_vecs. slice ( s ! [ size_x + current_block_size.., ..] ) ;
332315
333- //dbg!(&eig_X);
334- //dbg!(&eig_R);
335- //dbg!(&eig_P);
336-
337- //dbg!(&R, &AR, &active_P, &active_AP);
338-
339316 let pp = r. dot ( & eig_r) + active_p. dot ( & eig_p) ;
340317 let app = ar. dot ( & eig_r) + active_ap. dot ( & eig_p) ;
341318
342- //dbg!(&pp);
343- //dbg!(&app);
344-
345319 ( pp, app, eig_x)
346320 } else {
347321 let eig_x = eig_vecs. slice ( s ! [ ..size_x, ..] ) ;
@@ -363,7 +337,6 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
363337 iter -= 1 ;
364338 } ;
365339
366- //dbg!(&residual_norms);
367340 let best_idx = residual_norms. iter ( ) . enumerate ( ) . min_by (
368341 |& ( _, item1) : & ( usize , & Vec < A :: Real > ) , & ( _, item2) : & ( usize , & Vec < A :: Real > ) | {
369342 let norm1: A :: Real = item1. iter ( ) . map ( |x| ( * x) * ( * x) ) . sum ( ) ;
0 commit comments