1212// See the License for the specific language governing permissions and
1313// limitations under the License.
1414
15+ #include " vec_load.hpp"
16+ #include " vec_store.hpp"
1517#include " vec_arithmetic.hpp"
18+ #include " cmath"
1619
17- inline fp32x16 sub_fp32x16 (fp32x16 x, fp32x16 y) {
20+ fp32x16 sub_fp32x16 (fp32x16 x, fp32x16 y) {
1821#if __AVX512F__
19- return _mm512_sub_ps (x, y) ;
22+ return { _mm512_sub_ps (x. first , y. first )} ;
2023#else
2124 return {_mm256_sub_ps (x.first , y.first ), _mm256_sub_ps (x.second , y.second )};
2225#endif
2326}
2427
25- inline fp32x16 fmsub_fp32x16 (fp32x16 x, fp32x16 y, fp32x16 z) {
28+ fp32x16 fmsub_fp32x16 (fp32x16 x, fp32x16 y, fp32x16 z) {
2629#if __AVX512F__
27- return _mm512_fmsub_ps (x, y, z) ;
30+ return { _mm512_fmsub_ps (x. first , y. first , z. first )} ;
2831#else
2932 return {_mm256_fmsub_ps (x.first , y.first , z.first ), _mm256_fmsub_ps (x.second , y.second , z.second )};
3033#endif
3134}
3235
33- inline fp32x16 maskz_fmsub_fp32x16 (int mask, fp32x16 x, fp32x16 y, fp32x16 z) {
36+ fp32x16 maskz_fmsub_fp32x16 (int mask, fp32x16 x, fp32x16 y, fp32x16 z) {
3437#if __AVX512F__
35- return _mm512_maskz_fmsub_ps (mask, x, y, z) ;
38+ return { _mm512_maskz_fmsub_ps (mask, x. first , y. first , z. first )} ;
3639#else
3740 __m256 first, second;
3841 MASK_DECORATOR (_mm256_blend_ps, _mm256_setzero_ps (), _mm256_fmsub_ps (x.first , y.first , z.first ), mask & 255 , first);
@@ -42,33 +45,33 @@ inline fp32x16 maskz_fmsub_fp32x16(int mask, fp32x16 x, fp32x16 y, fp32x16 z) {
4245#endif
4346}
4447
45- inline fp32x16 add_fp32x16 (fp32x16 x, fp32x16 y) {
48+ fp32x16 add_fp32x16 (fp32x16 x, fp32x16 y) {
4649#if __AVX512F__
47- return _mm512_add_ps (x, y) ;
50+ return { _mm512_add_ps (x. first , y. first )} ;
4851#else
4952 return {_mm256_add_ps (x.first , y.first ), _mm256_add_ps (x.second , y.second )};
5053#endif
5154}
5255
53- inline fp32x16 fmadd_fp32x16 (fp32x16 x, fp32x16 y, fp32x16 z) {
56+ fp32x16 fmadd_fp32x16 (fp32x16 x, fp32x16 y, fp32x16 z) {
5457#if __AVX512F__
55- return _mm512_fmadd_ps (x, y, z) ;
58+ return { _mm512_fmadd_ps (x. first , y. first , z. first )} ;
5659#else
5760 return {_mm256_fmadd_ps (x.first , y.first , z.first ), _mm256_fmadd_ps (x.second , y.second , z.second )};
5861#endif
5962}
6063
61- inline fp32x16 mul_fp32x16 (fp32x16 x, fp32x16 y) {
64+ fp32x16 mul_fp32x16 (fp32x16 x, fp32x16 y) {
6265#if __AVX512F__
63- return _mm512_mul_ps (x, y) ;
66+ return { _mm512_mul_ps (x. first , y. first )} ;
6467#else
6568 return {_mm256_mul_ps (x.first , y.first ), _mm256_mul_ps (x.second , y.second )};
6669#endif
6770}
6871
69- inline fp32x16 maskz_mul_fp32x16 (int mask, fp32x16 x, fp32x16 y) {
72+ fp32x16 maskz_mul_fp32x16 (int mask, fp32x16 x, fp32x16 y) {
7073#if __AVX512F__
71- return _mm512_maskz_mul_ps (mask, x, y) ;
74+ return { _mm512_maskz_mul_ps (mask, x. first , y. first )} ;
7275#else
7376 __m256 first, second;
7477 MASK_DECORATOR (_mm256_blend_ps, _mm256_setzero_ps (), _mm256_mul_ps (x.first , y.first ), mask & 255 , first);
@@ -78,31 +81,31 @@ inline fp32x16 maskz_mul_fp32x16(int mask, fp32x16 x, fp32x16 y) {
7881}
7982
8083template <int rounding>
81- inline fp32x16 mul_round_fp32x16 (fp32x16 x, fp32x16 y) {
84+ fp32x16 mul_round_fp32x16 (fp32x16 x, fp32x16 y) {
8285 static_assert (rounding == (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC) ||
8386 rounding == (_MM_FROUND_TO_ZERO | _MM_FROUND_NO_EXC) ||
8487 rounding == (_MM_FROUND_TO_NEG_INF | _MM_FROUND_NO_EXC) ||
8588 rounding == (_MM_FROUND_TO_POS_INF | _MM_FROUND_NO_EXC) || rounding == (_MM_FROUND_CUR_DIRECTION),
8689 " ERROR: Not support rounding" );
8790#if __AVX512F__
88- return _mm512_mul_round_ps (x, y, rounding);
91+ return { _mm512_mul_round_ps (x. first , y. first , rounding)} ;
8992#else
9093 return {_mm256_round_ps (_mm256_mul_ps (x.first , y.first ), rounding),
9194 _mm256_round_ps (_mm256_mul_ps (x.second , y.second ), rounding)};
9295#endif
9396}
9497
95- inline fp32x16 div_fp32x16 (fp32x16 x, fp32x16 y) {
98+ fp32x16 div_fp32x16 (fp32x16 x, fp32x16 y) {
9699#if __AVX512F__
97- return _mm512_div_ps (x, y) ;
100+ return { _mm512_div_ps (x. first , y. first )} ;
98101#else
99102 return {_mm256_div_ps (x.first , y.first ), _mm256_div_ps (x.second , y.second )};
100103#endif
101104}
102105
103- inline float reduce_add_fp32x16 (fp32x16 x) {
106+ float reduce_add_fp32x16 (fp32x16 x) {
104107#if __AVX512F__
105- return _mm512_reduce_add_ps (x) ;
108+ return { _mm512_reduce_add_ps (x. first )} ;
106109#else
107110 const __m256 x256 = _mm256_add_ps (x.first , x.second );
108111 const __m128 x128 = _mm_add_ps (_mm256_extractf128_ps (x256, 1 ), _mm256_castps256_ps128 (x256));
@@ -112,46 +115,55 @@ inline float reduce_add_fp32x16(fp32x16 x) {
112115#endif
113116}
114117
115- inline fp32x16 sqrt_fp32x16 (fp32x16 x) {
118+ fp32x16 sqrt_fp32x16 (fp32x16 x) {
116119#if __AVX512F__
117- return _mm512_sqrt_ps (x) ;
120+ return { _mm512_sqrt_ps (x. first )} ;
118121#else
119122 return {_mm256_sqrt_ps (x.first ), _mm256_sqrt_ps (x.second )};
120123#endif
121124}
122125
123- inline fp32x16 rsqrt14_fp32x16 (fp32x16 x) {
126+ fp32x16 rsqrt14_fp32x16 (fp32x16 x) {
124127#if __AVX512F__
125- return _mm512_rsqrt14_ps (x) ;
128+ return { _mm512_rsqrt14_ps (x. first )} ;
126129#else
127130 // the max relative error is 6x than avx512
128131 return {_mm256_rsqrt_ps (x.first ), _mm256_rsqrt_ps (x.second )};
129132#endif
130133}
131- inline fp32x16 ceil_fp32x16 (fp32x16 x) {
134+ fp32x16 ceil_fp32x16 (fp32x16 x) {
132135#if __AVX512F__
133- return _mm512_ceil_ps (x) ;
136+ return { _mm512_ceil_ps (x. first )} ;
134137#else
135138 // the max relative error is 6x than avx512
136139 return {_mm256_ceil_ps (x.first ), _mm256_ceil_ps (x.second )};
137140#endif
138141}
139142
140- inline fp32x16 scale_fp32x16 (fp32x16 x, fp32x16 y) {
143+ fp32x16 scale_fp32x16 (fp32x16 x, fp32x16 y) {
141144#if __AVX512F__
142- return _mm512_scalef_ps (x, y) ;
145+ return { _mm512_scalef_ps (x. first , y. first )} ;
143146#else
144147 // No intrinsic
145- assert (" No intrinsic" );
146- return {_mm256_rsqrt_ps (x.first ), _mm256_rsqrt_ps (x.second )};
148+ float * vec_x = new float [16 ];
149+ float * vec_y = new float [16 ];
150+ float * vec_z = new float [16 ];
151+ store_fp32x16 (vec_x, x);
152+ store_fp32x16 (vec_y, y);
153+ for (int i = 0 ; i < 16 ; i++) vec_z[i] = vec_x[i] * exp2 (vec_y[i]);
154+ fp32x16 res = load_fp32x16 (vec_z);
155+ delete[] vec_x;
156+ delete[] vec_y;
157+ delete[] vec_z;
158+ return res;
147159#endif
148160}
149161
150- inline float dot_fp32x16 (fp32x16 x, fp32x16 y) { return reduce_add_fp32x16 (mul_fp32x16 (x, y)); }
162+ float dot_fp32x16 (fp32x16 x, fp32x16 y) { return reduce_add_fp32x16 (mul_fp32x16 (x, y)); }
151163
152- inline fp32x16 abs_fp32x16 (fp32x16 x) {
164+ fp32x16 abs_fp32x16 (fp32x16 x) {
153165#if __AVX512F__
154- return _mm512_abs_ps (x) ;
166+ return { _mm512_abs_ps (x. first )} ;
155167#else
156168 return {_mm256_castsi256_ps (_mm256_abs_epi32 (_mm256_castps_si256 (x.first ))),
157169 _mm256_castsi256_ps (_mm256_abs_epi32 (_mm256_castps_si256 (x.second )))};
0 commit comments