Skip to content

Commit 05672d8

Browse files
committed
Use generic SIMD masked load/stores for avx512 masked load/stores
1 parent 09a537e commit 05672d8

File tree

3 files changed

+134
-187
lines changed

3 files changed

+134
-187
lines changed

crates/core_arch/src/macros.rs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,17 @@ macro_rules! simd_extract {
163163
($x:expr, $idx:expr $(,)?) => {{ $crate::intrinsics::simd::simd_extract($x, const { $idx }) }};
164164
($x:expr, $idx:expr, $ty:ty $(,)?) => {{ $crate::intrinsics::simd::simd_extract::<_, $ty>($x, const { $idx }) }};
165165
}
166+
167+
#[allow(unused)]
168+
macro_rules! simd_masked_load {
169+
($align:expr, $mask:expr, $ptr:expr, $default:expr) => {
170+
$crate::intrinsics::simd::simd_masked_load::<_, _, _, { $align }>($mask, $ptr, $default)
171+
};
172+
}
173+
174+
#[allow(unused)]
175+
macro_rules! simd_masked_store {
176+
($align:expr, $mask:expr, $ptr:expr, $default:expr) => {
177+
$crate::intrinsics::simd::simd_masked_store::<_, _, _, { $align }>($mask, $ptr, $default)
178+
};
179+
}

crates/core_arch/src/x86/avx512bw.rs

Lines changed: 24 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5609,7 +5609,8 @@ pub unsafe fn _mm_storeu_epi8(mem_addr: *mut i8, a: __m128i) {
56095609
#[cfg_attr(test, assert_instr(vmovdqu16))]
56105610
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
56115611
pub unsafe fn _mm512_mask_loadu_epi16(src: __m512i, k: __mmask32, mem_addr: *const i16) -> __m512i {
5612-
transmute(loaddqu16_512(mem_addr, src.as_i16x32(), k))
5612+
let mask = simd_select_bitmask(k, i16x32::splat(!0), i16x32::ZERO);
5613+
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, src.as_i16x32()).as_m512i()
56135614
}
56145615

56155616
/// Load packed 16-bit integers from memory into dst using zeromask k
@@ -5635,7 +5636,8 @@ pub unsafe fn _mm512_maskz_loadu_epi16(k: __mmask32, mem_addr: *const i16) -> __
56355636
#[cfg_attr(test, assert_instr(vmovdqu8))]
56365637
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
56375638
pub unsafe fn _mm512_mask_loadu_epi8(src: __m512i, k: __mmask64, mem_addr: *const i8) -> __m512i {
5638-
transmute(loaddqu8_512(mem_addr, src.as_i8x64(), k))
5639+
let mask = simd_select_bitmask(k, i8x64::splat(!0), i8x64::ZERO);
5640+
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, src.as_i8x64()).as_m512i()
56395641
}
56405642

56415643
/// Load packed 8-bit integers from memory into dst using zeromask k
@@ -5661,7 +5663,8 @@ pub unsafe fn _mm512_maskz_loadu_epi8(k: __mmask64, mem_addr: *const i8) -> __m5
56615663
#[cfg_attr(test, assert_instr(vmovdqu16))]
56625664
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
56635665
pub unsafe fn _mm256_mask_loadu_epi16(src: __m256i, k: __mmask16, mem_addr: *const i16) -> __m256i {
5664-
transmute(loaddqu16_256(mem_addr, src.as_i16x16(), k))
5666+
let mask = simd_select_bitmask(k, i16x16::splat(!0), i16x16::ZERO);
5667+
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, src.as_i16x16()).as_m256i()
56655668
}
56665669

56675670
/// Load packed 16-bit integers from memory into dst using zeromask k
@@ -5687,7 +5690,8 @@ pub unsafe fn _mm256_maskz_loadu_epi16(k: __mmask16, mem_addr: *const i16) -> __
56875690
#[cfg_attr(test, assert_instr(vmovdqu8))]
56885691
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
56895692
pub unsafe fn _mm256_mask_loadu_epi8(src: __m256i, k: __mmask32, mem_addr: *const i8) -> __m256i {
5690-
transmute(loaddqu8_256(mem_addr, src.as_i8x32(), k))
5693+
let mask = simd_select_bitmask(k, i8x32::splat(!0), i8x32::ZERO);
5694+
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, src.as_i8x32()).as_m256i()
56915695
}
56925696

