|
| 1 | +use crate::core_arch::{simd::*, x86::*}; |
| 2 | + |
1 | 3 | #[cfg(test)] |
2 | 4 | use stdarch_test::assert_instr; |
3 | 5 |
|
@@ -242,6 +244,206 @@ pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() { |
242 | 244 | tcmmrlfp16ps(DST as i8, A as i8, B as i8); |
243 | 245 | } |
244 | 246 |
|
| 247 | +/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2) |
| 248 | +/// floating-point elements in tile b, accumulating the intermediate single-precision |
| 249 | +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result |
| 250 | +/// back to tile dst. |
| 251 | +#[inline] |
| 252 | +#[rustc_legacy_const_generics(0, 1, 2)] |
| 253 | +#[target_feature(enable = "amx-fp8")] |
| 254 | +#[cfg_attr( |
| 255 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 256 | + assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2) |
| 257 | +)] |
| 258 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 259 | +pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() { |
| 260 | + static_assert_uimm_bits!(DST, 3); |
| 261 | + static_assert_uimm_bits!(A, 3); |
| 262 | + static_assert_uimm_bits!(B, 3); |
| 263 | + tdpbf8ps(DST as i8, A as i8, B as i8); |
| 264 | +} |
| 265 | + |
| 266 | +/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8 |
| 267 | +/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision |
| 268 | +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result |
| 269 | +/// back to tile dst. |
| 270 | +#[inline] |
| 271 | +#[rustc_legacy_const_generics(0, 1, 2)] |
| 272 | +#[target_feature(enable = "amx-fp8")] |
| 273 | +#[cfg_attr( |
| 274 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 275 | + assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2) |
| 276 | +)] |
| 277 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 278 | +pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() { |
| 279 | + static_assert_uimm_bits!(DST, 3); |
| 280 | + static_assert_uimm_bits!(A, 3); |
| 281 | + static_assert_uimm_bits!(B, 3); |
| 282 | + tdpbhf8ps(DST as i8, A as i8, B as i8); |
| 283 | +} |
| 284 | + |
| 285 | +/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8 |
| 286 | +/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision |
| 287 | +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result |
| 288 | +/// back to tile dst. |
| 289 | +#[inline] |
| 290 | +#[rustc_legacy_const_generics(0, 1, 2)] |
| 291 | +#[target_feature(enable = "amx-fp8")] |
| 292 | +#[cfg_attr( |
| 293 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 294 | + assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2) |
| 295 | +)] |
| 296 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 297 | +pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() { |
| 298 | + static_assert_uimm_bits!(DST, 3); |
| 299 | + static_assert_uimm_bits!(A, 3); |
| 300 | + static_assert_uimm_bits!(B, 3); |
| 301 | + tdphbf8ps(DST as i8, A as i8, B as i8); |
| 302 | +} |
| 303 | + |
| 304 | +/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3) |
| 305 | +/// floating-point elements in tile b, accumulating the intermediate single-precision |
| 306 | +/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result |
| 307 | +/// back to tile dst. |
| 308 | +#[inline] |
| 309 | +#[rustc_legacy_const_generics(0, 1, 2)] |
| 310 | +#[target_feature(enable = "amx-fp8")] |
| 311 | +#[cfg_attr( |
| 312 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 313 | + assert_instr(tdphf8ps, DST = 0, A = 1, B = 2) |
| 314 | +)] |
| 315 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 316 | +pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() { |
| 317 | + static_assert_uimm_bits!(DST, 3); |
| 318 | + static_assert_uimm_bits!(A, 3); |
| 319 | + static_assert_uimm_bits!(B, 3); |
| 320 | + tdphf8ps(DST as i8, A as i8, B as i8); |
| 321 | +} |
| 322 | + |
| 323 | +/// Load tile rows from memory specified by base address and stride into destination tile dst |
| 324 | +/// using the tile configuration previously configured via _tile_loadconfig. |
| 325 | +/// Additionally, this intrinsic indicates the source memory location is likely to become |
| 326 | +/// read-shared by multiple processors, i.e., read in the future by at least one other processor |
| 327 | +/// before it is written, assuming it is ever written in the future. |
| 328 | +#[inline] |
| 329 | +#[rustc_legacy_const_generics(0)] |
| 330 | +#[target_feature(enable = "amx-movrs")] |
| 331 | +#[cfg_attr( |
| 332 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 333 | + assert_instr(tileloaddrs, DST = 0) |
| 334 | +)] |
| 335 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 336 | +pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) { |
| 337 | + static_assert_uimm_bits!(DST, 3); |
| 338 | + tileloaddrs64(DST as i8, base, stride); |
| 339 | +} |
| 340 | + |
| 341 | +/// Load tile rows from memory specified by base address and stride into destination tile dst |
| 342 | +/// using the tile configuration previously configured via _tile_loadconfig. |
| 343 | +/// Provides a hint to the implementation that the data would be reused but does not need |
| 344 | +/// to be resident in the nearest cache levels. |
| 345 | +/// Additionally, this intrinsic indicates the source memory location is likely to become |
| 346 | +/// read-shared by multiple processors, i.e., read in the future by at least one other processor |
| 347 | +/// before it is written, assuming it is ever written in the future. |
| 348 | +#[inline] |
| 349 | +#[rustc_legacy_const_generics(0)] |
| 350 | +#[target_feature(enable = "amx-movrs")] |
| 351 | +#[cfg_attr( |
| 352 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 353 | + assert_instr(tileloaddrst1, DST = 0) |
| 354 | +)] |
| 355 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 356 | +pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) { |
| 357 | + static_assert_uimm_bits!(DST, 3); |
| 358 | + tileloaddrst164(DST as i8, base, stride); |
| 359 | +} |
| 360 | + |
| 361 | +/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit) |
| 362 | +/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the |
| 363 | +/// results into a packed single precision tile. |
| 364 | +/// For each possible combination of (row of a, column of b), it performs |
| 365 | +/// - convert to TF32 |
| 366 | +/// - multiply the corresponding elements of a and b |
| 367 | +/// - accumulate the results into the corresponding row and column of dst using round-to-nearest-even |
| 368 | +/// rounding mode. |
| 369 | +/// Output FP32 denormals are always flushed to zero, input single precision denormals are always |
| 370 | +/// handled and *not* treated as zero. |
| 371 | +#[inline] |
| 372 | +#[rustc_legacy_const_generics(0, 1, 2)] |
| 373 | +#[target_feature(enable = "amx-tf32")] |
| 374 | +#[cfg_attr( |
| 375 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 376 | + assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2) |
| 377 | +)] |
| 378 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 379 | +pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() { |
| 380 | + static_assert_uimm_bits!(DST, 3); |
| 381 | + static_assert_uimm_bits!(A, 3); |
| 382 | + static_assert_uimm_bits!(B, 3); |
| 383 | + tmmultf32ps(DST as i8, A as i8, B as i8); |
| 384 | +} |
| 385 | + |
| 386 | +/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer |
| 387 | +/// elements to packed single-precision (32-bit) floating-point elements. |
| 388 | +#[inline] |
| 389 | +#[rustc_legacy_const_generics(0)] |
| 390 | +#[target_feature(enable = "amx-avx512,avx10.2")] |
| 391 | +#[cfg_attr( |
| 392 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 393 | + assert_instr(tcvtrowd2ps, TILE = 0) |
| 394 | +)] |
| 395 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 396 | +pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 { |
| 397 | + static_assert_uimm_bits!(TILE, 3); |
| 398 | + tcvtrowd2ps(TILE as i8, row).as_m512() |
| 399 | +} |
| 400 | + |
| 401 | +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) |
| 402 | +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting |
| 403 | +/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector. |
| 404 | +#[inline] |
| 405 | +#[rustc_legacy_const_generics(0)] |
| 406 | +#[target_feature(enable = "amx-avx512,avx10.2")] |
| 407 | +#[cfg_attr( |
| 408 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 409 | + assert_instr(tcvtrowps2phh, TILE = 0) |
| 410 | +)] |
| 411 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 412 | +pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h { |
| 413 | + static_assert_uimm_bits!(TILE, 3); |
| 414 | + tcvtrowps2phh(TILE as i8, row).as_m512h() |
| 415 | +} |
| 416 | + |
| 417 | +/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit) |
| 418 | +/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting |
| 419 | +/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector. |
| 420 | +#[inline] |
| 421 | +#[rustc_legacy_const_generics(0)] |
| 422 | +#[target_feature(enable = "amx-avx512,avx10.2")] |
| 423 | +#[cfg_attr( |
| 424 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 425 | + assert_instr(tcvtrowps2phl, TILE = 0) |
| 426 | +)] |
| 427 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 428 | +pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h { |
| 429 | + static_assert_uimm_bits!(TILE, 3); |
| 430 | + tcvtrowps2phl(TILE as i8, row).as_m512h() |
| 431 | +} |
| 432 | + |
| 433 | +/// Moves one row of tile data into a zmm vector register |
| 434 | +#[inline] |
| 435 | +#[rustc_legacy_const_generics(0)] |
| 436 | +#[target_feature(enable = "amx-avx512,avx10.2")] |
| 437 | +#[cfg_attr( |
| 438 | + all(test, any(target_os = "linux", target_env = "msvc")), |
| 439 | + assert_instr(tilemovrow, TILE = 0) |
| 440 | +)] |
| 441 | +#[unstable(feature = "x86_amx_intrinsics", issue = "126622")] |
| 442 | +pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i { |
| 443 | + static_assert_uimm_bits!(TILE, 3); |
| 444 | + tilemovrow(TILE as i8, row).as_m512i() |
| 445 | +} |
| 446 | + |
245 | 447 | #[allow(improper_ctypes)] |
246 | 448 | unsafe extern "C" { |
247 | 449 | #[link_name = "llvm.x86.ldtilecfg"] |
@@ -274,6 +476,28 @@ unsafe extern "C" { |
274 | 476 | fn tcmmimfp16ps(dst: i8, a: i8, b: i8); |
275 | 477 | #[link_name = "llvm.x86.tcmmrlfp16ps"] |
276 | 478 | fn tcmmrlfp16ps(dst: i8, a: i8, b: i8); |
| 479 | + #[link_name = "llvm.x86.tdpbf8ps"] |
| 480 | + fn tdpbf8ps(dst: i8, a: i8, b: i8); |
| 481 | + #[link_name = "llvm.x86.tdpbhf8ps"] |
| 482 | + fn tdpbhf8ps(dst: i8, a: i8, b: i8); |
| 483 | + #[link_name = "llvm.x86.tdphbf8ps"] |
| 484 | + fn tdphbf8ps(dst: i8, a: i8, b: i8); |
| 485 | + #[link_name = "llvm.x86.tdphf8ps"] |
| 486 | + fn tdphf8ps(dst: i8, a: i8, b: i8); |
| 487 | + #[link_name = "llvm.x86.tileloaddrs64"] |
| 488 | + fn tileloaddrs64(dst: i8, base: *const u8, stride: usize); |
| 489 | + #[link_name = "llvm.x86.tileloaddrst164"] |
| 490 | + fn tileloaddrst164(dst: i8, base: *const u8, stride: usize); |
| 491 | + #[link_name = "llvm.x86.tmmultf32ps"] |
| 492 | + fn tmmultf32ps(dst: i8, a: i8, b: i8); |
| 493 | + #[link_name = "llvm.x86.tcvtrowd2ps"] |
| 494 | + fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16; |
| 495 | + #[link_name = "llvm.x86.tcvtrowps2phh"] |
| 496 | + fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32; |
| 497 | + #[link_name = "llvm.x86.tcvtrowps2phl"] |
| 498 | + fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32; |
| 499 | + #[link_name = "llvm.x86.tilemovrow"] |
| 500 | + fn tilemovrow(tile: i8, row: u32) -> i32x16; |
277 | 501 | } |
278 | 502 |
|
279 | 503 | #[cfg(test)] |
|
0 commit comments