|
| 1 | +/*************************************************************************** |
| 2 | + * Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and * |
| 3 | + * Martin Renou * |
| 4 | + * Copyright (c) QuantStack * |
| 5 | + * Copyright (c) Serge Guelton * |
| 6 | + * Copyright (c) Marco Barbone * |
| 7 | + * * |
| 8 | + * Distributed under the terms of the BSD 3-Clause License. * |
| 9 | + * * |
| 10 | + * The full license is in the file LICENSE, distributed with this software.* |
| 11 | + ****************************************************************************/ |
| 12 | +#ifndef XSIMD_COMMON_SWIZZLE_HPP |
| 13 | +#define XSIMD_COMMON_SWIZZLE_HPP |
| 14 | + |
| 15 | +#include <cstddef> |
| 16 | +#include <cstdint> |
| 17 | +#include <type_traits> |
| 18 | + |
| 19 | +namespace xsimd |
| 20 | +{ |
| 21 | + template <typename T, class A, T... Values> |
| 22 | + struct batch_constant; |
| 23 | + |
| 24 | + namespace kernel |
| 25 | + { |
| 26 | + namespace detail |
| 27 | + { |
| 28 | + // ──────────────────────────────────────────────────────────────────────── |
| 29 | + // get_at<I,Values...> → the I-th element of the pack |
| 30 | + template <typename T, std::size_t I, T V0, T... Vs> |
| 31 | + struct get_at |
| 32 | + { |
| 33 | + static constexpr T value = get_at<T, I - 1, Vs...>::value; |
| 34 | + }; |
| 35 | + template <typename T, T V0, T... Vs> |
| 36 | + struct get_at<T, 0, V0, Vs...> |
| 37 | + { |
| 38 | + static constexpr T value = V0; |
| 39 | + }; |
| 40 | + |
| 41 | + // ──────────────────────────────────────────────────────────────────────── |
| 42 | + // 1) identity_impl |
| 43 | + template <std::size_t /*I*/, typename T> |
| 44 | + XSIMD_INLINE constexpr bool identity_impl() noexcept { return true; } |
| 45 | + template <std::size_t I, typename T, T V0, T... Vs> |
| 46 | + XSIMD_INLINE constexpr bool identity_impl() noexcept |
| 47 | + { |
| 48 | + return V0 == static_cast<T>(I) |
| 49 | + && identity_impl<I + 1, T, Vs...>(); |
| 50 | + } |
| 51 | + |
| 52 | + // ──────────────────────────────────────────────────────────────────────── |
| 53 | + // 2) bitmask_impl |
| 54 | + template <std::size_t /*I*/, std::size_t /*N*/, typename T> |
| 55 | + XSIMD_INLINE constexpr std::uint32_t bitmask_impl() noexcept { return 0u; } |
| 56 | + template <std::size_t I, std::size_t N, typename T, T V0, T... Vs> |
| 57 | + XSIMD_INLINE constexpr std::uint32_t bitmask_impl() noexcept |
| 58 | + { |
| 59 | + return (1u << (static_cast<std::uint32_t>(V0) & (N - 1))) |
| 60 | + | bitmask_impl<I + 1, N, T, Vs...>(); |
| 61 | + } |
| 62 | + |
| 63 | + // ──────────────────────────────────────────────────────────────────────── |
| 64 | + // 3) dup_lo_impl |
| 65 | + template <std::size_t I, std::size_t N, typename T, |
| 66 | + T... Vs, typename std::enable_if<I == N / 2, int>::type = 0> |
| 67 | + XSIMD_INLINE constexpr bool dup_lo_impl() noexcept { return true; } |
| 68 | + |
| 69 | + template <std::size_t I, std::size_t N, typename T, |
| 70 | + T... Vs, typename std::enable_if<(I < N / 2), int>::type = 0> |
| 71 | + XSIMD_INLINE constexpr bool dup_lo_impl() noexcept |
| 72 | + { |
| 73 | + return get_at<T, I, Vs...>::value < static_cast<T>(N / 2) |
| 74 | + && get_at<T, I + N / 2, Vs...>::value == get_at<T, I, Vs...>::value |
| 75 | + && dup_lo_impl<I + 1, N, T, Vs...>(); |
| 76 | + } |
| 77 | + |
| 78 | + // ──────────────────────────────────────────────────────────────────────── |
| 79 | + // 4) dup_hi_impl |
| 80 | + template <std::size_t I, std::size_t N, typename T, |
| 81 | + T... Vs, typename std::enable_if<I == N / 2, int>::type = 0> |
| 82 | + XSIMD_INLINE constexpr bool dup_hi_impl() noexcept { return true; } |
| 83 | + |
| 84 | + template <std::size_t I, std::size_t N, typename T, |
| 85 | + T... Vs, typename std::enable_if<(I < N / 2), int>::type = 0> |
| 86 | + XSIMD_INLINE constexpr bool dup_hi_impl() noexcept |
| 87 | + { |
| 88 | + return get_at<T, I, Vs...>::value >= static_cast<T>(N / 2) |
| 89 | + && get_at<T, I, Vs...>::value < static_cast<T>(N) |
| 90 | + && get_at<T, I + N / 2, Vs...>::value == get_at<T, I, Vs...>::value |
| 91 | + && dup_hi_impl<I + 1, N, T, Vs...>(); |
| 92 | + } |
| 93 | + |
| 94 | + // ──────────────────────────────────────────────────────────────────────── |
| 95 | + // 1) helper to get the I-th value from the Vs pack |
| 96 | + template <std::size_t I, uint32_t Head, uint32_t... Tail> |
| 97 | + struct get_nth_value |
| 98 | + { |
| 99 | + static constexpr uint32_t value = get_nth_value<I - 1, Tail...>::value; |
| 100 | + }; |
| 101 | + template <uint32_t Head, uint32_t... Tail> |
| 102 | + struct get_nth_value<0, Head, Tail...> |
| 103 | + { |
| 104 | + static constexpr uint32_t value = Head; |
| 105 | + }; |
| 106 | + |
| 107 | + // ──────────────────────────────────────────────────────────────────────── |
| 108 | + // 2) recursive cross‐lane test: true if any output‐lane i pulls from the opposite half |
| 109 | + template <std::size_t I, |
| 110 | + std::size_t N, |
| 111 | + std::size_t H, |
| 112 | + uint32_t... Vs> |
| 113 | + struct cross_impl |
| 114 | + { |
| 115 | + // does element I cross? (i.e. i<H but V>=H) or (i>=H but V<H) |
| 116 | + static constexpr uint32_t Vi = get_nth_value<I, Vs...>::value; |
| 117 | + static constexpr bool curr = (I < H ? (Vi >= H) : (Vi < H)); |
| 118 | + static constexpr bool next = cross_impl<I + 1, N, H, Vs...>::value; |
| 119 | + static constexpr bool value = curr || next; |
| 120 | + }; |
| 121 | + template <std::size_t N, std::size_t H, uint32_t... Vs> |
| 122 | + struct cross_impl<N, N, H, Vs...> |
| 123 | + { |
| 124 | + static constexpr bool value = false; |
| 125 | + }; |
| 126 | + template <std::size_t I, std::size_t N, typename T, |
| 127 | + T... Vs> |
| 128 | + XSIMD_INLINE constexpr bool no_duplicates_impl() noexcept |
| 129 | + { |
| 130 | + // build the bitmask of (Vs & (N-1)) across all lanes |
| 131 | + return detail::bitmask_impl<0, N, T, Vs...>() == ((1u << N) - 1u); |
| 132 | + } |
| 133 | + template <uint32_t... Vs> |
| 134 | + XSIMD_INLINE constexpr bool no_duplicates_v() noexcept |
| 135 | + { |
| 136 | + // forward to your existing no_duplicates_impl |
| 137 | + return no_duplicates_impl<0, sizeof...(Vs), uint32_t, Vs...>(); |
| 138 | + } |
| 139 | + template <uint32_t... Vs> |
| 140 | + XSIMD_INLINE constexpr bool is_cross_lane() noexcept |
| 141 | + { |
| 142 | + static_assert(sizeof...(Vs) >= 1, "Need at least one lane"); |
| 143 | + return cross_impl<0, sizeof...(Vs), sizeof...(Vs) / 2, Vs...>::value; |
| 144 | + } |
| 145 | + template <typename T, T... Vs> |
| 146 | + XSIMD_INLINE constexpr bool is_identity() noexcept { return detail::identity_impl<0, T, Vs...>(); } |
| 147 | + template <typename T, T... Vs> |
| 148 | + XSIMD_INLINE constexpr bool is_all_different() noexcept |
| 149 | + { |
| 150 | + return detail::bitmask_impl<0, sizeof...(Vs), T, Vs...>() == ((1u << sizeof...(Vs)) - 1); |
| 151 | + } |
| 152 | + |
| 153 | + template <typename T, T... Vs> |
| 154 | + XSIMD_INLINE constexpr bool is_dup_lo() noexcept { return detail::dup_lo_impl<0, sizeof...(Vs), T, Vs...>(); } |
| 155 | + template <typename T, T... Vs> |
| 156 | + XSIMD_INLINE constexpr bool is_dup_hi() noexcept { return detail::dup_hi_impl<0, sizeof...(Vs), T, Vs...>(); } |
| 157 | + template <typename T, class A, T... Vs> |
| 158 | + XSIMD_INLINE constexpr bool is_identity(batch_constant<T, A, Vs...>) noexcept { return is_identity<T, Vs...>(); } |
| 159 | + template <typename T, class A, T... Vs> |
| 160 | + XSIMD_INLINE constexpr bool is_all_different(batch_constant<T, A, Vs...>) noexcept { return is_all_different<T, Vs...>(); } |
| 161 | + template <typename T, class A, T... Vs> |
| 162 | + XSIMD_INLINE constexpr bool is_dup_lo(batch_constant<T, A, Vs...>) noexcept { return is_dup_lo<T, Vs...>(); } |
| 163 | + template <typename T, class A, T... Vs> |
| 164 | + XSIMD_INLINE constexpr bool is_dup_hi(batch_constant<T, A, Vs...>) noexcept { return is_dup_hi<T, Vs...>(); } |
| 165 | + template <typename T, class A, T... Vs> |
| 166 | + XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept { return detail::is_cross_lane<Vs...>(); } |
| 167 | + template <typename T, class A, T... Vs> |
| 168 | + XSIMD_INLINE constexpr bool no_duplicates(batch_constant<T, A, Vs...>) noexcept { return no_duplicates_impl<0, sizeof...(Vs), T, Vs...>(); } |
| 169 | + // ──────────────────────────────────────────────────────────────────────── |
| 170 | + // compile-time tests (identity, all-different, dup-lo, dup-hi) |
| 171 | + // 8-lane identity |
| 172 | + static_assert(is_identity<std::uint32_t, 0, 1, 2, 3, 4, 5, 6, 7>(), "identity failed"); |
| 173 | + // 8-lane reverse is all-different but not identity |
| 174 | + static_assert(is_all_different<std::uint32_t, 7, 6, 5, 4, 3, 2, 1, 0>(), "all-diff failed"); |
| 175 | + static_assert(!is_identity<std::uint32_t, 7, 6, 5, 4, 3, 2, 1, 0>(), "identity on reverse"); |
| 176 | + // 8-lane dup-lo (repeat 0..3 twice) |
| 177 | + static_assert(is_dup_lo<std::uint32_t, 0, 1, 2, 3, 0, 1, 2, 3>(), "dup_lo failed"); |
| 178 | + static_assert(!is_dup_hi<std::uint32_t, 0, 1, 2, 3, 0, 1, 2, 3>(), "dup_hi on dup_lo"); |
| 179 | + // 8-lane dup-hi (repeat 4..7 twice) |
| 180 | + static_assert(is_dup_hi<std::uint32_t, 4, 5, 6, 7, 4, 5, 6, 7>(), "dup_hi failed"); |
| 181 | + static_assert(!is_dup_lo<std::uint32_t, 4, 5, 6, 7, 4, 5, 6, 7>(), "dup_lo on dup_hi"); |
| 182 | + // ──────────────────────────────────────────────────────────────────────── |
| 183 | + // 4-lane identity |
| 184 | + static_assert(is_identity<std::uint32_t, 0, 1, 2, 3>(), "4-lane identity failed"); |
| 185 | + // 4-lane reverse all-different but not identity |
| 186 | + static_assert(is_all_different<std::uint32_t, 3, 2, 1, 0>(), "4-lane all-diff failed"); |
| 187 | + static_assert(!is_identity<std::uint32_t, 3, 2, 1, 0>(), "4-lane identity on reverse"); |
| 188 | + // 4-lane dup-lo (repeat 0..1 twice) |
| 189 | + static_assert(is_dup_lo<std::uint32_t, 0, 1, 0, 1>(), "4-lane dup_lo failed"); |
| 190 | + static_assert(!is_dup_hi<std::uint32_t, 0, 1, 0, 1>(), "4-lane dup_hi on dup_lo"); |
| 191 | + // 4-lane dup-hi (repeat 2..3 twice) |
| 192 | + static_assert(is_dup_hi<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_hi failed"); |
| 193 | + static_assert(!is_dup_lo<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_lo on dup_hi"); |
| 194 | + |
| 195 | + static_assert(is_cross_lane<0, 1, 0, 1>(), "dup-lo only → crossing"); |
| 196 | + static_assert(is_cross_lane<2, 3, 2, 3>(), "dup-hi only → crossing"); |
| 197 | + static_assert(is_cross_lane<0, 3, 3, 3>(), "one low + rest high → crossing"); |
| 198 | + static_assert(!is_cross_lane<1, 0, 2, 3>(), "mixed low/high → no crossing"); |
| 199 | + static_assert(!is_cross_lane<0, 1, 2, 3>(), "mixed low/high → no crossing"); |
| 200 | + |
| 201 | + static_assert(no_duplicates_v<0, 1, 2, 3>(), "N=4: [0,1,2,3] → distinct"); |
| 202 | + static_assert(!no_duplicates_v<0, 1, 2, 2>(), "N=4: [0,1,2,2] → dup"); |
| 203 | + |
| 204 | + static_assert(no_duplicates_v<0, 1, 2, 3, 4, 5, 6, 7>(), "N=8: [0..7] → distinct"); |
| 205 | + static_assert(!no_duplicates_v<0, 1, 2, 3, 4, 5, 6, 0>(), "N=8: last repeats 0"); |
| 206 | + |
| 207 | + } // namespace detail |
| 208 | + } // namespace kernel |
| 209 | +} // namespace xsimd |
| 210 | + |
| 211 | +#endif // XSIMD_COMMON_SWIZZLE_HPP |
0 commit comments