Skip to content

Commit 19be504

Browse files
authored
Add tests varying alpha and beta
1 parent 05adb52 commit 19be504

File tree

1 file changed

+98
-9
lines changed

1 file changed

+98
-9
lines changed

test/compare_sgemm_shgemm.c

Lines changed: 98 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ main (int argc, char *argv[])
4141
int i, j, l;
4242
blasint x, y;
4343
int ret = 0;
44+
int rret = 0;
4445
int loop = SHGEMM_LARGEST;
4546
char transA = 'N', transB = 'N';
4647
float alpha = 1.0, beta = 0.0;
48+
int xvals[6]={3,24,55,71,SHGEMM_LARGEST/2,SHGEMM_LARGEST};
4749

4850
for (x = 0; x <= loop; x++)
4951
{
@@ -52,8 +54,8 @@ main (int argc, char *argv[])
5254
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
5355
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
5456
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
55-
hfloat16 *AA = (hfloat16 *)malloc_safe(m * k * sizeof(hfloat16));
56-
hfloat16 *BB = (hfloat16 *)malloc_safe(k * n * sizeof(hfloat16));
57+
_Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16));
58+
_Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16));
5759
float *DD = (float *)malloc_safe(m * n * sizeof(FLOAT));
5860
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
5961
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
@@ -65,15 +67,15 @@ main (int argc, char *argv[])
6567
for (i = 0; i < k; i++)
6668
{
6769
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
68-
AA[j * k + i] = (hfloat16) A[j * k + i];
70+
AA[j * k + i] = (_Float16) A[j * k + i];
6971
}
7072
}
7173
for (j = 0; j < n; j++)
7274
{
7375
for (i = 0; i < k; i++)
7476
{
7577
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
76-
BB[j * k + i] = (hfloat16) B[j * k + i];
78+
BB[j * k + i] = (_Float16) B[j * k + i];
7779
}
7880
}
7981
for (y = 0; y < 4; y++)
@@ -95,8 +97,8 @@ main (int argc, char *argv[])
9597

9698
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
9799
&m, B, &k, &beta, C, &m);
98-
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, AA,
99-
&m, BB, &k, &beta, CC, &m);
100+
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA,
101+
&m, (_Float16*)BB, &k, &beta, CC, &m);
100102

101103
for (i = 0; i < n; i++)
102104
for (j = 0; j < m; j++)
@@ -120,9 +122,11 @@ main (int argc, char *argv[])
120122
(float)AA[k * j + l] * (float)BB[i + l * n];
121123
}
122124
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
125+
fprintf(stderr,"CC %f C %f \n",(float)CC[i*m+j],C[i*m+j]);
123126
ret++;
124127
}
125128
if (!is_close(CC[i * m + j], DD[i * m + j], 0.001, 0.0001)) {
129+
fprintf(stderr,"CC %f DD %f \n",(float)CC[i*m+j],(float)DD[i*m+j]);
126130
ret++;
127131
}
128132
}
@@ -135,11 +139,96 @@ main (int argc, char *argv[])
135139
free(DD);
136140
free(CC);
137141
}
138-
139142
if (ret != 0) {
140-
fprintf(stderr, "SHGEMM FAILURES: %d\n", ret);
143+
fprintf(stderr, "SHGEMM FAILURES: %d!!!\n", ret);
141144
return 1;
142145
}
143146

144-
return ret;
147+
148+
for (loop = 0; loop<6; loop++) {
149+
x=xvals[loop];
150+
for (alpha=0.;alpha<=1.;alpha+=0.5)
151+
{
152+
for (beta = 0.0; beta <=1.; beta+=0.5) {
153+
154+
m = k = n = x;
155+
float *A = (float *)malloc_safe(m * k * sizeof(FLOAT));
156+
float *B = (float *)malloc_safe(k * n * sizeof(FLOAT));
157+
float *C = (float *)malloc_safe(m * n * sizeof(FLOAT));
158+
_Float16 *AA = (_Float16 *)malloc_safe(m * k * sizeof(_Float16));
159+
_Float16 *BB = (_Float16 *)malloc_safe(k * n * sizeof(_Float16));
160+
float *CC = (float *)malloc_safe(m * n * sizeof(FLOAT));
161+
if ((A == NULL) || (B == NULL) || (C == NULL) || (AA == NULL) || (BB == NULL) ||
162+
(CC == NULL))
163+
return 1;
164+
165+
for (j = 0; j < m; j++)
166+
{
167+
for (i = 0; i < k; i++)
168+
{
169+
A[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
170+
AA[j * k + i] = (_Float16) A[j * k + i];
171+
}
172+
}
173+
for (j = 0; j < n; j++)
174+
{
175+
for (i = 0; i < k; i++)
176+
{
177+
B[j * k + i] = ((FLOAT) rand () / (FLOAT) RAND_MAX) + 0.5;
178+
BB[j * k + i] = (_Float16) B[j * k + i];
179+
}
180+
}
181+
182+
for (y = 0; y < 4; y++)
183+
{
184+
if ((y == 0) || (y == 2)) {
185+
transA = 'N';
186+
} else {
187+
transA = 'T';
188+
}
189+
if ((y == 0) || (y == 1)) {
190+
transB = 'N';
191+
} else {
192+
transB = 'T';
193+
}
194+
195+
memset(CC, 0, m * n * sizeof(FLOAT));
196+
memset(C, 0, m * n * sizeof(FLOAT));
197+
198+
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
199+
&m, B, &k, &beta, C, &m);
200+
SHGEMM (&transA, &transB, &m, &n, &k, &alpha, (_Float16*) AA,
201+
&m, (_Float16*)BB, &k, &beta, CC, &m);
202+
203+
for (i = 0; i < n; i++)
204+
for (j = 0; j < m; j++)
205+
{
206+
if (!is_close(CC[i * m + j], C[i * m + j], 0.01, 0.001)) {
207+
ret++;
208+
}
209+
}
210+
}
211+
free(A);
212+
free(B);
213+
free(C);
214+
free(AA);
215+
free(BB);
216+
free(CC);
217+
218+
if (ret != 0) {
219+
/*
220+
* fprintf(stderr, "SHGEMM FAILURES FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret);
221+
*/
222+
rret++;
223+
ret=0;
224+
/* } else {
225+
fprintf(stderr, "SHGEMM SUCCEEDED FOR n=%d, alpha=%f beta=%f : %d\n", x, alpha, beta, ret);
226+
*/
227+
}
228+
}
229+
230+
}
231+
}
232+
if (rret > 0) return(1);
233+
return(0);
145234
}

0 commit comments

Comments
 (0)