Skip to content

Commit 825d298

Browse files
DiamonDinoiaserge-sans-paille
authored andcommitted
improved avx/avx2 swizzles
Fixes: some swizzled did not allow duplicates in the output
1 parent 5af6bcd commit 825d298

File tree

7 files changed

+314
-64
lines changed

7 files changed

+314
-64
lines changed

.github/workflows/emulated.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
- { compiler: 'clang', version: '16'}
1818
steps:
1919
- name: Setup compiler
20-
if: ${{ matrix.sys.compiler == 'gcc' }}
20+
if: ${{ matrix.sys.compiler == 'gcc' }}
2121
run: |
2222
GCC_VERSION=${{ matrix.sys.version }}
2323
sudo apt-get update
@@ -31,7 +31,7 @@ jobs:
3131
- name: Setup compiler
3232
if: ${{ matrix.sys.compiler == 'clang' }}
3333
run: |
34-
LLVM_VERSION=${{ matrix.sys.version }}
34+
LLVM_VERSION=${{ matrix.sys.version }}
3535
sudo apt-get update || exit 1
3636
sudo apt-get --no-install-suggests --no-install-recommends install clang-$LLVM_VERSION || exit 1
3737
sudo apt-get --no-install-suggests --no-install-recommends install g++ g++-multilib || exit 1
@@ -49,7 +49,7 @@ jobs:
4949
- name: Configure build
5050
env:
5151
CC: ${{ env.CC }}
52-
CXX: ${{ env.CXX }}
52+
CXX: ${{ env.CXX }}
5353
run: |
5454
5555
mkdir _build