56935697
/// Load packed 8-bit integers from memory into dst using zeromask k
@@ -5713,7 +5717,8 @@ pub unsafe fn _mm256_maskz_loadu_epi8(k: __mmask32, mem_addr: *const i8) -> __m2
57135717
#[cfg_attr(test, assert_instr(vmovdqu16))]
57145718
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
57155719
pub unsafe fn _mm_mask_loadu_epi16(src: __m128i, k: __mmask8, mem_addr: *const i16) -> __m128i {
5716-
transmute(loaddqu16_128(mem_addr, src.as_i16x8(), k))
5720+
let mask = simd_select_bitmask(k, i16x8::splat(!0), i16x8::ZERO);
5721+
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, src.as_i16x8()).as_m128i()
57175722
}
57185723

57195724
/// Load packed 16-bit integers from memory into dst using zeromask k
@@ -5739,7 +5744,8 @@ pub unsafe fn _mm_maskz_loadu_epi16(k: __mmask8, mem_addr: *const i16) -> __m128
57395744
#[cfg_attr(test, assert_instr(vmovdqu8))]
57405745
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
57415746
pub unsafe fn _mm_mask_loadu_epi8(src: __m128i, k: __mmask16, mem_addr: *const i8) -> __m128i {
5742-
transmute(loaddqu8_128(mem_addr, src.as_i8x16(), k))
5747+
let mask = simd_select_bitmask(k, i8x16::splat(!0), i8x16::ZERO);
5748+
simd_masked_load!(SimdAlign::Unaligned, mask, mem_addr, src.as_i8x16()).as_m128i()
57435749
}
57445750

57455751
/// Load packed 8-bit integers from memory into dst using zeromask k
@@ -5764,7 +5770,8 @@ pub unsafe fn _mm_maskz_loadu_epi8(k: __mmask16, mem_addr: *const i8) -> __m128i
57645770
#[cfg_attr(test, assert_instr(vmovdqu16))]
57655771
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
57665772
pub unsafe fn _mm512_mask_storeu_epi16(mem_addr: *mut i16, mask: __mmask32, a: __m512i) {
5767-
storedqu16_512(mem_addr, a.as_i16x32(), mask)
5773+
let mask = simd_select_bitmask(mask, i16x32::splat(!0), i16x32::ZERO);
5774+
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i16x32());
57685775
}
57695776

57705777
/// Store packed 8-bit integers from a into memory using writemask k.
@@ -5776,7 +5783,8 @@ pub unsafe fn _mm512_mask_storeu_epi16(mem_addr: *mut i16, mask: __mmask32, a: _
57765783
#[cfg_attr(test, assert_instr(vmovdqu8))]
57775784
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
57785785
pub unsafe fn _mm512_mask_storeu_epi8(mem_addr: *mut i8, mask: __mmask64, a: __m512i) {
5779-
storedqu8_512(mem_addr, a.as_i8x64(), mask)
5786+
let mask = simd_select_bitmask(mask, i8x64::splat(!0), i8x64::ZERO);
5787+
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i8x64());
57805788
}
57815789

57825790
/// Store packed 16-bit integers from a into memory using writemask k.
@@ -5788,7 +5796,8 @@ pub unsafe fn _mm512_mask_storeu_epi8(mem_addr: *mut i8, mask: __mmask64, a: __m
57885796
#[cfg_attr(test, assert_instr(vmovdqu16))]
57895797
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
57905798
pub unsafe fn _mm256_mask_storeu_epi16(mem_addr: *mut i16, mask: __mmask16, a: __m256i) {
5791-
storedqu16_256(mem_addr, a.as_i16x16(), mask)
5799+
let mask = simd_select_bitmask(mask, i16x16::splat(!0), i16x16::ZERO);
5800+
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i16x16());
57925801
}
57935802

57945803
/// Store packed 8-bit integers from a into memory using writemask k.
@@ -5800,7 +5809,8 @@ pub unsafe fn _mm256_mask_storeu_epi16(mem_addr: *mut i16, mask: __mmask16, a: _
58005809
#[cfg_attr(test, assert_instr(vmovdqu8))]
58015810
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
58025811
pub unsafe fn _mm256_mask_storeu_epi8(mem_addr: *mut i8, mask: __mmask32, a: __m256i) {
5803-
storedqu8_256(mem_addr, a.as_i8x32(), mask)
5812+
let mask = simd_select_bitmask(mask, i8x32::splat(!0), i8x32::ZERO);
5813+
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i8x32());
58045814
}
58055815

