Skip to content

Commit 4f7404d

Browse files
committed
Merge Solveh_ into Lapack trait
1 parent 3dcf19b commit 4f7404d

File tree

2 files changed

+75
-174
lines changed

2 files changed

+75
-174
lines changed

lax/src/lib.rs

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,9 +123,7 @@ pub type Pivot = Vec<i32>;
123123

124124
#[cfg_attr(doc, katexit::katexit)]
125125
/// Trait for primitive types which implements LAPACK subroutines
126-
pub trait Lapack:
127-
OperatorNorm_ + Solveh_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_
128-
{
126+
pub trait Lapack: OperatorNorm_ + Cholesky_ + Triangular_ + Tridiagonal_ + Rcond_ {
129127
/// Compute right eigenvalue and eigenvectors for a general matrix
130128
fn eig(
131129
calc_v: bool,
@@ -217,6 +215,30 @@ pub trait Lapack:
217215

218216
/// Solve linear equations $Ax = b$ using the output of LU-decomposition
219217
fn solve(l: MatrixLayout, t: Transpose, a: &[Self], p: &Pivot, b: &mut [Self]) -> Result<()>;
218+
219+
/// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method
220+
///
221+
///
222+
/// For a given symmetric matrix $A$,
223+
/// this method factorizes $A = U^T D U$ or $A = L D L^T$ where
224+
///
225+
/// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices
226+
/// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
227+
///
228+
/// This takes two-step approach based in LAPACK:
229+
///
230+
/// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$
231+
/// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$
232+
///
233+
/// [BK]: https://doi.org/10.2307/2005787
234+
///
235+
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot>;
236+
237+
/// Compute inverse matrix $A^{-1}$ of symmetric/Hermitian matrix using factroized result
238+
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>;
239+
240+
/// Solve symmetric/Hermitian linear equation $Ax = b$ using factroized result
241+
fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
220242
}
221243

222244
macro_rules! impl_lapack {
@@ -335,6 +357,29 @@ macro_rules! impl_lapack {
335357
use solve::*;
336358
SolveImpl::solve(l, t, a, p, b)
337359
}
360+
361+
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
362+
use solveh::*;
363+
let work = BkWork::<$s>::new(l)?;
364+
work.eval(uplo, a)
365+
}
366+
367+
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
368+
use solveh::*;
369+
let mut work = InvhWork::<$s>::new(l)?;
370+
work.calc(uplo, a, ipiv)
371+
}
372+
373+
fn solveh(
374+
l: MatrixLayout,
375+
uplo: UPLO,
376+
a: &[Self],
377+
ipiv: &Pivot,
378+
b: &mut [Self],
379+
) -> Result<()> {
380+
use solveh::*;
381+
SolvehImpl::solveh(l, uplo, a, ipiv, b)
382+
}
338383
}
339384
};
340385
}

lax/src/solveh.rs

Lines changed: 27 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,15 @@ pub struct BkWork<T: Scalar> {
88
pub ipiv: Vec<MaybeUninit<i32>>,
99
}
1010

