Skip to content

Commit 3dcf19b

Browse files
committed
Add BkWork, InvhWork, SolvehImpl
1 parent 23baa44 commit 3dcf19b

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

lax/src/solveh.rs

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,167 @@ use crate::{error::*, layout::MatrixLayout, *};
22
use cauchy::*;
33
use num_traits::{ToPrimitive, Zero};
44

5+
pub struct BkWork<T: Scalar> {
6+
pub layout: MatrixLayout,
7+
pub work: Vec<MaybeUninit<T>>,
8+
pub ipiv: Vec<MaybeUninit<i32>>,
9+
}
10+
11+
pub trait BkWorkImpl: Sized {
12+
type Elem: Scalar;
13+
fn new(l: MatrixLayout) -> Result<Self>;
14+
fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]>;
15+
fn eval(self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Pivot>;
16+
}
17+
18+
macro_rules! impl_bk_work {
19+
($s:ty, $trf:path) => {
20+
impl BkWorkImpl for BkWork<$s> {
21+
type Elem = $s;
22+
23+
fn new(layout: MatrixLayout) -> Result<Self> {
24+
let (n, _) = layout.size();
25+
let ipiv = vec_uninit(n as usize);
26+
let mut info = 0;
27+
let mut work_size = [Self::Elem::zero()];
28+
unsafe {
29+
$trf(
30+
UPLO::Upper.as_ptr(),
31+
&n,
32+
std::ptr::null_mut(),
33+
&layout.lda(),
34+
std::ptr::null_mut(),
35+
AsPtr::as_mut_ptr(&mut work_size),
36+
&(-1),
37+
&mut info,
38+
)
39+
};
40+
info.as_lapack_result()?;
41+
let lwork = work_size[0].to_usize().unwrap();
42+
let work = vec_uninit(lwork);
43+
Ok(BkWork { layout, work, ipiv })
44+
}
45+
46+
fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<&[i32]> {
47+
let (n, _) = self.layout.size();
48+
let lwork = self.work.len().to_i32().unwrap();
49+
let mut info = 0;
50+
unsafe {
51+
$trf(
52+
uplo.as_ptr(),
53+
&n,
54+
AsPtr::as_mut_ptr(a),
55+
&self.layout.lda(),
56+
AsPtr::as_mut_ptr(&mut self.ipiv),
57+
AsPtr::as_mut_ptr(&mut self.work),
58+
&lwork,
59+
&mut info,
60+
)
61+
};
62+
info.as_lapack_result()?;
63+
Ok(unsafe { self.ipiv.slice_assume_init_ref() })
64+
}
65+
66+
fn eval(mut self, uplo: UPLO, a: &mut [Self::Elem]) -> Result<Pivot> {
67+
let _ref = self.calc(uplo, a)?;
68+
Ok(unsafe { self.ipiv.assume_init() })
69+
}
70+
}
71+
};
72+
}
73+
impl_bk_work!(c64, lapack_sys::zhetrf_);
74+
impl_bk_work!(c32, lapack_sys::chetrf_);
75+
impl_bk_work!(f64, lapack_sys::dsytrf_);
76+
impl_bk_work!(f32, lapack_sys::ssytrf_);
77+
78+
pub struct InvhWork<T: Scalar> {
79+
pub layout: MatrixLayout,
80+
pub work: Vec<MaybeUninit<T>>,
81+
}
82+
83+
pub trait InvhWorkImpl: Sized {
84+
type Elem;
85+
fn new(layout: MatrixLayout) -> Result<Self>;
86+
fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()>;
87+
}
88+
89+
macro_rules! impl_invh_work {
90+
($s:ty, $tri:path) => {
91+
impl InvhWorkImpl for InvhWork<$s> {
92+
type Elem = $s;
93+
94+
fn new(layout: MatrixLayout) -> Result<Self> {
95+
let (n, _) = layout.size();
96+
let work = vec_uninit(n as usize);
97+
Ok(InvhWork { layout, work })
98+
}
99+
100+
fn calc(&mut self, uplo: UPLO, a: &mut [Self::Elem], ipiv: &Pivot) -> Result<()> {
101+
let (n, _) = self.layout.size();
102+
let mut info = 0;
103+
unsafe {
104+
$tri(
105+
uplo.as_ptr(),
106+
&n,
107+
AsPtr::as_mut_ptr(a),
108+
&self.layout.lda(),
109+
ipiv.as_ptr(),
110+
AsPtr::as_mut_ptr(&mut self.work),
111+
&mut info,
112+
)
113+
};
114+
info.as_lapack_result()?;
115+
Ok(())
116+
}
117+
}
118+
};
119+
}
120+
impl_invh_work!(c64, lapack_sys::zhetri_);
121+
impl_invh_work!(c32, lapack_sys::chetri_);
122+
impl_invh_work!(f64, lapack_sys::dsytri_);
123+
impl_invh_work!(f32, lapack_sys::ssytri_);
124+
125+
pub trait SolvehImpl: Scalar {
126+
fn solveh(l: MatrixLayout, uplo: UPLO, a: &[Self], ipiv: &Pivot, b: &mut [Self]) -> Result<()>;
127+
}
128+
129+
macro_rules! impl_solveh_ {
130+
($s:ty, $trs:path) => {
131+
impl SolvehImpl for $s {
132+
fn solveh(
133+
l: MatrixLayout,
134+
uplo: UPLO,
135+
a: &[Self],
136+
ipiv: &Pivot,
137+
b: &mut [Self],
138+
) -> Result<()> {
139+
let (n, _) = l.size();
140+
let mut info = 0;
141+
unsafe {
142+
$trs(
143+
uplo.as_ptr(),
144+
&n,
145+
&1,
146+
AsPtr::as_ptr(a),
147+
&l.lda(),
148+
ipiv.as_ptr(),
149+
AsPtr::as_mut_ptr(b),
150+
&n,
151+
&mut info,
152+
)
153+
};
154+
info.as_lapack_result()?;
155+
Ok(())
156+
}
157+
}
158+
};
159+
}
160+
161+
impl_solveh_!(c64, lapack_sys::zhetrs_);
162+
impl_solveh_!(c32, lapack_sys::chetrs_);
163+
impl_solveh_!(f64, lapack_sys::dsytrs_);
164+
impl_solveh_!(f32, lapack_sys::ssytrs_);
165+
5166
#[cfg_attr(doc, katexit::katexit)]
6167
/// Solve symmetric/hermite indefinite linear problem using the [Bunch-Kaufman diagonal pivoting method][BK].
7168
///

0 commit comments

Comments
 (0)