|
| 1 | +///! Implements truncated eigenvalue decomposition |
| 2 | +/// |
| 3 | +
|
| 4 | +use ndarray::prelude::*; |
| 5 | +use ndarray::stack; |
| 6 | +use ndarray_rand::rand_distr::Uniform; |
| 7 | +use ndarray_rand::RandomExt; |
| 8 | +use num_traits::{Float, NumCast}; |
| 9 | +use crate::{Scalar, Lapack}; |
| 10 | +use super::lobpcg::{lobpcg, EigResult, Order}; |
| 11 | + |
| 12 | +pub struct TruncatedEig<A: Scalar> { |
| 13 | + order: Order, |
| 14 | + problem: Array2<A>, |
| 15 | + pub constraints: Option<Array2<A>>, |
| 16 | + precision: A::Real, |
| 17 | + maxiter: usize |
| 18 | +} |
| 19 | + |
| 20 | +impl<A: Scalar + Lapack + PartialOrd + Default> TruncatedEig<A> { |
| 21 | + pub fn new(problem: Array2<A>, order: Order) -> TruncatedEig<A> { |
| 22 | + TruncatedEig { |
| 23 | + precision: NumCast::from(1e-5).unwrap(), |
| 24 | + maxiter: problem.len_of(Axis(0)) * 2, |
| 25 | + constraints: None, |
| 26 | + order, |
| 27 | + problem |
| 28 | + } |
| 29 | + } |
| 30 | + |
| 31 | + pub fn precision(mut self, precision: A::Real) -> Self { |
| 32 | + self.precision = precision; |
| 33 | + |
| 34 | + self |
| 35 | + } |
| 36 | + |
| 37 | + pub fn maxiter(mut self, maxiter: usize) -> Self { |
| 38 | + self.maxiter = maxiter; |
| 39 | + |
| 40 | + self |
| 41 | + |
| 42 | + } |
| 43 | + |
| 44 | + pub fn constraints(mut self, constraints: Array2<A>) -> Self { |
| 45 | + self.constraints = Some(constraints); |
| 46 | + |
| 47 | + self |
| 48 | + } |
| 49 | + |
| 50 | + pub fn once(&self, num: usize) -> EigResult<A> { |
| 51 | + let x = Array2::random((self.problem.len_of(Axis(0)), num), Uniform::new(0.0, 1.0)) |
| 52 | + .mapv(|x| NumCast::from(x).unwrap()); |
| 53 | + |
| 54 | + lobpcg(|y| self.problem.dot(&y), x, None, self.constraints.clone(), self.precision, self.maxiter, self.order.clone()) |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +impl<A: Float + Scalar + Lapack + PartialOrd + Default> IntoIterator for TruncatedEig<A> { |
| 59 | + type Item = (Array1<A>, Array2<A>); |
| 60 | + type IntoIter = TruncatedEigIterator<A>; |
| 61 | + |
| 62 | + fn into_iter(self) -> TruncatedEigIterator<A>{ |
| 63 | + TruncatedEigIterator { |
| 64 | + step_size: 1, |
| 65 | + eig: self |
| 66 | + } |
| 67 | + } |
| 68 | +} |
| 69 | + |
| 70 | +pub struct TruncatedEigIterator<A: Scalar> { |
| 71 | + step_size: usize, |
| 72 | + eig: TruncatedEig<A> |
| 73 | +} |
| 74 | + |
| 75 | +impl<A: Float + Scalar + Lapack + PartialOrd + Default> Iterator for TruncatedEigIterator<A> { |
| 76 | + type Item = (Array1<A>, Array2<A>); |
| 77 | + |
| 78 | + fn next(&mut self) -> Option<Self::Item> { |
| 79 | + let res = self.eig.once(self.step_size); |
| 80 | + dbg!(&res); |
| 81 | + |
| 82 | + match res { |
| 83 | + EigResult::Ok(vals, vecs, norms) | EigResult::Err(vals, vecs, norms, _) => { |
| 84 | + // abort if any eigenproblem did not converge |
| 85 | + for r_norm in norms { |
| 86 | + if r_norm > NumCast::from(0.1).unwrap() { |
| 87 | + return None; |
| 88 | + } |
| 89 | + } |
| 90 | + |
| 91 | + let new_constraints = if let Some(ref constraints) = self.eig.constraints { |
| 92 | + let eigvecs_arr = constraints.gencolumns().into_iter() |
| 93 | + .chain(vecs.gencolumns().into_iter()) |
| 94 | + .map(|x| x.insert_axis(Axis(1))) |
| 95 | + .collect::<Vec<_>>(); |
| 96 | + |
| 97 | + stack(Axis(1), &eigvecs_arr).unwrap() |
| 98 | + } else { |
| 99 | + vecs.clone() |
| 100 | + }; |
| 101 | + |
| 102 | + dbg!(&new_constraints); |
| 103 | + |
| 104 | + self.eig.constraints = Some(new_constraints); |
| 105 | + |
| 106 | + Some((vals, vecs)) |
| 107 | + }, |
| 108 | + EigResult::NoResult(_) => None |
| 109 | + } |
| 110 | + } |
| 111 | +} |
| 112 | + |
| 113 | +#[cfg(test)] |
| 114 | +mod tests { |
| 115 | + use super::TruncatedEig; |
| 116 | + use super::Order; |
| 117 | + use ndarray::Array2; |
| 118 | + use ndarray_rand::rand_distr::Uniform; |
| 119 | + use ndarray_rand::RandomExt; |
| 120 | + |
| 121 | + #[test] |
| 122 | + fn test_truncated_eig() { |
| 123 | + let a = Array2::random((50, 50), Uniform::new(0., 1.0)); |
| 124 | + let a = a.t().dot(&a); |
| 125 | + |
| 126 | + let teig = TruncatedEig::new(a, Order::Largest) |
| 127 | + .precision(1e-5) |
| 128 | + .maxiter(500); |
| 129 | + |
| 130 | + let res = teig.into_iter().take(3).flat_map(|x| x.0.to_vec()).collect::<Vec<_>>(); |
| 131 | + dbg!(&res); |
| 132 | + panic!(""); |
| 133 | + } |
| 134 | +} |
0 commit comments