Skip to content

Commit affe742

Browse files
Extend support of batch_cast<...> to upcasting to a type twice as big
Fix #1179
1 parent cbf693c commit affe742

File tree

6 files changed

+166
-1
lines changed

6 files changed

+166
-1
lines changed
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
/***************************************************************************
2+
* Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and *
3+
* Martin Renou *
4+
* Copyright (c) QuantStack *
5+
* Copyright (c) Serge Guelton *
6+
* *
7+
* Distributed under the terms of the BSD 3-Clause License. *
8+
* *
9+
* The full license is in the file LICENSE, distributed with this software. *
10+
****************************************************************************/
11+
12+
#ifndef XSIMD_COMMON_CAST_HPP
13+
#define XSIMD_COMMON_CAST_HPP
14+
15+
namespace xsimd
16+
{
17+
namespace kernel
18+
{
19+
// upcast
20+
namespace detail
21+
{
22+
template <typename T>
23+
struct upcast;
24+
template <>
25+
struct upcast<uint32_t>
26+
{
27+
using type = uint64_t;
28+
};
29+
template <>
30+
struct upcast<uint16_t>
31+
{
32+
using type = uint32_t;
33+
};
34+
template <>
35+
struct upcast<uint8_t>
36+
{
37+
using type = uint8_t;
38+
};
39+
template <>
40+
struct upcast<int32_t>
41+
{
42+
using type = int64_t;
43+
};
44+
template <>
45+
struct upcast<int16_t>
46+
{
47+
using type = int32_t;
48+
};
49+
template <>
50+
struct upcast<int8_t>
51+
{
52+
using type = int8_t;
53+
};
54+
template <>
55+
struct upcast<float>
56+
{
57+
using type = double;
58+
};
59+
template <typename T>
60+
using upcast_t = typename upcast<T>::type;
61+
}
62+
63+
template <class T, class A>
64+
XSIMD_INLINE std::array<batch<detail::upcast_t<T>, A>, 2> batch_upcast(batch<T, A> const& x, requires_arch<common>) noexcept
65+
{
66+
alignas(A::alignment()) T buffer[batch<T, A>::size];
67+
x.store_aligned(&buffer[0]);
68+
69+
using T_out = detail::upcast_t<T>;
70+
alignas(A::alignment()) T_out out_buffer[batch<T, A>::size];
71+
for (size_t i = 0; i < batch<T, A>::size; ++i)
72+
out_buffer[i] = static_cast<T_out>(buffer[i]);
73+
74+
return { batch<T_out, A>::load_aligned(&out_buffer[0]),
75+
batch<T_out, A>::load_aligned(&out_buffer[batch<T_out, A>::size]) };
76+
}
77+
78+
}
79+
80+
}
81+
82+
#endif

include/xsimd/arch/xsimd_avx.hpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,22 @@ namespace xsimd
527527
}
528528
}
529529

530+
// upcast
531+
template <class A, class T>
532+
XSIMD_INLINE std::array<batch<detail::upcast_t<T>, A>, 2> batch_upcast(batch<T, A> const& x, requires_arch<avx>) noexcept
533+
{
534+
auto pair_lo = batch_upcast(batch<T, sse4_2>(_mm256_extractf128_si256(x, 0)), sse4_2 {});
535+
auto pair_hi = batch_upcast(batch<T, sse4_2>(_mm256_extractf128_si256(x, 1)), sse4_2 {});
536+
return { detail::merge_sse(pair_lo[0], pair_lo[1]), detail::merge_sse(pair_hi[0], pair_hi[1]) };
537+
}
538+
template <class A>
539+
XSIMD_INLINE std::array<batch<double, A>, 2> batch_upcast(batch<float, A> const& x, requires_arch<avx>) noexcept
540+
{
541+
__m256d lo = _mm256_cvtps_pd(_mm256_extractf128_ps(x, 0));
542+
__m256d hi = _mm256_cvtps_pd(_mm256_extractf128_ps(x, 1));
543+
return { lo, hi };
544+
}
545+
530546
// decr_if
531547
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
532548
XSIMD_INLINE batch<T, A> decr_if(batch<T, A> const& self, batch_bool<T, A> const& mask, requires_arch<avx>) noexcept

include/xsimd/arch/xsimd_avx2.hpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -431,6 +431,31 @@ namespace xsimd
431431
}
432432
}
433433

