Skip to content

Commit 8918247

Browse files
authored
Merge pull request #5483 from Mousius/bgemm-correctness
Fix bf16->f32 conversion for NEOVERSEV1 and NEOVERSEN2 targets
2 parents 106fabc + 958f721 commit 8918247

File tree

8 files changed

+43
-16
lines changed

8 files changed

+43
-16
lines changed

kernel/arm64/KERNEL.NEOVERSEN2

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ ZGEMMONCOPYOBJ = zgemm_oncopy$(TSUFFIX).$(SUFFIX)
189189
ZGEMMOTCOPYOBJ = zgemm_otcopy$(TSUFFIX).$(SUFFIX)
190190

191191
ifeq ($(BUILD_BFLOAT16), 1)
192-
BGEMM_BETA = sbgemm_beta_neoversen2.c
192+
BGEMM_BETA = bgemm_beta_neon.c
193193
BGEMMKERNEL = sbgemm_kernel_$(BGEMM_UNROLL_M)x$(BGEMM_UNROLL_N)_neoversen2.c
194194
BGEMMINCOPY = sbgemm_ncopy_$(BGEMM_UNROLL_M)_neoversen2.c
195195
BGEMMITCOPY = sbgemm_tcopy_$(BGEMM_UNROLL_M)_neoversen2.c

kernel/arm64/bgemm_kernel_2vlx4_neoversev1_impl.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@
4040

