Skip to content

Commit a7f4da6

Browse files
NicoshevSilv3S
authored andcommitted
[Pytorch] Improve conversion to bfloat16 on aarch64/NEON (pytorch#166958)
Summary: Autovectorization of casting to bfloat16_t is broken in clang-[17, 20], fixed in clang-21. We are adding a workaround vectorized code, which improves conversion speed from smaller int data types. We've observed the following performance improvements, when compiling with clang-19 and targeting armv9a+sve2: before: uint8->bfloat16_t ===> 319.433us int8->bfloat16_t ===> 320.216us int16->bfloat16_t ===> 326.899us int32->bfloat16_t ===> 327.925us after: uint8->bfloat16_t ===> 185.189us -----> 72% higher throughput int8->bfloat16_t ===> 169.790us -----> 89% higher throughput int16->bfloat16_t ===> 180.744us -----> 81% higher throughput int32->bfloat16_t ===> 185.129us -----> 77% higher throughput Test Plan: Correctness: buck2 test mode/opt //caffe2/test:test_ops buck2 test mode/opt //caffe2/test:torch Performance: buck2 run mode/opt //caffe2/benchmarks/operator_benchmark/fb:operator_benchmark_test Differential Revision: D86207189 Pull Request resolved: pytorch#166958 Approved by: https://github.com/mcfi
1 parent 7e98c4c commit a7f4da6

File tree

1 file changed

+56
-0
lines changed

1 file changed

+56
-0
lines changed

aten/src/ATen/cpu/vec/vec128/vec128_convert.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,62 @@ CONVERT_FROM_BF16_TEMPLATE(double)
223223
CONVERT_FROM_BF16_TEMPLATE(float16_t)
224224
#endif
225225

226+
#ifdef __ARM_FEATURE_BF16
227+
228+
// clang-[17, 20] crashes when autovectorizing static cast to bf16
229+
// Below is a workaround to have some vectorization
230+
// Works decently well for smaller int types
231+
template <typename from_type>
232+
inline void convertToBf16Impl(
233+
const from_type* __restrict src,
234+
c10::BFloat16* __restrict dst,
235+
uint64_t n) {
236+
bfloat16_t* dstPtr = reinterpret_cast<bfloat16_t*>(dst);
237+
uint64_t loopBound = n - (n % 16);
238+
uint64_t i = 0;
239+
for (; i < loopBound; i += 16) {
240+
float32x4_t a, b, c, d;
241+
a[0] = static_cast<float>(src[i]);
242+
a[1] = static_cast<float>(src[i + 1]);
243+
a[2] = static_cast<float>(src[i + 2]);
244+
a[3] = static_cast<float>(src[i + 3]);
245+
b[0] = static_cast<float>(src[i + 4]);
246+
b[1] = static_cast<float>(src[i + 5]);
247+
b[2] = static_cast<float>(src[i + 6]);
248+
b[3] = static_cast<float>(src[i + 7]);
249+
c[0] = static_cast<float>(src[i + 8]);
250+
c[1] = static_cast<float>(src[i + 9]);
251+
c[2] = static_cast<float>(src[i + 10]);
252+
c[3] = static_cast<float>(src[i + 11]);
253+
d[0] = static_cast<float>(src[i + 12]);
254+
d[1] = static_cast<float>(src[i + 13]);
255+
d[2] = static_cast<float>(src[i + 14]);
256+
d[3] = static_cast<float>(src[i + 15]);
257+
258+
vst1q_bf16(dstPtr + i, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(a), b));
259+
vst1q_bf16(dstPtr + i + 8, vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(c), d));
260+
}
261+
262+
#pragma clang loop vectorize(disable) interleave(disable) unroll(disable)
263+
for (; i < n; i++) {
264+
float a = static_cast<float>(src[i]);
265+
dstPtr[i] = vcvth_bf16_f32(a);
266+
}
267+
}
268+
269+
#define CONVERT_TO_BF16_TEMPLATE(from_type) \
270+
template <> \
271+
inline void convert(const from_type* src, c10::BFloat16* dst, int64_t n) { \
272+
return convertToBf16Impl<from_type>(src, dst, n); \
273+
}
274+
275+
CONVERT_TO_BF16_TEMPLATE(uint8_t)
276+
CONVERT_TO_BF16_TEMPLATE(int8_t)
277+
CONVERT_TO_BF16_TEMPLATE(int16_t)
278+
CONVERT_TO_BF16_TEMPLATE(int32_t)
279+
280+
#endif
281+
226282
inline void convertBoolToBfloat16Impl(
227283
const bool* __restrict src,
228284
c10::BFloat16* __restrict dst,

0 commit comments

Comments
 (0)