434+
// upcast
435+
template <class A, class T>
436+
XSIMD_INLINE std::array<batch<detail::upcast_t<T>, A>, 2> batch_upcast(batch<T, A> const& x, requires_arch<avx2>) noexcept
437+
{
438+
__m128i x_lo = _mm256_extracti128_si256(x, 0);
439+
__m128i x_hi = _mm256_extracti128_si256(x, 1);
440+
__m256i lo, hi;
441+
XSIMD_IF_CONSTEXPR(sizeof(T) == 4)
442+
{
443+
lo = _mm256_cvtepi32_epi64(x_lo);
444+
hi = _mm256_cvtepi32_epi64(x_hi);
445+
}
446+
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
447+
{
448+
lo = _mm256_cvtepi16_epi32(x_lo);
449+
hi = _mm256_cvtepi16_epi32(x_hi);
450+
}
451+
else XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
452+
{
453+
lo = _mm256_cvtepi8_epi16(x_lo);
454+
hi = _mm256_cvtepi8_epi16(x_hi);
455+
}
456+
return { lo, hi };
457+
}
458+
434459
// eq
435460
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
436461
XSIMD_INLINE batch_bool<T, A> eq(batch<T, A> const& self, batch<T, A> const& other, requires_arch<avx2>) noexcept

include/xsimd/arch/xsimd_common.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#define XSIMD_COMMON_HPP
1414

1515
#include "./common/xsimd_common_arithmetic.hpp"
16+
#include "./common/xsimd_common_cast.hpp"
1617
#include "./common/xsimd_common_complex.hpp"
1718
#include "./common/xsimd_common_logical.hpp"
1819
#include "./common/xsimd_common_math.hpp"

include/xsimd/arch/xsimd_sse4_1.hpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <type_traits>
1616

1717
#include "../types/xsimd_sse4_1_register.hpp"
18+
#include "./common/xsimd_common_cast.hpp"
1819

1920
namespace xsimd
2021
{
@@ -67,6 +68,38 @@ namespace xsimd
6768
}
6869
}
6970

71+
// upcast
72+
template <class T, class A>
73+
XSIMD_INLINE std::array<batch<detail::upcast_t<T>, A>, 2> batch_upcast(batch<T, A> const& x, requires_arch<sse4_1>) noexcept
74+
{
75+
__m128i x_shuf = _mm_unpackhi_epi64(x, x);
76+
__m128i lo, hi;
77+
XSIMD_IF_CONSTEXPR(sizeof(T) == 4)
78+
{
79+
lo = _mm_cvtepi32_epi64(x);
80+
hi = _mm_cvtepi32_epi64(x_shuf);
81+
}
82+
else XSIMD_IF_CONSTEXPR(sizeof(T) == 2)
83+
{
84+
lo = _mm_cvtepi16_epi32(x);
85+
hi = _mm_cvtepi16_epi32(x_shuf);
86+
}
87+
else XSIMD_IF_CONSTEXPR(sizeof(T) == 1)
88+
{
89+
lo = _mm_cvtepi8_epi16(x);
90+
hi = _mm_cvtepi8_epi16(x_shuf);
91+
}
92+
return { lo, hi };
93+
}
94+
template <class A>
95+
XSIMD_INLINE std::array<batch<double, A>, 2> batch_upcast(batch<float, A> const& x, requires_arch<sse4_1>) noexcept
96+
{
97+
__m128 x_shuf = _mm_unpackhi_ps(x, x);
98+
__m128d lo = _mm_cvtps_pd(x);
99+
__m128d hi = _mm_cvtps_pd(x_shuf);
100+
return { lo, hi };
101+
}
102+
70103
// eq
71104
template <class A, class T, class = typename std::enable_if<std::is_integral<T>::value, void>::type>
72105
XSIMD_INLINE batch_bool<T, A> eq(batch<T, A> const& self, batch<T, A> const& other, requires_arch<sse4_1>) noexcept

include/xsimd/types/xsimd_api.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,14 +256,22 @@ namespace xsimd
256256
* @param x batch of \c T_in
257257
* @return \c x cast to \c T_out
258258
*/
259-
template <class T_out, class T_in, class A>
259+
template <class T_out, class T_in, class A, class = typename std::enable_if<sizeof(T_out) == sizeof(T_in), void>::type>
260260
XSIMD_INLINE batch<T_out, A> batch_cast(batch<T_in, A> const& x) noexcept
261261
{
262262
detail::static_check_supported_config<T_out, A>();
263263
detail::static_check_supported_config<T_in, A>();
264264
return kernel::batch_cast<A>(x, batch<T_out, A> {}, A {});
265265
}
266266

267+
template <class T_out, class T_in, class A, class = typename std::enable_if<sizeof(T_out) == 2 * sizeof(T_in), void>::type>
268+
XSIMD_INLINE std::array<batch<T_out, A>, 2> batch_cast(batch<T_in, A> const& x) noexcept
269+
{
270+
detail::static_check_supported_config<T_out, A>();
271+
detail::static_check_supported_config<T_in, A>();
272+
return kernel::batch_upcast<A>(x, A {});
273+
}
274+
267275
/**
268276
* @ingroup batch_miscellaneous
269277
*

0 commit comments

Comments
 (0)