|
5 | 5 | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your |
6 | 6 | // option. This file may not be copied, modified, or distributed |
7 | 7 | // except according to those terms. |
8 | | -use libnum; |
9 | | - |
10 | 8 | use std::cmp; |
11 | | -use std::ops::{ |
12 | | - Add, |
13 | | -}; |
14 | 9 |
|
15 | 10 | use LinalgScalar; |
16 | 11 |
|
17 | | -/// Compute the sum of the values in `xs` |
18 | | -pub fn unrolled_sum<A>(mut xs: &[A]) -> A |
19 | | - where A: Clone + Add<Output=A> + libnum::Zero, |
| 12 | +/// Fold over the manually unrolled `xs` with `f` |
| 13 | +pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A |
| 14 | + where A: Clone, |
| 15 | + I: Fn() -> A, |
| 16 | + F: Fn(A, A) -> A, |
20 | 17 | { |
21 | 18 | // eightfold unrolled so that floating point can be vectorized |
22 | 19 | // (even with strict floating point accuracy semantics) |
23 | | - let mut sum = A::zero(); |
| 20 | + let mut acc = init(); |
24 | 21 | let (mut p0, mut p1, mut p2, mut p3, |
25 | 22 | mut p4, mut p5, mut p6, mut p7) = |
26 | | - (A::zero(), A::zero(), A::zero(), A::zero(), |
27 | | - A::zero(), A::zero(), A::zero(), A::zero()); |
| 23 | + (init(), init(), init(), init(), |
| 24 | + init(), init(), init(), init()); |
28 | 25 | while xs.len() >= 8 { |
29 | | - p0 = p0 + xs[0].clone(); |
30 | | - p1 = p1 + xs[1].clone(); |
31 | | - p2 = p2 + xs[2].clone(); |
32 | | - p3 = p3 + xs[3].clone(); |
33 | | - p4 = p4 + xs[4].clone(); |
34 | | - p5 = p5 + xs[5].clone(); |
35 | | - p6 = p6 + xs[6].clone(); |
36 | | - p7 = p7 + xs[7].clone(); |
| 26 | + p0 = f(p0, xs[0].clone()); |
| 27 | + p1 = f(p1, xs[1].clone()); |
| 28 | + p2 = f(p2, xs[2].clone()); |
| 29 | + p3 = f(p3, xs[3].clone()); |
| 30 | + p4 = f(p4, xs[4].clone()); |
| 31 | + p5 = f(p5, xs[5].clone()); |
| 32 | + p6 = f(p6, xs[6].clone()); |
| 33 | + p7 = f(p7, xs[7].clone()); |
37 | 34 |
|
38 | 35 | xs = &xs[8..]; |
39 | 36 | } |
40 | | - sum = sum.clone() + (p0 + p4); |
41 | | - sum = sum.clone() + (p1 + p5); |
42 | | - sum = sum.clone() + (p2 + p6); |
43 | | - sum = sum.clone() + (p3 + p7); |
| 37 | + acc = f(acc.clone(), f(p0, p4)); |
| 38 | + acc = f(acc.clone(), f(p1, p5)); |
| 39 | + acc = f(acc.clone(), f(p2, p6)); |
| 40 | + acc = f(acc.clone(), f(p3, p7)); |
44 | 41 |
|
45 | 42 | // make it clear to the optimizer that this loop is short |
46 | 43 | // and can not be autovectorized. |
47 | 44 | for i in 0..xs.len() { |
48 | 45 | if i >= 7 { break; } |
49 | | - sum = sum.clone() + xs[i].clone() |
| 46 | + acc = f(acc.clone(), xs[i].clone()) |
50 | 47 | } |
51 | | - sum |
| 48 | + acc |
52 | 49 | } |
53 | 50 |
|
54 | 51 | /// Compute the dot product. |
|
0 commit comments