11+
/// Factorize symmetric/Hermitian matrix using Bunch-Kaufman diagonal pivoting method
12+
///
13+
/// LAPACK correspondance
14+
/// ----------------------
15+
///
16+
/// | f32 | f64 | c32 | c64 |
17+
/// |:-------|:-------|:-------|:-------|
18+
/// | ssytrf | dsytrf | chetrf | zhetrf |
19+
///
1120
pub trait BkWorkImpl: Sized {
1221
type Elem: Scalar;
1322
fn new(l: MatrixLayout) -> Result<Self>;
@@ -80,6 +89,15 @@ pub struct InvhWork<T: Scalar> {
8089
pub work: Vec<MaybeUninit<T>>,
8190
}
8291

92+
/// Compute inverse matrix of symmetric/Hermitian matrix
93+
///
94+
/// LAPACK correspondance
95+
/// ----------------------
96+
///
97+
/// | f32 | f64 | c32 | c64 |
98+
/// |:-------|:-------|:-------|:-------|
99+
/// | ssytri | dsytri | chetri | zhetri |
100+
///
83101
pub trait InvhWorkImpl: Sized {
84102
type Elem;
85103
fn new(layout: MatrixLayout) -> Result<Self>;
@@ -122,6 +140,15 @@ impl_invh_work!(c32, lapack_sys::chetri_);
122140
impl_invh_work!(f64, lapack_sys::dsytri_);
123141
impl_invh_work!(f32, lapack_sys::ssytri_);
124142

143+
/// Solve symmetric/Hermitian linear equation
144+
///
145+
/// LAPACK correspondance
146+
/// ----------------------
147+
///
148+
/// | f32 | f64 | c32 | c64 |
149+
/// |:-------|:-------|:-------|:-------|
150+
/// | ssytrs | dsytrs | chetrs | zhetrs |
151+
///
125152
pub trait SolvehImpl: Scalar {
126153
fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
127154
}
@@ -162,174 +189,3 @@ impl_solveh_!(c64, lapack_sys::zhetrs_);
162189
impl_solveh_!(c32, lapack_sys::chetrs_);
163190
impl_solveh_!(f64, lapack_sys::dsytrs_);
164191
impl_solveh_!(f32, lapack_sys::ssytrs_);
165-
166-
#[cfg_attr(doc, katexit::katexit)]
167-
/// Solve symmetric/hermite indefinite linear problem using the [Bunch-Kaufman diagonal pivoting method][BK].
168-
///
169-
/// For a given symmetric matrix $A$,
170-
/// this method factorizes $A = U^T D U$ or $A = L D L^T$ where
171-
///
172-
/// - $U$ (or $L$) are is a product of permutation and unit upper (lower) triangular matrices
173-
/// - $D$ is symmetric and block diagonal with 1-by-1 and 2-by-2 diagonal blocks.
174-
///
175-
/// This takes two-step approach based in LAPACK:
176-
///
177-
/// 1. Factorize given matrix $A$ into upper ($U$) or lower ($L$) form with diagonal matrix $D$
178-
/// 2. Then solve linear equation $Ax = b$, and/or calculate inverse matrix $A^{-1}$
179-
///
180-
/// [BK]: https://doi.org/10.2307/2005787
181-
///
182-
pub trait Solveh_: Sized {
183-
/// Factorize input matrix using Bunch-Kaufman diagonal pivoting method
184-
///
185-
/// LAPACK correspondance
186-
/// ----------------------
187-
///
188-
/// | f32 | f64 | c32 | c64 |
189-
/// |:-------|:-------|:-------|:-------|
190-
/// | ssytrf | dsytrf | chetrf | zhetrf |
191-
///
192-
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot>;
193-
194-
/// Compute inverse matrix $A^{-1}$ from factroized result
195-
///
196-
/// LAPACK correspondance
197-
/// ----------------------
198-
///
199-
/// | f32 | f64 | c32 | c64 |
200-
/// |:-------|:-------|:-------|:-------|
201-
/// | ssytri | dsytri | chetri | zhetri |
202-
///
203-
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()>;
204-
205-
/// Solve linear equation $Ax = b$ using factroized result
206-
///
207-
/// LAPACK correspondance
208-
/// ----------------------
209-
///
210-
/// | f32 | f64 | c32 | c64 |
211-
/// |:-------|:-------|:-------|:-------|
212-
/// | ssytrs | dsytrs | chetrs | zhetrs |
213-
///
214-
fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
215-
}
216-
217-
macro_rules! impl_solveh {
218-
($scalar:ty, $trf:path, $tri:path, $trs:path) => {
219-
impl Solveh_ for $scalar {
220-
fn bk(l: MatrixLayout, uplo: UPLO, a: &mut [Self]) -> Result<Pivot> {
221-
let (n, _) = l.size();
222-
let mut ipiv = vec_uninit(n as usize);
223-
if n == 0 {
224-
return Ok(Vec::new());
225-
}
226-
227-
// calc work size
228-
let mut info = 0;
229-
let mut work_size = [Self::zero()];
230-
unsafe {
231-
$trf(
232-
uplo.as_ptr(),
233-
&n,
234-
AsPtr::as_mut_ptr(a),
235-
&l.lda(),
236-
AsPtr::as_mut_ptr(&mut ipiv),
237-
AsPtr::as_mut_ptr(&mut work_size),
238-
&(-1),
239-
&mut info,
240-
)
241-
};
242-
info.as_lapack_result()?;
243-
244-
// actual
245-
let lwork = work_size[0].to_usize().unwrap();
246-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(lwork);
247-
unsafe {
248-
$trf(
249-
uplo.as_ptr(),
250-
&n,
251-
AsPtr::as_mut_ptr(a),
252-
&l.lda(),
253-
AsPtr::as_mut_ptr(&mut ipiv),
254-
AsPtr::as_mut_ptr(&mut work),
255-
&(lwork as i32),
256-
&mut info,
257-
)
258-
};
259-
info.as_lapack_result()?;
260-
let ipiv = unsafe { ipiv.assume_init() };
261-
Ok(ipiv)
262-
}
263-
264-
fn invh(l: MatrixLayout, uplo: UPLO, a: &mut [Self], ipiv: &Pivot) -> Result<()> {
265-
let (n, _) = l.size();
266-
let mut info = 0;
267-
let mut work: Vec<MaybeUninit<Self>> = vec_uninit(n as usize);
268-
unsafe {
269-
$tri(
270-
uplo.as_ptr(),
271-
&n,
272-
AsPtr::as_mut_ptr(a),
273-
&l.lda(),
274-
ipiv.as_ptr(),
275-
AsPtr::as_mut_ptr(&mut work),
276-
&mut info,
277-
)
278-
};
279-
info.as_lapack_result()?;
280-
Ok(())
281-
}
282-
283-
fn solveh(
284-
l: MatrixLayout,
285-
uplo: UPLO,
286-
a: &[Self],
287-
ipiv: &Pivot,
288-
b: &mut [Self],
289-
) -> Result<()> {
290-
let (n, _) = l.size();
291-
let mut info = 0;
292-
unsafe {
293-
$trs(
294-
uplo.as_ptr(),
295-
&n,
296-
&1,
297-
AsPtr::as_ptr(a),
298-
&l.lda(),
299-
ipiv.as_ptr(),
300-
AsPtr::as_mut_ptr(b),
301-
&n,
302-
&mut info,
303-
)
304-
};
305-
info.as_lapack_result()?;
306-
Ok(())
307-
}
308-
}
309-
};
310-
} // impl_solveh!
311-
312-
impl_solveh!(
313-
f64,
314-
lapack_sys::dsytrf_,
315-
lapack_sys::dsytri_,
316-
lapack_sys::dsytrs_
317-
);
318-
impl_solveh!(
319-
f32,
320-
lapack_sys::ssytrf_,
321-
lapack_sys::ssytri_,
322-
lapack_sys::ssytrs_
323-
);
324-
impl_solveh!(
325-
c64,
326-
lapack_sys::zhetrf_,
327-
lapack_sys::zhetri_,
328-
lapack_sys::zhetrs_
329-
);
330-
impl_solveh!(
331-
c32,
332-
lapack_sys::chetrf_,
333-
lapack_sys::chetri_,
334-
lapack_sys::chetrs_
335-
);

0 commit comments

Comments
 (0)