.github/workflows/linux.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
- { compiler: 'clang', version: '18', flags: 'avx512' }
3030
steps:
3131
- name: Setup compiler
32-
if: ${{ matrix.sys.compiler == 'gcc' }}
32+
if: ${{ matrix.sys.compiler == 'gcc' }}
3333
run: |
3434
GCC_VERSION=${{ matrix.sys.version }}
3535
sudo apt-get update
@@ -45,7 +45,7 @@ jobs:
4545
- name: Setup compiler
4646
if: ${{ matrix.sys.compiler == 'clang' }}
4747
run: |
48-
LLVM_VERSION=${{ matrix.sys.version }}
48+
LLVM_VERSION=${{ matrix.sys.version }}
4949
sudo apt-get update || exit 1
5050
sudo apt-get --no-install-suggests --no-install-recommends install clang-$LLVM_VERSION || exit 1
5151
sudo apt-get --no-install-suggests --no-install-recommends install g++ g++-multilib || exit 1
@@ -66,7 +66,7 @@ jobs:
6666
- name: Configure build
6767
env:
6868
CC: ${{ env.CC }}
69-
CXX: ${{ env.CXX }}
69+
CXX: ${{ env.CXX }}
7070
run: |
7171
if [[ '${{ matrix.sys.flags }}' == 'enable_xtl_complex' ]]; then
7272
CMAKE_EXTRA_ARGS="$CMAKE_EXTRA_ARGS -DENABLE_XTL_COMPLEX=ON"
Lines changed: 211 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,211 @@
1+
/***************************************************************************
2+
* Copyright (c) Johan Mabille, Sylvain Corlay, Wolf Vollprecht and *
3+
* Martin Renou *
4+
* Copyright (c) QuantStack *
5+
* Copyright (c) Serge Guelton *
6+
* Copyright (c) Marco Barbone *
7+
* *
8+
* Distributed under the terms of the BSD 3-Clause License. *
9+
* *
10+
* The full license is in the file LICENSE, distributed with this software.*
11+
****************************************************************************/
12+
#ifndef XSIMD_COMMON_SWIZZLE_HPP
13+
#define XSIMD_COMMON_SWIZZLE_HPP
14+
15+
#include <cstddef>
16+
#include <cstdint>
17+
#include <type_traits>
18+
19+
namespace xsimd
20+
{
21+
template <typename T, class A, T... Values>
22+
struct batch_constant;
23+
24+
namespace kernel
25+
{
26+
namespace detail
27+
{
28+
// ────────────────────────────────────────────────────────────────────────
29+
// get_at<I,Values...> → the I-th element of the pack
30+
template <typename T, std::size_t I, T V0, T... Vs>
31+
struct get_at
32+
{
33+
static constexpr T value = get_at<T, I - 1, Vs...>::value;
34+
};
35+
template <typename T, T V0, T... Vs>
36+
struct get_at<T, 0, V0, Vs...>
37+
{
38+
static constexpr T value = V0;
39+
};
40+
41+
// ────────────────────────────────────────────────────────────────────────
42+
// 1) identity_impl
43+
template <std::size_t /*I*/, typename T>
44+
XSIMD_INLINE constexpr bool identity_impl() noexcept { return true; }
45+
template <std::size_t I, typename T, T V0, T... Vs>
46+
XSIMD_INLINE constexpr bool identity_impl() noexcept
47+
{
48+
return V0 == static_cast<T>(I)
49+
&& identity_impl<I + 1, T, Vs...>();
50+
}
51+
52+
// ────────────────────────────────────────────────────────────────────────
53+
// 2) bitmask_impl
54+
template <std::size_t /*I*/, std::size_t /*N*/, typename T>
55+
XSIMD_INLINE constexpr std::uint32_t bitmask_impl() noexcept { return 0u; }
56+
template <std::size_t I, std::size_t N, typename T, T V0, T... Vs>
57+
XSIMD_INLINE constexpr std::uint32_t bitmask_impl() noexcept
58+
{
59+
return (1u << (static_cast<std::uint32_t>(V0) & (N - 1)))
60+
| bitmask_impl<I + 1, N, T, Vs...>();
61+
}
62+
63+
// ────────────────────────────────────────────────────────────────────────
64+
// 3) dup_lo_impl
65+
template <std::size_t I, std::size_t N, typename T,
66+
T... Vs, typename std::enable_if<I == N / 2, int>::type = 0>
67+
XSIMD_INLINE constexpr bool dup_lo_impl() noexcept { return true; }
68+
69+
template <std::size_t I, std::size_t N, typename T,
70+
T... Vs, typename std::enable_if<(I < N / 2), int>::type = 0>
71+
XSIMD_INLINE constexpr bool dup_lo_impl() noexcept
72+
{
73+
return get_at<T, I, Vs...>::value < static_cast<T>(N / 2)
74+
&& get_at<T, I + N / 2, Vs...>::value == get_at<T, I, Vs...>::value
75+
&& dup_lo_impl<I + 1, N, T, Vs...>();
76+
}
77+
78+
// ────────────────────────────────────────────────────────────────────────
79+
// 4) dup_hi_impl
80+
template <std::size_t I, std::size_t N, typename T,
81+
T... Vs, typename std::enable_if<I == N / 2, int>::type = 0>
82+
XSIMD_INLINE constexpr bool dup_hi_impl() noexcept { return true; }
83+
84+
template <std::size_t I, std::size_t N, typename T,
85+
T... Vs, typename std::enable_if<(I < N / 2), int>::type = 0>
86+
XSIMD_INLINE constexpr bool dup_hi_impl() noexcept
87+
{
88+
return get_at<T, I, Vs...>::value >= static_cast<T>(N / 2)
89+
&& get_at<T, I, Vs...>::value < static_cast<T>(N)
90+
&& get_at<T, I + N / 2, Vs...>::value == get_at<T, I, Vs...>::value
91+
&& dup_hi_impl<I + 1, N, T, Vs...>();
92+
}
93+
94+
// ────────────────────────────────────────────────────────────────────────
95+
// 1) helper to get the I-th value from the Vs pack
96+
template <std::size_t I, uint32_t Head, uint32_t... Tail>
97+
struct get_nth_value
98+
{
99+
static constexpr uint32_t value = get_nth_value<I - 1, Tail...>::value;
100+
};
101+
template <uint32_t Head, uint32_t... Tail>
102+
struct get_nth_value<0, Head, Tail...>
103+
{
104+
static constexpr uint32_t value = Head;
105+
};
106+
107+
// ────────────────────────────────────────────────────────────────────────
108+
// 2) recursive cross‐lane test: true if any output‐lane i pulls from the opposite half
109+
template <std::size_t I,
110+
std::size_t N,
111+
std::size_t H,
112+
uint32_t... Vs>
113+
struct cross_impl
114+
{
115+
// does element I cross? (i.e. i<H but V>=H) or (i>=H but V<H)
116+
static constexpr uint32_t Vi = get_nth_value<I, Vs...>::value;
117+
static constexpr bool curr = (I < H ? (Vi >= H) : (Vi < H));
118+
static constexpr bool next = cross_impl<I + 1, N, H, Vs...>::value;
119+
static constexpr bool value = curr || next;
120+
};
121+
template <std::size_t N, std::size_t H, uint32_t... Vs>
122+
struct cross_impl<N, N, H, Vs...>
123+
{
124+
static constexpr bool value = false;
125+
};
126+
template <std::size_t I, std::size_t N, typename T,
127+
T... Vs>
128+
XSIMD_INLINE constexpr bool no_duplicates_impl() noexcept
129+
{
130+
// build the bitmask of (Vs & (N-1)) across all lanes
131+
return detail::bitmask_impl<0, N, T, Vs...>() == ((1u << N) - 1u);
132+
}
133+
template <uint32_t... Vs>
134+
XSIMD_INLINE constexpr bool no_duplicates_v() noexcept
135+
{
136+
// forward to your existing no_duplicates_impl
137+
return no_duplicates_impl<0, sizeof...(Vs), uint32_t, Vs...>();
138+
}
139+
template <uint32_t... Vs>
140+
XSIMD_INLINE constexpr bool is_cross_lane() noexcept
141+
{
142+
static_assert(sizeof...(Vs) >= 1, "Need at least one lane");
143+
return cross_impl<0, sizeof...(Vs), sizeof...(Vs) / 2, Vs...>::value;
144+
}
145+
template <typename T, T... Vs>
146+
XSIMD_INLINE constexpr bool is_identity() noexcept { return detail::identity_impl<0, T, Vs...>(); }
147+
template <typename T, T... Vs>
148+
XSIMD_INLINE constexpr bool is_all_different() noexcept
149+
{
150+
return detail::bitmask_impl<0, sizeof...(Vs), T, Vs...>() == ((1u << sizeof...(Vs)) - 1);
151+
}
152+
153+
template <typename T, T... Vs>
154+
XSIMD_INLINE constexpr bool is_dup_lo() noexcept { return detail::dup_lo_impl<0, sizeof...(Vs), T, Vs...>(); }
155+
template <typename T, T... Vs>
156+
XSIMD_INLINE constexpr bool is_dup_hi() noexcept { return detail::dup_hi_impl<0, sizeof...(Vs), T, Vs...>(); }
157+
template <typename T, class A, T... Vs>
158+
XSIMD_INLINE constexpr bool is_identity(batch_constant<T, A, Vs...>) noexcept { return is_identity<T, Vs...>(); }
159+
template <typename T, class A, T... Vs>
160+
XSIMD_INLINE constexpr bool is_all_different(batch_constant<T, A, Vs...>) noexcept { return is_all_different<T, Vs...>(); }
161+
template <typename T, class A, T... Vs>
162+
XSIMD_INLINE constexpr bool is_dup_lo(batch_constant<T, A, Vs...>) noexcept { return is_dup_lo<T, Vs...>(); }
163+
template <typename T, class A, T... Vs>
164+
XSIMD_INLINE constexpr bool is_dup_hi(batch_constant<T, A, Vs...>) noexcept { return is_dup_hi<T, Vs...>(); }
165+
template <typename T, class A, T... Vs>
166+
XSIMD_INLINE constexpr bool is_cross_lane(batch_constant<T, A, Vs...>) noexcept { return detail::is_cross_lane<Vs...>(); }
167+
template <typename T, class A, T... Vs>
168+
XSIMD_INLINE constexpr bool no_duplicates(batch_constant<T, A, Vs...>) noexcept { return no_duplicates_impl<0, sizeof...(Vs), T, Vs...>(); }
169+
// ────────────────────────────────────────────────────────────────────────
170+
// compile-time tests (identity, all-different, dup-lo, dup-hi)
171+
// 8-lane identity
172+
static_assert(is_identity<std::uint32_t, 0, 1, 2, 3, 4, 5, 6, 7>(), "identity failed");
173+
// 8-lane reverse is all-different but not identity
174+
static_assert(is_all_different<std::uint32_t, 7, 6, 5, 4, 3, 2, 1, 0>(), "all-diff failed");
175+
static_assert(!is_identity<std::uint32_t, 7, 6, 5, 4, 3, 2, 1, 0>(), "identity on reverse");
176+
// 8-lane dup-lo (repeat 0..3 twice)
177+
static_assert(is_dup_lo<std::uint32_t, 0, 1, 2, 3, 0, 1, 2, 3>(), "dup_lo failed");
178+
static_assert(!is_dup_hi<std::uint32_t, 0, 1, 2, 3, 0, 1, 2, 3>(), "dup_hi on dup_lo");
179+
// 8-lane dup-hi (repeat 4..7 twice)
180+
static_assert(is_dup_hi<std::uint32_t, 4, 5, 6, 7, 4, 5, 6, 7>(), "dup_hi failed");
181+
static_assert(!is_dup_lo<std::uint32_t, 4, 5, 6, 7, 4, 5, 6, 7>(), "dup_lo on dup_hi");
182+
// ────────────────────────────────────────────────────────────────────────
183+
// 4-lane identity
184+
static_assert(is_identity<std::uint32_t, 0, 1, 2, 3>(), "4-lane identity failed");
185+
// 4-lane reverse all-different but not identity
186+
static_assert(is_all_different<std::uint32_t, 3, 2, 1, 0>(), "4-lane all-diff failed");
187+
static_assert(!is_identity<std::uint32_t, 3, 2, 1, 0>(), "4-lane identity on reverse");
188+
// 4-lane dup-lo (repeat 0..1 twice)
189+
static_assert(is_dup_lo<std::uint32_t, 0, 1, 0, 1>(), "4-lane dup_lo failed");
190+
static_assert(!is_dup_hi<std::uint32_t, 0, 1, 0, 1>(), "4-lane dup_hi on dup_lo");
191+
// 4-lane dup-hi (repeat 2..3 twice)
192+
static_assert(is_dup_hi<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_hi failed");
193+
static_assert(!is_dup_lo<std::uint32_t, 2, 3, 2, 3>(), "4-lane dup_lo on dup_hi");
194+
195+
static_assert(is_cross_lane<0, 1, 0, 1>(), "dup-lo only → crossing");
196+
static_assert(is_cross_lane<2, 3, 2, 3>(), "dup-hi only → crossing");
197+
static_assert(is_cross_lane<0, 3, 3, 3>(), "one low + rest high → crossing");
198+
static_assert(!is_cross_lane<1, 0, 2, 3>(), "mixed low/high → no crossing");
199+
static_assert(!is_cross_lane<0, 1, 2, 3>(), "mixed low/high → no crossing");
200+
201+
static_assert(no_duplicates_v<0, 1, 2, 3>(), "N=4: [0,1,2,3] → distinct");
202+
static_assert(!no_duplicates_v<0, 1, 2, 2>(), "N=4: [0,1,2,2] → dup");
203+
204+
static_assert(no_duplicates_v<0, 1, 2, 3, 4, 5, 6, 7>(), "N=8: [0..7] → distinct");
205+
static_assert(!no_duplicates_v<0, 1, 2, 3, 4, 5, 6, 0>(), "N=8: last repeats 0");
206+
207+
} // namespace detail
208+
} // namespace kernel
209+
} // namespace xsimd
210+
211+
#endif // XSIMD_COMMON_SWIZZLE_HPP

0 commit comments

Comments
 (0)