88#include < benchmark/benchmark.h>
99
1010#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
11+ #include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
1112#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
1213#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
1314#include < torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
1617
1718namespace {
1819
20+ // Benchmark utility to compare variants of uint1 packing
21+ void pack_uint1_values (
22+ uint8_t * packed,
23+ uint8_t * unpacked,
24+ int packed_size,
25+ int unpacked_size,
26+ int variant) {
27+ constexpr int nbit = 1 ;
28+ constexpr int bitsPerByte = 8 ;
29+ assert (unpacked_size * nbit / bitsPerByte == packed_size);
30+ assert (packed_size % variant == 0 );
31+
32+ uint8x16_t unpacked0;
33+ uint8x16_t unpacked1;
34+ uint8x16_t unpacked2;
35+ uint8x16_t unpacked3;
36+ uint8x16_t unpacked4;
37+ uint8x16_t unpacked5;
38+ uint8x16_t unpacked6;
39+ uint8x16_t unpacked7;
40+
41+ switch (variant) {
42+ case 8 :
43+ for (int i = 0 ; i < unpacked_size; i += 8 ) {
44+ torchao::bitpacking::internal::pack_8_uint1_values (
45+ packed + ((i * nbit) / bitsPerByte), unpacked + i);
46+ }
47+ break ;
48+ case 64 :
49+ for (int i = 0 ; i < unpacked_size; i += 64 ) {
50+ torchao::bitpacking::internal::vec_load_64_uint8_values (
51+ unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
52+ torchao::bitpacking::internal::vec_pack_64_uint1_values (
53+ packed + ((i * nbit) / bitsPerByte),
54+ unpacked0,
55+ unpacked1,
56+ unpacked2,
57+ unpacked3);
58+ }
59+ break ;
60+ case 128 :
61+ for (int i = 0 ; i < unpacked_size; i += 128 ) {
62+ torchao::bitpacking::internal::vec_load_64_uint8_values (
63+ unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
64+ torchao::bitpacking::internal::vec_load_64_uint8_values (
65+ unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64 );
66+ torchao::bitpacking::internal::vec_pack_128_uint1_values (
67+ packed + ((i * nbit) / bitsPerByte),
68+ unpacked0,
69+ unpacked1,
70+ unpacked2,
71+ unpacked3,
72+ unpacked4,
73+ unpacked5,
74+ unpacked6,
75+ unpacked7);
76+ }
77+ break ;
78+ }
79+ }
80+
81+ // Benchmark utility to compare variants of uint1 packing
82+ void unpack_uint1_values (
83+ uint8_t * unpacked,
84+ uint8_t * packed,
85+ int unpacked_size,
86+ int packed_size,
87+ int variant) {
88+ constexpr int nbit = 1 ;
89+ constexpr int bitsPerByte = 8 ;
90+ assert (unpacked_size * nbit / bitsPerByte == packed_size);
91+ assert (packed_size % variant == 0 );
92+
93+ uint8x16_t unpacked0;
94+ uint8x16_t unpacked1;
95+ uint8x16_t unpacked2;
96+ uint8x16_t unpacked3;
97+ uint8x16_t unpacked4;
98+ uint8x16_t unpacked5;
99+ uint8x16_t unpacked6;
100+ uint8x16_t unpacked7;
101+
102+ switch (variant) {
103+ case 8 :
104+ for (int i = 0 ; i < unpacked_size; i += 8 ) {
105+ torchao::bitpacking::internal::unpack_8_uint1_values (
106+ unpacked + i, packed + ((i * nbit) / bitsPerByte));
107+ }
108+ break ;
109+ case 64 :
110+ for (int i = 0 ; i < unpacked_size; i += 64 ) {
111+ torchao::bitpacking::internal::vec_unpack_64_uint1_values (
112+ unpacked0,
113+ unpacked1,
114+ unpacked2,
115+ unpacked3,
116+ packed + ((i * nbit) / bitsPerByte));
117+ torchao::bitpacking::internal::vec_store_64_uint8_values (
118+ unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
119+ }
120+ break ;
121+ case 128 :
122+ for (int i = 0 ; i < unpacked_size; i += 128 ) {
123+ torchao::bitpacking::internal::vec_unpack_128_uint1_values (
124+ unpacked0,
125+ unpacked1,
126+ unpacked2,
127+ unpacked3,
128+ unpacked4,
129+ unpacked5,
130+ unpacked6,
131+ unpacked7,
132+ packed + ((i * nbit) / bitsPerByte));
133+ torchao::bitpacking::internal::vec_store_64_uint8_values (
134+ unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
135+ torchao::bitpacking::internal::vec_store_64_uint8_values (
136+ unpacked + i + 64 , unpacked4, unpacked5, unpacked6, unpacked7);
137+ }
138+ break ;
139+ }
140+ }
141+
19142// Benchmark utility to compare variants of uint2 packing
20143void pack_uint2_values (
21144 uint8_t * packed,
@@ -470,6 +593,44 @@ void unpack_uint5_values(
470593
471594} // namespace
472595
596+ static void benchmark_pack_uint1_values (benchmark::State& state) {
597+ int unpacked_size = state.range (0 );
598+ int variant = state.range (1 );
599+ int nbit = 1 ;
600+
601+ assert (unpacked_size % 8 == 0 );
602+ int packed_size = (unpacked_size / 8 ) * nbit;
603+
604+ auto packed = std::vector<uint8_t >(packed_size, 0 );
605+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit);
606+
607+ for (auto _ : state) {
608+ pack_uint1_values (
609+ packed.data (), unpacked.data (), packed_size, unpacked_size, variant);
610+ }
611+ }
612+
613+ static void benchmark_unpack_uint1_values (benchmark::State& state) {
614+ int unpacked_size = state.range (0 );
615+ int variant = state.range (1 );
616+ int nbit = 1 ;
617+
618+ assert (unpacked_size % 8 == 0 );
619+ int packed_size = (unpacked_size / 8 ) * nbit;
620+
621+ auto packed = torchao::get_random_lowbit_vector (packed_size, 8 );
622+ auto unpacked = std::vector<uint8_t >(unpacked_size, 0 );
623+
624+ for (auto _ : state) {
625+ unpack_uint1_values (
626+ unpacked.data (),
627+ packed.data (),
628+ unpacked.size (),
629+ packed.size (),
630+ variant);
631+ }
632+ }
633+
473634static void benchmark_pack_uint2_values (benchmark::State& state) {
474635 int unpacked_size = state.range (0 );
475636 int variant = state.range (1 );
@@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) {
478639 assert (unpacked_size % 8 == 0 );
479640 int packed_size = (unpacked_size / 8 ) * nbit;
480641
481- auto packed = std::vector<uint8_t >(unpacked_size , 0 );
482- auto unpacked = torchao::get_random_lowbit_vector (packed_size, 8 );
642+ auto packed = std::vector<uint8_t >(packed_size , 0 );
643+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
483644
484645 for (auto _ : state) {
485646 pack_uint2_values (
@@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) {
516677 assert (unpacked_size % 8 == 0 );
517678 int packed_size = (unpacked_size / 8 ) * nbit;
518679
519- auto packed = std::vector<uint8_t >(unpacked_size , 0 );
520- auto unpacked = torchao::get_random_lowbit_vector (packed_size, 8 );
680+ auto packed = std::vector<uint8_t >(packed_size , 0 );
681+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
521682
522683 for (auto _ : state) {
523684 pack_uint3_values (
@@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) {
554715 assert (unpacked_size % 8 == 0 );
555716 int packed_size = (unpacked_size / 8 ) * nbit;
556717
557- auto packed = std::vector<uint8_t >(unpacked_size , 0 );
558- auto unpacked = torchao::get_random_lowbit_vector (packed_size, 8 );
718+ auto packed = std::vector<uint8_t >(packed_size , 0 );
719+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
559720
560721 for (auto _ : state) {
561722 pack_uint4_values (
@@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) {
592753 assert (unpacked_size % 8 == 0 );
593754 int packed_size = (unpacked_size / 8 ) * nbit;
594755
595- auto packed = std::vector<uint8_t >(unpacked_size , 0 );
596- auto unpacked = torchao::get_random_lowbit_vector (packed_size, 8 );
756+ auto packed = std::vector<uint8_t >(packed_size , 0 );
757+ auto unpacked = torchao::get_random_lowbit_vector (unpacked_size, nbit );
597758
598759 for (auto _ : state) {
599760 pack_uint5_values (
@@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) {
622783 }
623784}
624785
786+ BENCHMARK (benchmark_pack_uint1_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
787+ BENCHMARK (benchmark_unpack_uint1_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
625788BENCHMARK (benchmark_pack_uint2_values)->ArgsProduct({{128 }, {4 , 32 , 64 }});
626789BENCHMARK (benchmark_unpack_uint2_values)->ArgsProduct({{128 }, {4 , 32 , 64 }});
627790BENCHMARK (benchmark_pack_uint3_values)->ArgsProduct({{128 }, {8 , 64 , 128 }});
0 commit comments