Skip to content

Commit 4bf9a90

Browse files
committed
Add AMX intrinsics
1 parent e3accfc commit 4bf9a90

File tree

2 files changed

+226
-1
lines changed

2 files changed

+226
-1
lines changed

crates/core_arch/src/lib.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434
f16,
3535
aarch64_unstable_target_feature,
3636
bigint_helper_methods,
37-
funnel_shifts
37+
funnel_shifts,
38+
avx10_target_feature
3839
)]
3940
#![cfg_attr(test, feature(test, abi_vectorcall, stdarch_internal))]
4041
#![deny(clippy::missing_inline_in_public_items)]

crates/core_arch/src/x86_64/amx.rs

Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use crate::core_arch::{simd::*, x86::*};
2+
13
#[cfg(test)]
24
use stdarch_test::assert_instr;
35

@@ -242,6 +244,206 @@ pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
242244
tcmmrlfp16ps(DST as i8, A as i8, B as i8);
243245
}
244246

247+
/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2)
248+
/// floating-point elements in tile b, accumulating the intermediate single-precision
249+
/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
250+
/// back to tile dst.
251+
#[inline]
252+
#[rustc_legacy_const_generics(0, 1, 2)]
253+
#[target_feature(enable = "amx-fp8")]
254+
#[cfg_attr(
255+
all(test, any(target_os = "linux", target_env = "msvc")),
256+
assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257+
)]
258+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
259+
pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
260+
static_assert_uimm_bits!(DST, 3);
261+
static_assert_uimm_bits!(A, 3);
262+
static_assert_uimm_bits!(B, 3);
263+
tdpbf8ps(DST as i8, A as i8, B as i8);
264+
}
265+
266+
/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8
267+
/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision
268+
/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
269+
/// back to tile dst.
270+
#[inline]
271+
#[rustc_legacy_const_generics(0, 1, 2)]
272+
#[target_feature(enable = "amx-fp8")]
273+
#[cfg_attr(
274+
all(test, any(target_os = "linux", target_env = "msvc")),
275+
assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276+
)]
277+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
278+
pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
279+
static_assert_uimm_bits!(DST, 3);
280+
static_assert_uimm_bits!(A, 3);
281+
static_assert_uimm_bits!(B, 3);
282+
tdpbhf8ps(DST as i8, A as i8, B as i8);
283+
}
284+
285+
/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8
286+
/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision
287+
/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
288+
/// back to tile dst.
289+
#[inline]
290+
#[rustc_legacy_const_generics(0, 1, 2)]
291+
#[target_feature(enable = "amx-fp8")]
292+
#[cfg_attr(
293+
all(test, any(target_os = "linux", target_env = "msvc")),
294+
assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295+
)]
296+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
297+
pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
298+
static_assert_uimm_bits!(DST, 3);
299+
static_assert_uimm_bits!(A, 3);
300+
static_assert_uimm_bits!(B, 3);
301+
tdphbf8ps(DST as i8, A as i8, B as i8);
302+
}
303+
304+
/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3)
305+
/// floating-point elements in tile b, accumulating the intermediate single-precision
306+
/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
307+
/// back to tile dst.
308+
#[inline]
309+
#[rustc_legacy_const_generics(0, 1, 2)]
310+
#[target_feature(enable = "amx-fp8")]
311+
#[cfg_attr(
312+
all(test, any(target_os = "linux", target_env = "msvc")),
313+
assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314+
)]
315+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
316+
pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
317+
static_assert_uimm_bits!(DST, 3);
318+
static_assert_uimm_bits!(A, 3);
319+
static_assert_uimm_bits!(B, 3);
320+
tdphf8ps(DST as i8, A as i8, B as i8);
321+
}
322+
323+
/// Load tile rows from memory specified by base address and stride into destination tile dst
324+
/// using the tile configuration previously configured via _tile_loadconfig.
325+
/// Additionally, this intrinsic indicates the source memory location is likely to become
326+
/// read-shared by multiple processors, i.e., read in the future by at least one other processor
327+
/// before it is written, assuming it is ever written in the future.
328+
#[inline]
329+
#[rustc_legacy_const_generics(0)]
330+
#[target_feature(enable = "amx-movrs")]
331+
#[cfg_attr(
332+
all(test, any(target_os = "linux", target_env = "msvc")),
333+
assert_instr(tileloaddrs, DST = 0)
334+
)]
335+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
336+
pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
337+
static_assert_uimm_bits!(DST, 3);
338+
tileloaddrs64(DST as i8, base, stride);
339+
}
340+
341+
/// Load tile rows from memory specified by base address and stride into destination tile dst
342+
/// using the tile configuration previously configured via _tile_loadconfig.
343+
/// Provides a hint to the implementation that the data would be reused but does not need
344+
/// to be resident in the nearest cache levels.
345+
/// Additionally, this intrinsic indicates the source memory location is likely to become
346+
/// read-shared by multiple processors, i.e., read in the future by at least one other processor
347+
/// before it is written, assuming it is ever written in the future.
348+
#[inline]
349+
#[rustc_legacy_const_generics(0)]
350+
#[target_feature(enable = "amx-movrs")]
351+
#[cfg_attr(
352+
all(test, any(target_os = "linux", target_env = "msvc")),
353+
assert_instr(tileloaddrst1, DST = 0)
354+
)]
355+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
356+
pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
357+
static_assert_uimm_bits!(DST, 3);
358+
tileloaddrst164(DST as i8, base, stride);
359+
}
360+
361+
/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit)
362+
/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the
363+
/// results into a packed single precision tile.
364+
/// For each possible combination of (row of a, column of b), it performs
365+
/// - convert to TF32
366+
/// - multiply the corresponding elements of a and b
367+
/// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even
368+
/// rounding mode.
369+
/// Output FP32 denormals are always flushed to zero, input single precision denormals are always
370+
/// handled and *not* treated as zero.
371+
#[inline]
372+
#[rustc_legacy_const_generics(0, 1, 2)]
373+
#[target_feature(enable = "amx-tf32")]
374+
#[cfg_attr(
375+
all(test, any(target_os = "linux", target_env = "msvc")),
376+
assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377+
)]
378+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
379+
pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380+
static_assert_uimm_bits!(DST, 3);
381+
static_assert_uimm_bits!(A, 3);
382+
static_assert_uimm_bits!(B, 3);
383+
tmmultf32ps(DST as i8, A as i8, B as i8);
384+
}
385+
386+
/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
387+
/// elements to packed single-precision (32-bit) floating-point elements.
388+
#[inline]
389+
#[rustc_legacy_const_generics(0)]
390+
#[target_feature(enable = "amx-avx512,avx10.2")]
391+
#[cfg_attr(
392+
all(test, any(target_os = "linux", target_env = "msvc")),
393+
assert_instr(tcvtrowd2ps, TILE = 0)
394+
)]
395+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
396+
pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
397+
static_assert_uimm_bits!(TILE, 3);
398+
tcvtrowd2ps(TILE as i8, row).as_m512()
399+
}
400+
401+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
402+
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
403+
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
404+
#[inline]
405+
#[rustc_legacy_const_generics(0)]
406+
#[target_feature(enable = "amx-avx512,avx10.2")]
407+
#[cfg_attr(
408+
all(test, any(target_os = "linux", target_env = "msvc")),
409+
assert_instr(tcvtrowps2phh, TILE = 0)
410+
)]
411+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
412+
pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
413+
static_assert_uimm_bits!(TILE, 3);
414+
tcvtrowps2phh(TILE as i8, row).as_m512h()
415+
}
416+
417+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
418+
/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
419+
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
420+
#[inline]
421+
#[rustc_legacy_const_generics(0)]
422+
#[target_feature(enable = "amx-avx512,avx10.2")]
423+
#[cfg_attr(
424+
all(test, any(target_os = "linux", target_env = "msvc")),
425+
assert_instr(tcvtrowps2phl, TILE = 0)
426+
)]
427+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
428+
pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
429+
static_assert_uimm_bits!(TILE, 3);
430+
tcvtrowps2phl(TILE as i8, row).as_m512h()
431+
}
432+
433+
/// Moves one row of tile data into a zmm vector register
434+
#[inline]
435+
#[rustc_legacy_const_generics(0)]
436+
#[target_feature(enable = "amx-avx512,avx10.2")]
437+
#[cfg_attr(
438+
all(test, any(target_os = "linux", target_env = "msvc")),
439+
assert_instr(tilemovrow, TILE = 0)
440+
)]
441+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
442+
pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
443+
static_assert_uimm_bits!(TILE, 3);
444+
tilemovrow(TILE as i8, row).as_m512i()
445+
}
446+
245447
#[allow(improper_ctypes)]
246448
unsafe extern "C" {
247449
#[link_name = "llvm.x86.ldtilecfg"]
@@ -274,6 +476,28 @@ unsafe extern "C" {
274476
fn tcmmimfp16ps(dst: i8, a: i8, b: i8);
275477
#[link_name = "llvm.x86.tcmmrlfp16ps"]
276478
fn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
479+
#[link_name = "llvm.x86.tdpbf8ps"]
480+
fn tdpbf8ps(dst: i8, a: i8, b: i8);
481+
#[link_name = "llvm.x86.tdpbhf8ps"]
482+
fn tdpbhf8ps(dst: i8, a: i8, b: i8);
483+
#[link_name = "llvm.x86.tdphbf8ps"]
484+
fn tdphbf8ps(dst: i8, a: i8, b: i8);
485+
#[link_name = "llvm.x86.tdphf8ps"]
486+
fn tdphf8ps(dst: i8, a: i8, b: i8);
487+
#[link_name = "llvm.x86.tileloaddrs64"]
488+
fn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
489+
#[link_name = "llvm.x86.tileloaddrst164"]
490+
fn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
491+
#[link_name = "llvm.x86.tmmultf32ps"]
492+
fn tmmultf32ps(dst: i8, a: i8, b: i8);
493+
#[link_name = "llvm.x86.tcvtrowd2ps"]
494+
fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
495+
#[link_name = "llvm.x86.tcvtrowps2phh"]
496+
fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
497+
#[link_name = "llvm.x86.tcvtrowps2phl"]
498+
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
499+
#[link_name = "llvm.x86.tilemovrow"]
500+
fn tilemovrow(tile: i8, row: u32) -> i32x16;
277501
}
278502

279503
#[cfg(test)]

0 commit comments

Comments
 (0)