4141
#define UPDATE_C(PG, PTR, DST, SRC) \
4242
do { \
43-
DST = svreinterpret_f32_u32(svld1uh_u32((pghalf), (uint16_t*)PTR)); \
43+
svtmp16 = svld1_bf16((pghalf), (PTR)); \
44+
DST = svreinterpret_f32(svzip1_bf16(zeros, svtmp16)); \
4445
DST = svadd_z((PG), SRC, DST); \
4546
svtmp16 = svcvt_bf16_f32_z((PG), DST); \
4647
svtmp16 = svuzp1_bf16(svtmp16, svtmp16); \
@@ -55,7 +56,8 @@
5556

5657
#define UPDATE_C(PG, PTR, DST, SRC) \
5758
do { \
58-
DST = svreinterpret_f32_u32(svld1uh_u32((pghalf), (uint16_t*)PTR)); \
59+
svtmp16 = svld1_bf16((pghalf), (PTR)); \
60+
DST = svreinterpret_f32(svzip1_bf16(zeros, svtmp16)); \
5961
DST = svmad_z((PG), svalpha, SRC, DST); \
6062
svtmp16 = svcvt_bf16_f32_z((PG), DST); \
6163
svtmp16 = svuzp1_bf16(svtmp16, svtmp16); \
@@ -133,6 +135,7 @@ static int bgemm_kernel_neoversev1_alpha(BLASLONG m, BLASLONG n, BLASLONG k,
133135
OUTPUT_FLOAT *ptr_c0, *ptr_c1, *ptr_c2, *ptr_c3;
134136
svfloat32_t tmp0, tmp1, tmp2, tmp3;
135137
#ifdef BGEMM
138+
svbfloat16_t zeros = svdup_n_bf16(TO16(0.0));
136139
svbfloat16_t svtmp16;
137140
#else
138141
float32x2_t tmp4, tmp5, tmp6, tmp7;

kernel/arm64/sbgemm_kernel_8x4_neoversen2_impl.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
#ifdef ALPHA_ONE
5252
#define UPDATE_C(PG16, PG32, PTR, SRC) \
5353
do { \
54-
tmp32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \
54+
tmp16 = svld1_bf16((PG16), (PTR)); \
55+
tmp32 = svreinterpret_f32(svzip1_bf16(zeros, tmp16)); \
5556
tmp32 = svadd_z((PG32), SRC, tmp32); \
5657
tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \
5758
tmp16 = svuzp1_bf16(tmp16, tmp16); \
@@ -60,7 +61,8 @@
6061
#else
6162
#define UPDATE_C(PG16, PG32, PTR, SRC) \
6263
do { \
63-
tmp32 = svreinterpret_f32_u32(svld1uh_u32((PG16), (uint16_t*)PTR)); \
64+
tmp16 = svld1_bf16((PG16), (PTR)); \
65+
tmp32 = svreinterpret_f32(svzip1_bf16(zeros, tmp16)); \
6466
tmp32 = svmad_z((PG32), svalpha, SRC, tmp32); \
6567
tmp16 = svcvt_bf16_f32_z((PG32), tmp32); \
6668
tmp16 = svuzp1_bf16(tmp16, tmp16); \
@@ -121,6 +123,7 @@ static int gemm_kernel_neoversen2_alpha(BLASLONG m, BLASLONG n, BLASLONG k, FLOA
121123
#ifdef BGEMM
122124
svbool_t pg16_first_2 = svdupq_b16(1, 1, 0, 0, 0, 0, 0, 0);
123125
svbool_t pg16_first_1 = svdupq_b16(1, 0, 0, 0, 0, 0, 0, 0);
126+
svbfloat16_t zeros = svdup_n_bf16(vcvth_bf16_f32(0.0));
124127
#endif
125128

126129
bfloat16_t *ptr_a = (bfloat16_t *)A;

kernel/generic/gemv_t.c

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,14 @@ int CNAME(BLASLONG m, BLASLONG n, FLOAT alpha, IFLOAT *a, BLASLONG lda, IFLOAT *
5252
temp += BF16TOF32(a_ptr[i]) * BF16TOF32(x[ix]);
5353
ix += inc_x;
5454
}
55-
y[iy] += F32TOBF16(ALPHA * temp);
55+
if (BETA == ZERO)
56+
{
57+
y[iy] = F32TOBF16(ALPHA * temp);
58+
}
59+
else
60+
{
61+
y[iy] = F32TOBF16(ALPHA * temp + BETA * BF16TOF32(y[iy]));
62+
}
5663
iy += inc_y;
5764
a_ptr += lda;
5865
}

test/compare_sgemm_bgemm.c

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ main (int argc, char *argv[])
4444
int ret = 0;
4545
int loop = BGEMM_LARGEST;
4646
char transA = 'N', transB = 'N';
47-
float alpha = 1.0, beta = 0.0;
47+
float alpha = 1.0, beta = 1.0;
4848
bfloat16 alpha_bf16;
4949
sbstobf16_(&one, &alpha, &one, &alpha_bf16, &one);
5050
bfloat16 beta_bf16;
@@ -94,9 +94,15 @@ main (int argc, char *argv[])
9494
transB = 'T';
9595
}
9696

97-
memset(CC, 0, m * n * sizeof(bfloat16));
98-
memset(DD, 0, m * n * sizeof(FLOAT));
99-
memset(C, 0, m * n * sizeof(FLOAT));
97+
for (j = 0; j < m; j++)
98+
{
99+
for (i = 0; i < n; i++)
100+
{
101+
C[j * n + i] = 100.0;
102+
DD[j * n + i] = 100.0;
103+
sbstobf16_(&one, &C[j * n + i], &one, &CC[j * n + i], &one);
104+
}
105+
}
100106

101107
SGEMM (&transA, &transB, &m, &n, &k, &alpha, A,
102108
&m, B, &k, &beta, C, &m);
@@ -152,7 +158,8 @@ main (int argc, char *argv[])
152158
}
153159

154160
if (ret != 0) {
155-
fprintf (stderr, "FATAL ERROR BGEMM - Return code: %d\n", ret);
161+
fprintf(stderr, "BGEMM FAILURES: %d\n", ret);
162+
return 1;
156163
}
157164

158165
return ret;

test/compare_sgemm_sbgemm.c

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ main (int argc, char *argv[])
140140
}
141141

142142
if (ret != 0) {
143-
fprintf (stderr, "FATAL ERROR SBGEMM - Return code: %d\n", ret);
143+
fprintf(stderr, "SBGEMM FAILURES: %d\n", ret);
144+
return 1;
144145
}
145146

146147
return ret;

test/compare_sgemv_bgemv.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ int main(int argc, char *argv[])
147147
} // alpha
148148
} // beta
149149

150-
if (ret != 0)
151-
fprintf(stderr, "FATAL ERROR BGEMV - Return code: %d\n", ret);
150+
if (ret != 0) {
151+
fprintf(stderr, "BGEMV FAILURES: %d\n", ret);
152+
return 1;
153+
}
154+
152155
return ret;
153156
}

test/compare_sgemv_sbgemv.c

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,10 @@ main (int argc, char *argv[])
122122
} // alpha
123123
} // beta
124124

125-
if (ret != 0)
126-
fprintf (stderr, "FATAL ERROR SBGEMV - Return code: %d\n", ret);
125+
if (ret != 0) {
126+
fprintf(stderr, "SBGEMV FAILURES: %d\n", ret);
127+
return 1;
128+
}
129+
127130
return ret;
128131
}

0 commit comments

Comments
 (0)