@@ -27,72 +27,15 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727#include <stdio.h>
2828#include <stdint.h>
2929#include "../common.h"
30+
31+ #include "test_helpers.h"
32+
3033#define SGEMM BLASFUNC(sgemm)
3134#define SBGEMM BLASFUNC(sbgemm)
3235#define SGEMV BLASFUNC(sgemv)
3336#define SBGEMV BLASFUNC(sbgemv)
34- typedef union
35- {
36- unsigned short v ;
37- #if defined(_AIX )
38- struct __attribute__((packed ))
39- #else
40- struct
41- #endif
42- {
43- #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
44- unsigned short s :1 ;
45- unsigned short e :8 ;
46- unsigned short m :7 ;
47- #else
48- unsigned short m :7 ;
49- unsigned short e :8 ;
50- unsigned short s :1 ;
51- #endif
52- } bits ;
53- } bfloat16_bits ;
54-
55- typedef union
56- {
57- float v ;
58- #if defined(_AIX )
59- struct __attribute__((packed ))
60- #else
61- struct
62- #endif
63- {
64- #if __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
65- uint32_t s :1 ;
66- uint32_t e :8 ;
67- uint32_t m :23 ;
68- #else
69- uint32_t m :23 ;
70- uint32_t e :8 ;
71- uint32_t s :1 ;
72- #endif
73- } bits ;
74- } float32_bits ;
75-
76- float
77- float16to32 (bfloat16_bits f16 )
78- {
79- float32_bits f32 ;
80- f32 .bits .s = f16 .bits .s ;
81- f32 .bits .e = f16 .bits .e ;
82- f32 .bits .m = (uint32_t ) f16 .bits .m << 16 ;
83- return f32 .v ;
84- }
85-
8637#define SBGEMM_LARGEST 256
8738
88- void * malloc_safe (size_t size )
89- {
90- if (size == 0 )
91- return malloc (1 );
92- else
93- return malloc (size );
94- }
95-
9639int
9740main (int argc , char * argv [])
9841{
@@ -111,32 +54,29 @@ main (int argc, char *argv[])
11154 float * A = (float * )malloc_safe (m * k * sizeof (FLOAT ));
11255 float * B = (float * )malloc_safe (k * n * sizeof (FLOAT ));
11356 float * C = (float * )malloc_safe (m * n * sizeof (FLOAT ));
114- bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (m * k * sizeof (bfloat16_bits ));
115- bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (k * n * sizeof (bfloat16_bits ));
57+ bfloat16 * AA = (bfloat16 * )malloc_safe (m * k * sizeof (bfloat16 ));
58+ bfloat16 * BB = (bfloat16 * )malloc_safe (k * n * sizeof (bfloat16 ));
11659 float * DD = (float * )malloc_safe (m * n * sizeof (FLOAT ));
11760 float * CC = (float * )malloc_safe (m * n * sizeof (FLOAT ));
11861 if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
11962 (DD == NULL ) || (CC == NULL ))
12063 return 1 ;
121- bfloat16 atmp ,btmp ;
12264 blasint one = 1 ;
12365
12466 for (j = 0 ; j < m ; j ++ )
12567 {
12668 for (i = 0 ; i < k ; i ++ )
12769 {
12870 A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
129- sbstobf16_ (& one , & A [j * k + i ], & one , & atmp , & one );
130- AA [j * k + i ].v = atmp ;
71+ sbstobf16_ (& one , & A [j * k + i ], & one , & AA [j * k + i ], & one );
13172 }
13273 }
13374 for (j = 0 ; j < n ; j ++ )
13475 {
13576 for (i = 0 ; i < k ; i ++ )
13677 {
13778 B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
138- sbstobf16_ (& one , & B [j * k + i ], & one , & btmp , & one );
139- BB [j * k + i ].v = btmp ;
79+ sbstobf16_ (& one , & B [j * k + i ], & one , & BB [j * k + i ], & one );
14080 }
14181 }
14282 for (y = 0 ; y < 4 ; y ++ )
@@ -182,10 +122,12 @@ main (int argc, char *argv[])
182122 DD [i * m + j ] +=
183123 float16to32 (AA [k * j + l ]) * float16to32 (BB [i + l * n ]);
184124 }
185- if (fabs (CC [i * m + j ] - C [i * m + j ]) > 1.0 )
125+ if (! is_close (CC [i * m + j ], C [i * m + j ], 0.01 , 0.001 )) {
186126 ret ++ ;
187- if (fabs (CC [i * m + j ] - DD [i * m + j ]) > 1.0 )
127+ }
128+ if (!is_close (CC [i * m + j ], DD [i * m + j ], 0.001 , 0.0001 )) {
188129 ret ++ ;
130+ }
189131 }
190132 }
191133 free (A );
@@ -211,27 +153,24 @@ main (int argc, char *argv[])
211153 float * A = (float * )malloc_safe (x * x * sizeof (FLOAT ));
212154 float * B = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
213155 float * C = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
214- bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (x * x * sizeof (bfloat16_bits ));
215- bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (x * sizeof (bfloat16_bits ) << l );
156+ bfloat16 * AA = (bfloat16 * )malloc_safe (x * x * sizeof (bfloat16 ));
157+ bfloat16 * BB = (bfloat16 * )malloc_safe (x * sizeof (bfloat16 ) << l );
216158 float * DD = (float * )malloc_safe (x * sizeof (FLOAT ));
217159 float * CC = (float * )malloc_safe (x * sizeof (FLOAT ) << l );
218160 if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
219161 (DD == NULL ) || (CC == NULL ))
220162 return 1 ;
221- bfloat16 atmp , btmp ;
222163 blasint one = 1 ;
223164
224165 for (j = 0 ; j < x ; j ++ )
225166 {
226167 for (i = 0 ; i < x ; i ++ )
227168 {
228169 A [j * x + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
229- sbstobf16_ (& one , & A [j * x + i ], & one , & atmp , & one );
230- AA [j * x + i ].v = atmp ;
170+ sbstobf16_ (& one , & A [j * x + i ], & one , & AA [j * x + i ], & one );
231171 }
232172 B [j << l ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
233- sbstobf16_ (& one , & B [j << l ], & one , & btmp , & one );
234- BB [j << l ].v = btmp ;
173+ sbstobf16_ (& one , & B [j << l ], & one , & BB [j << l ], & one );
235174
236175 CC [j << l ] = C [j << l ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
237176 }
@@ -262,10 +201,12 @@ main (int argc, char *argv[])
262201 }
263202
264203 for (j = 0 ; j < x ; j ++ ) {
265- if (fabs (CC [j << l ] - C [j << l ]) > 1.0 )
204+ if (! is_close (CC [j << l ], C [j << l ], 0.01 , 0.001 )) {
266205 ret ++ ;
267- if (fabs (CC [j << l ] - DD [j ]) > 1.0 )
206+ }
207+ if (!is_close (CC [j << l ], DD [j ], 0.001 , 0.0001 )) {
268208 ret ++ ;
209+ }
269210 }
270211 }
271212 free (A );
0 commit comments