Skip to content

Commit 297dd76

Browse files
Add avx512 support for tranpose operator
1 parent 401e149 commit 297dd76

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

include/xsimd/arch/xsimd_avx512f.hpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ namespace xsimd
2525
{
2626
using namespace types;
2727

28+
// fwd
29+
template <class A>
30+
XSIMD_INLINE void transpose(batch<uint16_t, A>* matrix_begin, batch<uint16_t, A>* matrix_end, requires_arch<generic>) noexcept;
31+
template <class A>
32+
XSIMD_INLINE void transpose(batch<uint8_t, A>* matrix_begin, batch<uint8_t, A>* matrix_end, requires_arch<generic>) noexcept;
33+
2834
namespace detail
2935
{
3036
XSIMD_INLINE void split_avx512(__m512 val, __m256& low, __m256& high) noexcept
@@ -2010,6 +2016,79 @@ namespace xsimd
20102016
return bitwise_cast<int16_t>(swizzle(bitwise_cast<uint16_t>(self), mask, avx512f {}));
20112017
}
20122018

2019+
// transpose
2020+
template <class A>
2021+
XSIMD_INLINE void transpose(batch<uint16_t, A>* matrix_begin, batch<uint16_t, A>* matrix_end, requires_arch<avx512f>) noexcept
2022+
{
2023+
assert((matrix_end - matrix_begin == batch<uint16_t, A>::size) && "correctly sized matrix");
2024+
(void)matrix_end;
2025+
batch<uint16_t, avx2> tmp_lo0[16];
2026+
for (int i = 0; i < 16; ++i)
2027+
tmp_lo0[i] = _mm512_castsi512_si256(matrix_begin[i]);
2028+
transpose(tmp_lo0 + 0, tmp_lo0 + 16, avx2 {});
2029+
2030+
batch<uint16_t, avx2> tmp_hi0[16];
2031+
for (int i = 0; i < 16; ++i)
2032+
tmp_hi0[i] = _mm512_castsi512_si256(matrix_begin[16 + i]);
2033+
transpose(tmp_hi0 + 0, tmp_hi0 + 16, avx2 {});
2034+
2035+
batch<uint16_t, avx2> tmp_lo1[16];
2036+
for (int i = 0; i < 16; ++i)
2037+
tmp_lo1[i] = _mm512_extracti64x4_epi64(matrix_begin[i], 1);
2038+
transpose(tmp_lo1 + 0, tmp_lo1 + 16, avx2 {});
2039+
2040+
batch<uint16_t, avx2> tmp_hi1[16];
2041+
for (int i = 0; i < 16; ++i)
2042+
tmp_hi1[i] = _mm512_extracti64x4_epi64(matrix_begin[16 + i], 1);
2043+
transpose(tmp_hi1 + 0, tmp_hi1 + 16, avx2 {});
2044+
2045+
for (int i = 0; i < 16; ++i)
2046+
matrix_begin[i] = detail::merge_avx(tmp_lo0[i], tmp_hi0[i]);
2047+
for (int i = 0; i < 16; ++i)
2048+
matrix_begin[i + 16] = detail::merge_avx(tmp_lo1[i], tmp_hi1[i]);
2049+
}
2050+
template <class A>
2051+
XSIMD_INLINE void transpose(batch<int16_t, A>* matrix_begin, batch<int16_t, A>* matrix_end, requires_arch<avx512f>) noexcept
2052+
{
2053+
return transpose(reinterpret_cast<batch<uint16_t, A>*>(matrix_begin), reinterpret_cast<batch<uint16_t, A>*>(matrix_end), A {});
2054+
}
2055+
2056+
template <class A>
2057+
XSIMD_INLINE void transpose(batch<uint8_t, A>* matrix_begin, batch<uint8_t, A>* matrix_end, requires_arch<avx512f>) noexcept
2058+
{
2059+
assert((matrix_end - matrix_begin == batch<uint8_t, A>::size) && "correctly sized matrix");
2060+
(void)matrix_end;
2061+
batch<uint8_t, avx2> tmp_lo0[32];
2062+
for (int i = 0; i < 32; ++i)
2063+
tmp_lo0[i] = _mm512_castsi512_si256(matrix_begin[i]);
2064+
transpose(tmp_lo0 + 0, tmp_lo0 + 32, avx2 {});
2065+
2066+
batch<uint8_t, avx2> tmp_hi0[32];
2067+
for (int i = 0; i < 32; ++i)
2068+
tmp_hi0[i] = _mm512_castsi512_si256(matrix_begin[32 + i]);
2069+
transpose(tmp_hi0 + 0, tmp_hi0 + 32, avx2 {});
2070+
2071+
batch<uint8_t, avx2> tmp_lo1[32];
2072+
for (int i = 0; i < 32; ++i)
2073+
tmp_lo1[i] = _mm512_extracti64x4_epi64(matrix_begin[i], 1);
2074+
transpose(tmp_lo1 + 0, tmp_lo1 + 32, avx2 {});
2075+
2076+
batch<uint8_t, avx2> tmp_hi1[32];
2077+
for (int i = 0; i < 32; ++i)
2078+
tmp_hi1[i] = _mm512_extracti64x4_epi64(matrix_begin[32 + i], 1);
2079+
transpose(tmp_hi1 + 0, tmp_hi1 + 32, avx2 {});
2080+
2081+
for (int i = 0; i < 32; ++i)
2082+
matrix_begin[i] = detail::merge_avx(tmp_lo0[i], tmp_hi0[i]);
2083+
for (int i = 0; i < 32; ++i)
2084+
matrix_begin[i + 32] = detail::merge_avx(tmp_lo1[i], tmp_hi1[i]);
2085+
}
2086+
template <class A>
2087+
XSIMD_INLINE void transpose(batch<int8_t, A>* matrix_begin, batch<int8_t, A>* matrix_end, requires_arch<avx512f>) noexcept
2088+
{
2089+
return transpose(reinterpret_cast<batch<uint8_t, A>*>(matrix_begin), reinterpret_cast<batch<uint8_t, A>*>(matrix_end), A {});
2090+
}
2091+
20132092
// trunc
20142093
template <class A>
20152094
XSIMD_INLINE batch<float, A>

0 commit comments

Comments
 (0)