Skip to content

Commit 7fbe2fe

Browse files
committed
TEST: Add numeric test that uses general_mat_mul with strided output
1 parent d8078f6 commit 7fbe2fe

File tree

1 file changed

+72
-0
lines changed

1 file changed

+72
-0
lines changed

numeric-tests/tests/accuracy.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use ndarray::{
1111
Data,
1212
LinalgScalar,
1313
};
14+
use ndarray::linalg::general_mat_mul;
1415

1516
use rand::distributions::Normal;
1617

@@ -161,6 +162,42 @@ fn accurate_mul_f32() {
161162
}
162163
}
163164

165+
#[test]
166+
fn accurate_mul_f32_general() {
167+
// pick a few random sizes
168+
let mut rng = SmallRng::from_entropy();
169+
for i in 0..20 {
170+
let m = rng.gen_range(15, 512);
171+
let k = rng.gen_range(15, 512);
172+
let n = rng.gen_range(15, 1560);
173+
let a = gen(Ix2(m, k));
174+
let b = gen(Ix2(n, k));
175+
let mut c = gen(Ix2(m, n));
176+
let b = b.t();
177+
let (a, b, mut c) = if i > 10 {
178+
(a.slice(s![..;2, ..;2]),
179+
b.slice(s![..;2, ..;2]),
180+
c.slice_mut(s![..;2, ..;2]))
181+
} else { (a.view(), b, c.view_mut()) };
182+
183+
println!("Testing size {} by {} by {}", a.shape()[0], a.shape()[1], b.shape()[1]);
184+
general_mat_mul(1., &a, &b, 0., &mut c);
185+
let reference = reference_mat_mul(&a, &b);
186+
let diff = (&c - &reference).mapv_into(f32::abs);
187+
188+
let rtol = 1e-3;
189+
let atol = 1e-4;
190+
let crtol = c.mapv(|x| x.abs() * rtol);
191+
let tol = crtol + atol;
192+
let tol_m_diff = &diff - &tol;
193+
let maxdiff = *tol_m_diff.max();
194+
println!("diff offset from tolerance level= {:.2e}", maxdiff);
195+
if maxdiff > 0. {
196+
panic!("results differ");
197+
}
198+
}
199+
}
200+
164201
#[test]
165202
fn accurate_mul_f64() {
166203
// pick a few random sizes
@@ -195,6 +232,41 @@ fn accurate_mul_f64() {
195232
}
196233
}
197234

235+
#[test]
236+
fn accurate_mul_f64_general() {
237+
// pick a few random sizes
238+
let mut rng = SmallRng::from_entropy();
239+
for i in 0..20 {
240+
let m = rng.gen_range(15, 512);
241+
let k = rng.gen_range(15, 512);
242+
let n = rng.gen_range(15, 1560);
243+
let a = gen_f64(Ix2(m, k));
244+
let b = gen_f64(Ix2(n, k));
245+
let mut c = gen_f64(Ix2(m, n));
246+
let b = b.t();
247+
let (a, b, mut c) = if i > 10 {
248+
(a.slice(s![..;2, ..;2]),
249+
b.slice(s![..;2, ..;2]),
250+
c.slice_mut(s![..;2, ..;2]))
251+
} else { (a.view(), b, c.view_mut()) };
252+
253+
println!("Testing size {} by {} by {}", a.shape()[0], a.shape()[1], b.shape()[1]);
254+
general_mat_mul(1., &a, &b, 0., &mut c);
255+
let reference = reference_mat_mul(&a, &b);
256+
let diff = (&c - &reference).mapv_into(f64::abs);
257+
258+
let rtol = 1e-7;
259+
let atol = 1e-12;
260+
let crtol = c.mapv(|x| x.abs() * rtol);
261+
let tol = crtol + atol;
262+
let tol_m_diff = &diff - &tol;
263+
let maxdiff = *tol_m_diff.max();
264+
println!("diff offset from tolerance level= {:.2e}", maxdiff);
265+
if maxdiff > 0. {
266+
panic!("results differ");
267+
}
268+
}
269+
}
198270

199271
#[test]
200272
fn accurate_mul_with_column_f64() {

0 commit comments

Comments
 (0)