58065816
/// Store packed 16-bit integers from a into memory using writemask k.
@@ -5812,7 +5822,8 @@ pub unsafe fn _mm256_mask_storeu_epi8(mem_addr: *mut i8, mask: __mmask32, a: __m
58125822
#[cfg_attr(test, assert_instr(vmovdqu16))]
58135823
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
58145824
pub unsafe fn _mm_mask_storeu_epi16(mem_addr: *mut i16, mask: __mmask8, a: __m128i) {
5815-
storedqu16_128(mem_addr, a.as_i16x8(), mask)
5825+
let mask = simd_select_bitmask(mask, i16x8::splat(!0), i16x8::ZERO);
5826+
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i16x8());
58165827
}
58175828

58185829
/// Store packed 8-bit integers from a into memory using writemask k.
@@ -5824,7 +5835,8 @@ pub unsafe fn _mm_mask_storeu_epi16(mem_addr: *mut i16, mask: __mmask8, a: __m12
58245835
#[cfg_attr(test, assert_instr(vmovdqu8))]
58255836
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
58265837
pub unsafe fn _mm_mask_storeu_epi8(mem_addr: *mut i8, mask: __mmask16, a: __m128i) {
5827-
storedqu8_128(mem_addr, a.as_i8x16(), mask)
5838+
let mask = simd_select_bitmask(mask, i8x16::splat(!0), i8x16::ZERO);
5839+
simd_masked_store!(SimdAlign::Unaligned, mask, mem_addr, a.as_i8x16());
58285840
}
58295841

58305842
/// Multiply packed signed 16-bit integers in a and b, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers, and pack the results in dst.
@@ -11733,33 +11745,6 @@ unsafe extern "C" {
1173311745
fn vpmovuswbmem256(mem_addr: *mut i8, a: i16x16, mask: u16);
1173411746
#[link_name = "llvm.x86.avx512.mask.pmovus.wb.mem.128"]
1173511747
fn vpmovuswbmem128(mem_addr: *mut i8, a: i16x8, mask: u8);
11736-
11737-
#[link_name = "llvm.x86.avx512.mask.loadu.b.128"]
11738-
fn loaddqu8_128(mem_addr: *const i8, a: i8x16, mask: u16) -> i8x16;
11739-
#[link_name = "llvm.x86.avx512.mask.loadu.w.128"]
11740-
fn loaddqu16_128(mem_addr: *const i16, a: i16x8, mask: u8) -> i16x8;
11741-
#[link_name = "llvm.x86.avx512.mask.loadu.b.256"]
11742-
fn loaddqu8_256(mem_addr: *const i8, a: i8x32, mask: u32) -> i8x32;
11743-
#[link_name = "llvm.x86.avx512.mask.loadu.w.256"]
11744-
fn loaddqu16_256(mem_addr: *const i16, a: i16x16, mask: u16) -> i16x16;
11745-
#[link_name = "llvm.x86.avx512.mask.loadu.b.512"]
11746-
fn loaddqu8_512(mem_addr: *const i8, a: i8x64, mask: u64) -> i8x64;
11747-
#[link_name = "llvm.x86.avx512.mask.loadu.w.512"]
11748-
fn loaddqu16_512(mem_addr: *const i16, a: i16x32, mask: u32) -> i16x32;
11749-
11750-
#[link_name = "llvm.x86.avx512.mask.storeu.b.128"]
11751-
fn storedqu8_128(mem_addr: *mut i8, a: i8x16, mask: u16);
11752-
#[link_name = "llvm.x86.avx512.mask.storeu.w.128"]
11753-
fn storedqu16_128(mem_addr: *mut i16, a: i16x8, mask: u8);
11754-
#[link_name = "llvm.x86.avx512.mask.storeu.b.256"]
11755-
fn storedqu8_256(mem_addr: *mut i8, a: i8x32, mask: u32);
11756-
#[link_name = "llvm.x86.avx512.mask.storeu.w.256"]
11757-
fn storedqu16_256(mem_addr: *mut i16, a: i16x16, mask: u16);
11758-
#[link_name = "llvm.x86.avx512.mask.storeu.b.512"]
11759-
fn storedqu8_512(mem_addr: *mut i8, a: i8x64, mask: u64);
11760-
#[link_name = "llvm.x86.avx512.mask.storeu.w.512"]
11761-
fn storedqu16_512(mem_addr: *mut i16, a: i16x32, mask: u32);
11762-
1176311748
}
1176411749

1176511750
#[cfg(test)]

0 commit comments

Comments
 (0)