11// Code taken from the `packed_simd` crate
22// Run this code with `cargo test --example dot_product`
3+ //use core::iter::zip;
4+ //use std::iter::zip;
5+
36#![ feature( array_chunks) ]
7+ #![ feature( slice_as_chunks) ]
8+ // Add these imports to use the stdsimd library
9+ #![ feature( portable_simd) ]
410use core_simd:: * ;
511
6- /// This is your barebones dot product implementation:
7- /// Take 2 vectors, multiply them element wise and *then*
8- /// add up the result. In the next example we will see if there
9- /// is any difference to adding as we go along multiplying.
12+ // This is your barebones dot product implementation:
13+ // Take 2 vectors, multiply them element wise and *then*
14+ // go along the resulting array and add up the result.
15+ // In the next example we will see if there
16+ // is any difference to adding and multiplying in tandem.
1017pub fn dot_prod_0 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
1118 assert_eq ! ( a. len( ) , b. len( ) ) ;
1219
1320 a. iter ( )
1421 . zip ( b. iter ( ) )
15- . map ( |a, b| a * b)
22+ . map ( |( a, b) | a * b)
1623 . sum ( )
1724}
1825
26+ // When dealing with SIMD, it is very important to think about the amount
27+ // of data movement and when it happens. We're going over simple computation examples here, and yet
28+ // it is not trivial to understand what may or may not contribute to performance
29+ // changes. Eventually, you will need tools to inspect the generated assembly and confirm your
30+ // hypothesis and benchmarks - we will mention them later on.
31+ // With the use of `fold`, we're doing a multiplication,
32+ // and then adding it to the sum, one element from both vectors at a time.
1933pub fn dot_prod_1 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
2034 assert_eq ! ( a. len( ) , b. len( ) ) ;
2135 a. iter ( )
2236 . zip ( b. iter ( ) )
23- . fold ( 0.0 , |a, b| a * b )
37+ . fold ( 0.0 , |a, zipped| { a + zipped . 0 * zipped . 1 } )
2438}
2539
40+ // We now move on to the SIMD implementations: notice the following constructs:
41+ // `array_chunks::<4>`: mapping this over the vector will let use construct SIMD vectors
42+ // `f32x4::from_array`: construct the SIMD vector from a slice
43+ // `(a * b).reduce_sum()`: Multiply both f32x4 vectors together, and then reduce them.
44+ // This approach essentially uses SIMD to produce a vector of length N/4 of all the products,
45+ // and then add those with `sum()`. This is suboptimal.
46+ // TODO: ASCII diagrams
2647pub fn dot_prod_simd_0 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
2748 assert_eq ! ( a. len( ) , b. len( ) ) ;
28-
2949 // TODO handle remainder when a.len() % 4 != 0
3050 a. array_chunks :: < 4 > ( )
3151 . map ( |& a| f32x4:: from_array ( a) )
3252 . zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
33- . map ( |( a, b) | ( a * b) . horizontal_sum ( ) )
53+ . map ( |( a, b) | ( a * b) . reduce_sum ( ) )
3454 . sum ( )
3555}
3656
57+ // There's some simple ways to improve the previous code:
58+ // 1. Make a `zero` `f32x4` SIMD vector that we will be accumulating into
59+ // So that there is only one `sum()` reduction when the last `f32x4` has been processed
60+ // 2. Exploit Fused Multiply Add so that the multiplication, addition and sinking into the reduciton
61+ // happen in the same step.
62+ // If the arrays are large, minimizing the data shuffling will lead to great perf.
63+ // If the arrays are small, handling the remainder elements when the length isn't a multiple of 4
64+ // Can become a problem.
65+ pub fn dot_prod_simd_1 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
66+ assert_eq ! ( a. len( ) , b. len( ) ) ;
67+ // TODO handle remainder when a.len() % 4 != 0
68+ a. array_chunks :: < 4 > ( )
69+ . map ( |& a| f32x4:: from_array ( a) )
70+ . zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
71+ . fold ( f32x4:: splat ( 0.0 ) , |acc, zipped| { acc + zipped. 0 * zipped. 1 } )
72+ . reduce_sum ( )
73+ }
74+
75+ //
76+ pub fn dot_prod_simd_2 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
77+ assert_eq ! ( a. len( ) , b. len( ) ) ;
78+ // TODO handle remainder when a.len() % 4 != 0
79+ a. array_chunks :: < 4 > ( )
80+ . map ( |& a| f32x4:: from_array ( a) )
81+ . zip ( b. array_chunks :: < 4 > ( ) . map ( |& b| f32x4:: from_array ( b) ) )
82+ . fold ( f32x4:: splat ( 0.0 ) , |acc, zipped| { acc + zipped. 0 * zipped. 1 } )
83+ . reduce_sum ( )
84+ }
85+
86+ const LANES : usize = 4 ;
87+ pub fn dot_prod_simd_3 ( a : & [ f32 ] , b : & [ f32 ] ) -> f32 {
88+ assert_eq ! ( a. len( ) , b. len( ) ) ;
89+
90+ let ( a_extra, a_chunks) = a. as_rchunks ( ) ;
91+ let ( b_extra, b_chunks) = b. as_rchunks ( ) ;
92+
93+ // These are always true, but for emphasis:
94+ assert_eq ! ( a_chunks. len( ) , b_chunks. len( ) ) ;
95+ assert_eq ! ( a_extra. len( ) , b_extra. len( ) ) ;
96+
97+ let mut sums = [ 0.0 ; LANES ] ;
98+ for ( ( x, y) , d) in std:: iter:: zip ( a_extra, b_extra) . zip ( & mut sums) {
99+ * d = x * y;
100+ }
101+
102+ let mut sums = f32x4:: from_array ( sums) ;
103+ std:: iter:: zip ( a_chunks, b_chunks)
104+ . for_each ( |( x, y) | {
105+ sums += f32x4:: from_array ( * x) * f32x4:: from_array ( * y) ;
106+ } ) ;
107+
108+ sums. reduce_sum ( )
109+ }
37110fn main ( ) {
38111 // Empty main to make cargo happy
39112}
@@ -45,10 +118,15 @@ mod tests {
45118 use super :: * ;
46119 let a: Vec < f32 > = vec ! [ 1.0 , 2.0 , 3.0 , 4.0 , 5.0 , 6.0 , 7.0 , 8.0 ] ;
47120 let b: Vec < f32 > = vec ! [ -8.0 , -7.0 , -6.0 , -5.0 , 4.0 , 3.0 , 2.0 , 1.0 ] ;
121+ let x: Vec < f32 > = [ 0.5 ; 1003 ] . to_vec ( ) ;
122+ let y: Vec < f32 > = [ 2.0 ; 1003 ] . to_vec ( ) ;
48123
49124 assert_eq ! ( 0.0 , dot_prod_0( & a, & b) ) ;
50125 assert_eq ! ( 0.0 , dot_prod_1( & a, & b) ) ;
51126 assert_eq ! ( 0.0 , dot_prod_simd_0( & a, & b) ) ;
52127 assert_eq ! ( 0.0 , dot_prod_simd_1( & a, & b) ) ;
128+ assert_eq ! ( 0.0 , dot_prod_simd_2( & a, & b) ) ;
129+ assert_eq ! ( 0.0 , dot_prod_simd_3( & a, & b) ) ;
130+ assert_eq ! ( 1003.0 , dot_prod_simd_3( & x, & y) ) ;
53131 }
54132}
0 commit comments