@@ -7,7 +7,8 @@ use crate::{cholesky::*, close_l2, eigh::*, norm::*, triangular::*};
77use crate :: { Lapack , Scalar } ;
88use ndarray:: prelude:: * ;
99use ndarray:: OwnedRepr ;
10- use num_traits:: NumCast ;
10+ use ndarray:: ScalarOperand ;
11+ use num_traits:: { NumCast , Float } ;
1112
1213/// Find largest or smallest eigenvalues
1314#[ derive( Debug , Clone ) ]
@@ -136,7 +137,7 @@ fn orthonormalize<T: Scalar + Lapack>(v: Array2<T>) -> Result<(Array2<T>, Array2
136137/// for it. All iterations are tracked and the optimal solution returned. In case of an error a
137138/// special variant `EigResult::NotConverged` additionally carries the error. This can happen when
138139/// the precision of the matrix is too low (switch from `f32` to `f64` for example).
139- pub fn lobpcg < A : Scalar + Lapack + PartialOrd + Default , F : Fn ( ArrayView2 < A > ) -> Array2 < A > , G : Fn ( ArrayViewMut2 < A > ) > (
140+ pub fn lobpcg < A : Float + Scalar + Lapack + ScalarOperand + PartialOrd + Default , F : Fn ( ArrayView2 < A > ) -> Array2 < A > , G : Fn ( ArrayViewMut2 < A > ) > (
140141 a : F ,
141142 mut x : Array2 < A > ,
142143 m : G ,
@@ -193,15 +194,17 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
193194 let mut ax = ax. dot ( & eig_block) ;
194195
195196 let mut activemask = vec ! [ true ; size_x] ;
196- let mut residual_norms = Vec :: new ( ) ;
197+ let mut residual_norms_history = Vec :: new ( ) ;
197198 let mut best_result = None ;
198199
199200 let mut previous_block_size = size_x;
200201
201202 let mut ident: Array2 < A > = Array2 :: eye ( size_x) ;
202203 let ident0: Array2 < A > = Array2 :: eye ( size_x) ;
204+ let two: A = NumCast :: from ( 2.0 ) . unwrap ( ) ;
203205
204206 let mut ap: Option < ( Array2 < A > , Array2 < A > ) > = None ;
207+ let mut explicit_gram_flag = false ;
205208
206209 let final_norm = loop {
207210 // calculate residual
@@ -212,17 +215,17 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
212215 let r = & ax - & lambda_x;
213216
214217 // calculate L2 norm of error for every eigenvalue
215- let tmp = r. gencolumns ( ) . into_iter ( ) . map ( |x| x. norm ( ) ) . collect :: < Vec < A :: Real > > ( ) ;
216- residual_norms . push ( tmp . clone ( ) ) ;
218+ let residual_norms = r. gencolumns ( ) . into_iter ( ) . map ( |x| x. norm ( ) ) . collect :: < Vec < A :: Real > > ( ) ;
219+ residual_norms_history . push ( residual_norms . clone ( ) ) ;
217220
218221 // compare best result and update if we improved
219- let sum_rnorm: A :: Real = tmp . iter ( ) . cloned ( ) . sum ( ) ;
222+ let sum_rnorm: A :: Real = residual_norms . iter ( ) . cloned ( ) . sum ( ) ;
220223 if best_result. as_ref ( ) . map ( |x : & ( _ , _ , Vec < A :: Real > ) | x. 2 . iter ( ) . cloned ( ) . sum :: < A :: Real > ( ) > sum_rnorm) . unwrap_or ( true ) {
221- best_result = Some ( ( lambda. clone ( ) , x. clone ( ) , tmp . clone ( ) ) ) ;
224+ best_result = Some ( ( lambda. clone ( ) , x. clone ( ) , residual_norms . clone ( ) ) ) ;
222225 }
223226
224227 // disable eigenvalues which are below the tolerance threshold
225- activemask = tmp . iter ( ) . zip ( activemask. iter ( ) ) . map ( |( x, a) | * x > tol && * a) . collect ( ) ;
228+ activemask = residual_norms . iter ( ) . zip ( activemask. iter ( ) ) . map ( |( x, a) | * x > tol && * a) . collect ( ) ;
226229
227230 // resize identity block if necessary
228231 let current_block_size = activemask. iter ( ) . filter ( |x| * * x) . count ( ) ;
@@ -234,17 +237,19 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
234237 // if we are below the threshold for all eigenvalue or exceeded the number of iteration,
235238 // abort
236239 if current_block_size == 0 || iter == 0 {
237- break Ok ( tmp ) ;
240+ break Ok ( residual_norms ) ;
238241 }
239242
240243 // select active eigenvalues, apply pre-conditioner, orthogonalize to Y and orthonormalize
241244 let mut active_block_r = ndarray_mask ( r. view ( ) , & activemask) ;
242245 // apply preconditioner
243246 m ( active_block_r. view_mut ( ) ) ;
244- // apply constraints
247+ // apply constraints to the preconditioned residuals
245248 if let ( Some ( ref y) , Some ( ref fact_yy) ) = ( & y, & fact_yy) {
246249 apply_constraints ( active_block_r. view_mut ( ) , fact_yy, y. view ( ) ) ;
247250 }
251+ // orthogonalize the preconditioned residual to x
252+ active_block_r -= & x. dot ( & x. t ( ) . dot ( & active_block_r) ) ;
248253
249254 let ( r, _) = match orthonormalize ( active_block_r) {
250255 Ok ( x) => x,
@@ -253,57 +258,99 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
253258
254259 let ar = a ( r. view ( ) ) ;
255260
261+ let max_rnorm = if A :: epsilon ( ) > NumCast :: from ( 1e-8 ) . unwrap ( ) {
262+ NumCast :: from ( 1.0 ) . unwrap ( )
263+ } else {
264+ NumCast :: from ( 1.0e-8 ) . unwrap ( )
265+ } ;
266+
267+ // if we once below the max_rnorm enable explicit gram flag
268+ let max_norm = residual_norms. into_iter ( ) . fold ( A :: Real :: neg_infinity ( ) , A :: Real :: max) ;
269+
270+ explicit_gram_flag = max_norm <= max_rnorm || explicit_gram_flag;
271+
256272 // perform the Rayleigh Ritz procedure
257- let xaw = x. t ( ) . dot ( & ar) ;
258- let waw = r. t ( ) . dot ( & ar) ;
259- let xw = x. t ( ) . dot ( & r) ;
273+ let xar = x. t ( ) . dot ( & ar) ;
274+ let mut rar = r. t ( ) . dot ( & ar) ;
260275
261- // compute symmetric gram matrices
262- let ( gram_a, gram_b, active_p, active_ap) = if let Some ( ( ref p, ref ap) ) = ap {
276+ let ( xax, xx, rr, xr) = if explicit_gram_flag {
277+ rar = ( & rar + & rar. t ( ) ) / two;
278+ let xax = x. t ( ) . dot ( & ax) ;
279+
280+ (
281+ ( & xax + & xax. t ( ) ) / two,
282+ x. t ( ) . dot ( & x) ,
283+ r. t ( ) . dot ( & r) ,
284+ x. t ( ) . dot ( & r)
285+ )
286+ } else {
287+ (
288+ lambda_diag,
289+ ident0. clone ( ) ,
290+ ident. clone ( ) ,
291+ Array2 :: zeros ( ( size_x, current_block_size) )
292+ )
293+ } ;
294+
295+ let p_ap: Option < ( _ , _ ) > = ap. as_ref ( ) . and_then ( |( p, ap) | {
263296 let active_p = ndarray_mask ( p. view ( ) , & activemask) ;
264297 let active_ap = ndarray_mask ( ap. view ( ) , & activemask) ;
265298
266- let ( active_p, p_r) = orthonormalize ( active_p) . unwrap ( ) ;
299+ if let Ok ( ( active_p, p_r) ) = orthonormalize ( active_p) {
300+ if let Ok ( active_ap) = p_r. solve_triangular ( UPLO :: Lower , Diag :: NonUnit , & active_ap. reversed_axes ( ) ) {
301+ let active_ap = active_ap. reversed_axes ( ) ;
302+
303+ Some ( ( active_p, active_ap) )
304+ } else {
305+ None
306+ }
307+ } else {
308+ None
309+ }
310+ } ) ;
267311
268- let active_ap = match p_r. solve_triangular ( UPLO :: Lower , Diag :: NonUnit , & active_ap. reversed_axes ( ) ) {
269- Ok ( x) => x,
270- Err ( err) => break Err ( err) ,
312+ // compute symmetric gram matrices
313+ let ( gram_a, gram_b) = if let Some ( ( active_p, active_ap) ) = & p_ap {
314+ let xap = x. t ( ) . dot ( active_ap) ;
315+ let rap = r. t ( ) . dot ( active_ap) ;
316+ let pap = active_p. t ( ) . dot ( active_ap) ;
317+ let xp = x. t ( ) . dot ( active_p) ;
318+ let rp = r. t ( ) . dot ( active_p) ;
319+ let ( pap, pp) = if explicit_gram_flag {
320+ (
321+ ( & pap + & pap. t ( ) ) / two,
322+ active_p. t ( ) . dot ( active_p)
323+ )
324+ } else {
325+ ( pap, ident. clone ( ) )
271326 } ;
272327
273- let active_ap = active_ap. reversed_axes ( ) ;
274-
275- let xap = x. t ( ) . dot ( & active_ap) ;
276- let wap = r. t ( ) . dot ( & active_ap) ;
277- let pap = active_p. t ( ) . dot ( & active_ap) ;
278- let xp = x. t ( ) . dot ( & active_p) ;
279- let wp = r. t ( ) . dot ( & active_p) ;
280-
281328 (
282329 stack ! [
283330 Axis ( 0 ) ,
284- stack![ Axis ( 1 ) , lambda_diag , xaw , xap] ,
285- stack![ Axis ( 1 ) , xaw . t( ) , waw , wap ] ,
286- stack![ Axis ( 1 ) , xap. t( ) , wap . t( ) , pap]
331+ stack![ Axis ( 1 ) , xax , xar , xap] ,
332+ stack![ Axis ( 1 ) , xar . t( ) , rar , rap ] ,
333+ stack![ Axis ( 1 ) , xap. t( ) , rap . t( ) , pap]
287334 ] ,
288335 stack ! [
289336 Axis ( 0 ) ,
290- stack![ Axis ( 1 ) , ident0 , xw , xp] ,
291- stack![ Axis ( 1 ) , xw . t( ) , ident , wp ] ,
292- stack![ Axis ( 1 ) , xp. t( ) , wp . t( ) , ident ]
337+ stack![ Axis ( 1 ) , xx , xr , xp] ,
338+ stack![ Axis ( 1 ) , xr . t( ) , rr , rp ] ,
339+ stack![ Axis ( 1 ) , xp. t( ) , rp . t( ) , pp ]
293340 ] ,
294- Some ( active_p) ,
295- Some ( active_ap) ,
296341 )
297342 } else {
298343 (
299344 stack ! [
300345 Axis ( 0 ) ,
301- stack![ Axis ( 1 ) , lambda_diag, xaw] ,
302- stack![ Axis ( 1 ) , xaw. t( ) , waw]
346+ stack![ Axis ( 1 ) , xax, xar] ,
347+ stack![ Axis ( 1 ) , xar. t( ) , rar]
348+ ] ,
349+ stack ! [
350+ Axis ( 0 ) ,
351+ stack![ Axis ( 1 ) , xx, xr] ,
352+ stack![ Axis ( 1 ) , xr. t( ) , rr]
303353 ] ,
304- stack ! [ Axis ( 0 ) , stack![ Axis ( 1 ) , ident0, xw] , stack![ Axis ( 1 ) , xw. t( ) , ident] ] ,
305- None ,
306- None ,
307354 )
308355 } ;
309356
@@ -313,7 +360,7 @@ pub fn lobpcg<A: Scalar + Lapack + PartialOrd + Default, F: Fn(ArrayView2<A>) ->
313360 } ;
314361 lambda = new_lambda;
315362
316- let ( pp, app, eig_x) = if let ( Some ( _) , ( Some ( ref active_p) , Some ( ref active_ap) ) ) = ( ap, ( active_p , active_ap ) )
363+ let ( pp, app, eig_x) = if let ( Some ( _) , Some ( ( active_p, active_ap) ) ) = ( ap, p_ap )
317364 {
318365 let eig_x = eig_vecs. slice ( s ! [ ..size_x, ..] ) ;
319366 let eig_r = eig_vecs. slice ( s ! [ size_x..size_x + current_block_size, ..] ) ;
0 commit comments