@@ -208,18 +208,93 @@ impl DebugHexF16 for __m512i {
208208}
209209 "# ;
210210
211- pub const LANE_FUNCTION_HELPERS : & str = r#"
212- typedef _Float16 float16_t;
213- typedef float float32_t;
214- typedef double float64_t;
215-
216- #define __int64 long long
217- #define __int32 int
211+ pub const PLATFORM_C_FORWARD_DECLARATIONS : & str = r#"
212+ #ifndef X86_DECLARATIONS
213+ #define X86_DECLARATIONS
214+ typedef _Float16 float16_t;
215+ typedef float float32_t;
216+ typedef double float64_t;
217+
218+ #define __int64 long long
219+ #define __int32 int
218220
219- std::ostream& operator<<(std::ostream& os, _Float16 value);
220- std::ostream& operator<<(std::ostream& os, __m128i value);
221- std::ostream& operator<<(std::ostream& os, __m256i value);
222- std::ostream& operator<<(std::ostream& os, __m512i value);
221+ std::ostream& operator<<(std::ostream& os, _Float16 value);
222+ std::ostream& operator<<(std::ostream& os, __m128i value);
223+ std::ostream& operator<<(std::ostream& os, __m256i value);
224+ std::ostream& operator<<(std::ostream& os, __m512i value);
225+
226+ #define _mm512_extract_intrinsic_test_epi8(m, lane) \
227+ _mm_extract_epi8(_mm512_extracti64x2_epi64((m), (lane) / 16), (lane) % 16)
228+
229+ #define _mm512_extract_intrinsic_test_epi16(m, lane) \
230+ _mm_extract_epi16(_mm512_extracti64x2_epi64((m), (lane) / 8), (lane) % 8)
231+
232+ #define _mm512_extract_intrinsic_test_epi32(m, lane) \
233+ _mm_extract_epi32(_mm512_extracti64x2_epi64((m), (lane) / 4), (lane) % 4)
234+
235+ #define _mm512_extract_intrinsic_test_epi64(m, lane) \
236+ _mm_extract_epi64(_mm512_extracti64x2_epi64((m), (lane) / 2), (lane) % 2)
237+
238+ #define _mm64_extract_intrinsic_test_epi8(m, lane) \
239+ ((_mm_extract_pi16((m), (lane) / 2) >> (((lane) % 2) * 8)) & 0xFF)
240+
241+ #define _mm64_extract_intrinsic_test_epi32(m, lane) \
242+ _mm_cvtsi64_si32(_mm_srli_si64(m, (lane) * 32))
243+
244+ // Load f16 (__m128h) and cast to integer (__m128i)
245+ #define _mm_loadu_ph_to___m128i(mem_addr) _mm_castph_si128(_mm_loadu_ph(mem_addr))
246+ #define _mm256_loadu_ph_to___m256i(mem_addr) _mm256_castph_si256(_mm256_loadu_ph(mem_addr))
247+ #define _mm512_loadu_ph_to___m512i(mem_addr) _mm512_castph_si512(_mm512_loadu_ph(mem_addr))
248+
249+ // Load f32 (__m128) and cast to f16 (__m128h)
250+ #define _mm_loadu_ps_to___m128h(mem_addr) _mm_castps_ph(_mm_loadu_ps(mem_addr))
251+ #define _mm256_loadu_ps_to___m256h(mem_addr) _mm256_castps_ph(_mm256_loadu_ps(mem_addr))
252+ #define _mm512_loadu_ps_to___m512h(mem_addr) _mm512_castps_ph(_mm512_loadu_ps(mem_addr))
253+
254+ // Load integer types and cast to double (__m128d, __m256d, __m512d)
255+ #define _mm_loadu_epi16_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
256+ #define _mm256_loadu_epi16_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
257+ #define _mm512_loadu_epi16_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
258+
259+ #define _mm_loadu_epi32_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
260+ #define _mm256_loadu_epi32_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
261+ #define _mm512_loadu_epi32_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
262+
263+ #define _mm_loadu_epi64_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
264+ #define _mm256_loadu_epi64_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
265+ #define _mm512_loadu_epi64_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
266+
267+ // Load integer types and cast to float (__m128, __m256, __m512)
268+ #define _mm_loadu_epi16_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
269+ #define _mm256_loadu_epi16_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
270+ #define _mm512_loadu_epi16_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
271+
272+ #define _mm_loadu_epi32_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
273+ #define _mm256_loadu_epi32_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
274+ #define _mm512_loadu_epi32_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
275+
276+ #define _mm_loadu_epi64_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
277+ #define _mm256_loadu_epi64_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
278+ #define _mm512_loadu_epi64_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
279+
280+
281+ // T1 is the `To` type, T2 is the `From` type
282+ template<typename T1, typename T2> T1 cast(T2 x) {
283+ if constexpr (std::is_convertible_v<T2, T1>) {
284+ return x;
285+ } else if constexpr (sizeof(T1) == sizeof(T2)) {
286+ T1 ret{};
287+ std::memcpy(&ret, &x, sizeof(T1));
288+ return ret;
289+ } else {
290+ static_assert(sizeof(T1) == sizeof(T2) || std::is_convertible_v<T2, T1>,
291+ "T2 must either be convertible to T1, or have the same size as T1!");
292+ return T1{};
293+ }
294+ }
295+ #endif
296+ "# ;
297+ pub const PLATFORM_C_DEFINITIONS : & str = r#"
223298
224299std::ostream& operator<<(std::ostream& os, _Float16 value) {
225300 uint16_t temp = 0;
@@ -268,74 +343,6 @@ std::ostream& operator<<(std::ostream& os, __m512i value) {
268343 os << ss.str();
269344 return os;
270345}
271-
272- // T1 is the `To` type, T2 is the `From` type
273- template<typename T1, typename T2> T1 cast(T2 x) {
274- if (std::is_convertible<T2, T1>::value) {
275- return x;
276- } else if (sizeof(T1) == sizeof(T2)) {
277- T1 ret{};
278- memcpy(&ret, &x, sizeof(T1));
279- return ret;
280- } else {
281- assert("T2 must either be convertable to T1, or have the same size as T1!");
282- }
283- }
284-
285- #define _mm512_extract_intrinsic_test_epi8(m, lane) \
286- _mm_extract_epi8(_mm512_extracti64x2_epi64((m), (lane) / 16), (lane) % 16)
287-
288- #define _mm512_extract_intrinsic_test_epi16(m, lane) \
289- _mm_extract_epi16(_mm512_extracti64x2_epi64((m), (lane) / 8), (lane) % 8)
290-
291- #define _mm512_extract_intrinsic_test_epi32(m, lane) \
292- _mm_extract_epi32(_mm512_extracti64x2_epi64((m), (lane) / 4), (lane) % 4)
293-
294- #define _mm512_extract_intrinsic_test_epi64(m, lane) \
295- _mm_extract_epi64(_mm512_extracti64x2_epi64((m), (lane) / 2), (lane) % 2)
296-
297- #define _mm64_extract_intrinsic_test_epi8(m, lane) \
298- ((_mm_extract_pi16((m), (lane) / 2) >> (((lane) % 2) * 8)) & 0xFF)
299-
300- #define _mm64_extract_intrinsic_test_epi32(m, lane) \
301- _mm_cvtsi64_si32(_mm_srli_si64(m, (lane) * 32))
302-
303- // Load f16 (__m128h) and cast to integer (__m128i)
304- #define _mm_loadu_ph_to___m128i(mem_addr) _mm_castph_si128(_mm_loadu_ph(mem_addr))
305- #define _mm256_loadu_ph_to___m256i(mem_addr) _mm256_castph_si256(_mm256_loadu_ph(mem_addr))
306- #define _mm512_loadu_ph_to___m512i(mem_addr) _mm512_castph_si512(_mm512_loadu_ph(mem_addr))
307-
308- // Load f32 (__m128) and cast to f16 (__m128h)
309- #define _mm_loadu_ps_to___m128h(mem_addr) _mm_castps_ph(_mm_loadu_ps(mem_addr))
310- #define _mm256_loadu_ps_to___m256h(mem_addr) _mm256_castps_ph(_mm256_loadu_ps(mem_addr))
311- #define _mm512_loadu_ps_to___m512h(mem_addr) _mm512_castps_ph(_mm512_loadu_ps(mem_addr))
312-
313- // Load integer types and cast to double (__m128d, __m256d, __m512d)
314- #define _mm_loadu_epi16_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
315- #define _mm256_loadu_epi16_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
316- #define _mm512_loadu_epi16_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
317-
318- #define _mm_loadu_epi32_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
319- #define _mm256_loadu_epi32_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
320- #define _mm512_loadu_epi32_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
321-
322- #define _mm_loadu_epi64_to___m128d(mem_addr) _mm_castsi128_pd(_mm_loadu_si128((__m128i const*)(mem_addr)))
323- #define _mm256_loadu_epi64_to___m256d(mem_addr) _mm256_castsi256_pd(_mm256_loadu_si256((__m256i const*)(mem_addr)))
324- #define _mm512_loadu_epi64_to___m512d(mem_addr) _mm512_castsi512_pd(_mm512_loadu_si512((__m512i const*)(mem_addr)))
325-
326- // Load integer types and cast to float (__m128, __m256, __m512)
327- #define _mm_loadu_epi16_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
328- #define _mm256_loadu_epi16_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
329- #define _mm512_loadu_epi16_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
330-
331- #define _mm_loadu_epi32_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
332- #define _mm256_loadu_epi32_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
333- #define _mm512_loadu_epi32_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
334-
335- #define _mm_loadu_epi64_to___m128(mem_addr) _mm_castsi128_ps(_mm_loadu_si128((__m128i const*)(mem_addr)))
336- #define _mm256_loadu_epi64_to___m256(mem_addr) _mm256_castsi256_ps(_mm256_loadu_si256((__m256i const*)(mem_addr)))
337- #define _mm512_loadu_epi64_to___m512(mem_addr) _mm512_castsi512_ps(_mm512_loadu_si512((__m512i const*)(mem_addr)))
338-
339346"# ;
340347
341348pub const X86_CONFIGURATIONS : & str = r#"
0 commit comments