@@ -85,6 +85,14 @@ float16to32 (bfloat16_bits f16)
8585
8686#define SBGEMM_LARGEST 256
8787
88+ void * malloc_safe (size_t size )
89+ {
90+ if (size == 0 )
91+ return malloc (1 );
92+ else
93+ return malloc (size );
94+ }
95+
8896int
8997main (int argc , char * argv [])
9098{
@@ -100,13 +108,13 @@ main (int argc, char *argv[])
100108 {
101109 if ((x > 100 ) && (x != SBGEMM_LARGEST )) continue ;
102110 m = k = n = x ;
103- float * A = (float * )malloc (m * k * sizeof (FLOAT ));
104- float * B = (float * )malloc (k * n * sizeof (FLOAT ));
105- float * C = (float * )malloc (m * n * sizeof (FLOAT ));
106- bfloat16_bits * AA = (bfloat16_bits * )malloc (m * k * sizeof (bfloat16_bits ));
107- bfloat16_bits * BB = (bfloat16_bits * )malloc (k * n * sizeof (bfloat16_bits ));
108- float * DD = (float * )malloc (m * n * sizeof (FLOAT ));
109- float * CC = (float * )malloc (m * n * sizeof (FLOAT ));
111+ float * A = (float * )malloc_safe (m * k * sizeof (FLOAT ));
112+ float * B = (float * )malloc_safe (k * n * sizeof (FLOAT ));
113+ float * C = (float * )malloc_safe (m * n * sizeof (FLOAT ));
114+ bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (m * k * sizeof (bfloat16_bits ));
115+ bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (k * n * sizeof (bfloat16_bits ));
116+ float * DD = (float * )malloc_safe (m * n * sizeof (FLOAT ));
117+ float * CC = (float * )malloc_safe (m * n * sizeof (FLOAT ));
110118 if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
111119 (DD == NULL ) || (CC == NULL ))
112120 return 1 ;
@@ -194,16 +202,16 @@ main (int argc, char *argv[])
194202 return ret ;
195203 }
196204
197- k = 1 ;
198205 for (x = 1 ; x <= loop ; x ++ )
199206 {
200- float * A = (float * )malloc (x * x * sizeof (FLOAT ));
201- float * B = (float * )malloc (x * sizeof (FLOAT ));
202- float * C = (float * )malloc (x * sizeof (FLOAT ));
203- bfloat16_bits * AA = (bfloat16_bits * )malloc (x * x * sizeof (bfloat16_bits ));
204- bfloat16_bits * BB = (bfloat16_bits * )malloc (x * sizeof (bfloat16_bits ));
205- float * DD = (float * )malloc (x * sizeof (FLOAT ));
206- float * CC = (float * )malloc (x * sizeof (FLOAT ));
207+ k = (x == 0 ) ? 0 : 1 ;
208+ float * A = (float * )malloc_safe (x * x * sizeof (FLOAT ));
209+ float * B = (float * )malloc_safe (x * sizeof (FLOAT ));
210+ float * C = (float * )malloc_safe (x * sizeof (FLOAT ));
211+ bfloat16_bits * AA = (bfloat16_bits * )malloc_safe (x * x * sizeof (bfloat16_bits ));
212+ bfloat16_bits * BB = (bfloat16_bits * )malloc_safe (x * sizeof (bfloat16_bits ));
213+ float * DD = (float * )malloc_safe (x * sizeof (FLOAT ));
214+ float * CC = (float * )malloc_safe (x * sizeof (FLOAT ));
207215 if ((A == NULL ) || (B == NULL ) || (C == NULL ) || (AA == NULL ) || (BB == NULL ) ||
208216 (DD == NULL ) || (CC == NULL ))
209217 return 1 ;
0 commit comments