@@ -41,9 +41,11 @@ main (int argc, char *argv[])
4141 int i , j , l ;
4242 blasint x , y ;
4343 int ret = 0 ;
44+ int rret = 0 ;
4445 int loop = SHGEMM_LARGEST ;
4546 char transA = 'N' , transB = 'N' ;
4647 float alpha = 1.0 , beta = 0.0 ;
48+ int xvals [6 ]= {3 ,24 ,55 ,71 ,SHGEMM_LARGEST /2 ,SHGEMM_LARGEST };
4749
4850 for (x = 0 ; x <= loop ; x ++ )
4951 {
@@ -52,8 +54,8 @@ main (int argc, char *argv[])
5254 float * A = (float * )malloc_safe (m * k * sizeof (FLOAT ));
5355 float * B = (float * )malloc_safe (k * n * sizeof (FLOAT ));
5456 float * C = (float * )malloc_safe (m * n * sizeof (FLOAT ));
55- hfloat16 * AA = (hfloat16 * )malloc_safe (m * k * sizeof (hfloat16 ));
56- hfloat16 * BB = (hfloat16 * )malloc_safe (k * n * sizeof (hfloat16 ));
57+ _Float16 * AA = (_Float16 * )malloc_safe (m * k * sizeof (_Float16 ));
58+ _Float16 * BB = (_Float16 * )malloc_safe (k * n * sizeof (_Float16 ));
5759 float * DD = (float * )malloc_safe (m * n * sizeof (FLOAT ));
5860 float * CC = (float * )malloc_safe (m * n * sizeof (FLOAT ));
5961 if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
@@ -65,15 +67,15 @@ main (int argc, char *argv[])
6567 for (i = 0 ; i < k ; i ++ )
6668 {
6769 A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
68- AA [j * k + i ] = (hfloat16 ) A [j * k + i ];
70+ AA [j * k + i ] = (_Float16 ) A [j * k + i ];
6971 }
7072 }
7173 for (j = 0 ; j < n ; j ++ )
7274 {
7375 for (i = 0 ; i < k ; i ++ )
7476 {
7577 B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
76- BB [j * k + i ] = (hfloat16 ) B [j * k + i ];
78+ BB [j * k + i ] = (_Float16 ) B [j * k + i ];
7779 }
7880 }
7981 for (y = 0 ; y < 4 ; y ++ )
@@ -95,8 +97,8 @@ main (int argc, char *argv[])
9597
9698 SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
9799 & m , B , & k , & beta , C , & m );
98- SHGEMM (& transA , & transB , & m , & n , & k , & alpha , AA ,
99- & m , BB , & k , & beta , CC , & m );
100+ SHGEMM (& transA , & transB , & m , & n , & k , & alpha , ( _Float16 * ) AA ,
101+ & m , ( _Float16 * ) BB , & k , & beta , CC , & m );
100102
101103 for (i = 0 ; i < n ; i ++ )
102104 for (j = 0 ; j < m ; j ++ )
@@ -120,9 +122,11 @@ main (int argc, char *argv[])
120122 (float )AA [k * j + l ] * (float )BB [i + l * n ];
121123 }
122124 if (!is_close (CC [i * m + j ], C [i * m + j ], 0.01 , 0.001 )) {
125+ fprintf (stderr ,"CC %f C %f \n" ,(float )CC [i * m + j ],C [i * m + j ]);
123126 ret ++ ;
124127 }
125128 if (!is_close (CC [i * m + j ], DD [i * m + j ], 0.001 , 0.0001 )) {
129+ fprintf (stderr ,"CC %f DD %f \n" ,(float )CC [i * m + j ],(float )DD [i * m + j ]);
126130 ret ++ ;
127131 }
128132 }
@@ -135,11 +139,96 @@ main (int argc, char *argv[])
135139 free (DD );
136140 free (CC );
137141 }
138-
139142 if (ret != 0 ) {
140- fprintf (stderr , "SHGEMM FAILURES: %d\n" , ret );
143+ fprintf (stderr , "SHGEMM FAILURES: %d!!! \n" , ret );
141144 return 1 ;
142145 }
143146
144- return ret ;
147+
148+ for (loop = 0 ; loop < 6 ; loop ++ ) {
149+ x = xvals [loop ];
150+ for (alpha = 0. ;alpha <=1. ;alpha += 0.5 )
151+ {
152+ for (beta = 0.0 ; beta <=1. ; beta += 0.5 ) {
153+
154+ m = k = n = x ;
155+ float * A = (float * )malloc_safe (m * k * sizeof (FLOAT ));
156+ float * B = (float * )malloc_safe (k * n * sizeof (FLOAT ));
157+ float * C = (float * )malloc_safe (m * n * sizeof (FLOAT ));
158+ _Float16 * AA = (_Float16 * )malloc_safe (m * k * sizeof (_Float16 ));
159+ _Float16 * BB = (_Float16 * )malloc_safe (k * n * sizeof (_Float16 ));
160+ float * CC = (float * )malloc_safe (m * n * sizeof (FLOAT ));
161+ if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
162+ (CC == NULL ))
163+ return 1 ;
164+
165+ for (j = 0 ; j < m ; j ++ )
166+ {
167+ for (i = 0 ; i < k ; i ++ )
168+ {
169+ A [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
170+ AA [j * k + i ] = (_Float16 ) A [j * k + i ];
171+ }
172+ }
173+ for (j = 0 ; j < n ; j ++ )
174+ {
175+ for (i = 0 ; i < k ; i ++ )
176+ {
177+ B [j * k + i ] = ((FLOAT ) rand () / (FLOAT ) RAND_MAX ) + 0.5 ;
178+ BB [j * k + i ] = (_Float16 ) B [j * k + i ];
179+ }
180+ }
181+
182+ for (y = 0 ; y < 4 ; y ++ )
183+ {
184+ if ((y == 0 ) || (y == 2 )) {
185+ transA = 'N' ;
186+ } else {
187+ transA = 'T' ;
188+ }
189+ if ((y == 0 ) || (y == 1 )) {
190+ transB = 'N' ;
191+ } else {
192+ transB = 'T' ;
193+ }
194+
195+ memset (CC , 0 , m * n * sizeof (FLOAT ));
196+ memset (C , 0 , m * n * sizeof (FLOAT ));
197+
198+ SGEMM (& transA , & transB , & m , & n , & k , & alpha , A ,
199+ & m , B , & k , & beta , C , & m );
200+ SHGEMM (& transA , & transB , & m , & n , & k , & alpha , (_Float16 * ) AA ,
201+ & m , (_Float16 * )BB , & k , & beta , CC , & m );
202+
203+ for (i = 0 ; i < n ; i ++ )
204+ for (j = 0 ; j < m ; j ++ )
205+ {
206+ if (!is_close (CC [i * m + j ], C [i * m + j ], 0.01 , 0.001 )) {
207+ ret ++ ;
208+ }
209+ }
210+ }
211+ free (A );
212+ free (B );
213+ free (C );
214+ free (AA );
215+ free (BB );
216+ free (CC );
217+
218+ if (ret != 0 ) {
219+ /*
220+ * fprintf(stderr, "SHGEMM FAILURES FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret);
221+ */
222+ rret ++ ;
223+ ret = 0 ;
224+ /* } else {
225+ fprintf(stderr, "SHGEMM SUCCEEDED FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret);
226+ */
227+ }
228+ }
229+
230+ }
231+ }
232+ if (rret > 0 ) return (1 );
233+ return (0 );
145234}
0 commit comments