Skip to content

Commit 76b3c74

Browse files
committed
Add tests for new AMX intrinsics
1 parent 7ea5712 commit 76b3c74

File tree

1 file changed

+227
-1
lines changed
  • crates/core_arch/src/x86_64

1 file changed

+227
-1
lines changed

crates/core_arch/src/x86_64/amx.rs

Lines changed: 227 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -504,7 +504,7 @@ unsafe extern "C" {
504504
mod tests {
505505
use crate::core_arch::x86::_mm_cvtness_sbh;
506506
use crate::core_arch::x86_64::*;
507-
use core::mem::transmute;
507+
use core::{array, mem::transmute};
508508
use stdarch_test::simd_test;
509509
#[cfg(target_os = "linux")]
510510
use syscalls::{Sysno, syscall};
@@ -843,4 +843,230 @@ mod tests {
843843
_tile_release();
844844
assert_eq!(res, [[0f32; 16]; 16]);
845845
}
846+
847+
const BF8_ONE: u8 = 0x3c;
848+
const BF8_TWO: u8 = 0x40;
849+
const HF8_ONE: u8 = 0x38;
850+
const HF8_TWO: u8 = 0x40;
851+
852+
#[simd_test(enable = "amx-fp8")]
853+
unsafe fn test_tile_dpbf8ps() {
854+
_init_amx();
855+
let ones = [BF8_ONE; 1024];
856+
let twos = [BF8_TWO; 1024];
857+
let mut res = [[0.0_f32; 16]; 16];
858+
let mut config = __tilecfg::default();
859+
config.palette = 1;
860+
(0..=2).for_each(|i| {
861+
config.colsb[i] = 64;
862+
config.rows[i] = 16;
863+
});
864+
_tile_loadconfig(config.as_ptr());
865+
_tile_zero::<0>();
866+
_tile_loadd::<1>(&ones as *const u8, 64);
867+
_tile_loadd::<2>(&twos as *const u8, 64);
868+
_tile_dpbf8ps::<0, 1, 2>();
869+
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
870+
_tile_release();
871+
assert_eq!(res, [[128.0_f32; 16]; 16]);
872+
}
873+
874+
#[simd_test(enable = "amx-fp8")]
875+
unsafe fn test_tile_dpbhf8ps() {
876+
_init_amx();
877+
let ones = [BF8_ONE; 1024];
878+
let twos = [HF8_TWO; 1024];
879+
let mut res = [[0.0_f32; 16]; 16];
880+
let mut config = __tilecfg::default();
881+
config.palette = 1;
882+
(0..=2).for_each(|i| {
883+
config.colsb[i] = 64;
884+
config.rows[i] = 16;
885+
});
886+
_tile_loadconfig(config.as_ptr());
887+
_tile_zero::<0>();
888+
_tile_loadd::<1>(&ones as *const u8, 64);
889+
_tile_loadd::<2>(&twos as *const u8, 64);
890+
_tile_dpbhf8ps::<0, 1, 2>();
891+
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
892+
_tile_release();
893+
assert_eq!(res, [[128.0_f32; 16]; 16]);
894+
}
895+
896+
#[simd_test(enable = "amx-fp8")]
897+
unsafe fn test_tile_dphbf8ps() {
898+
_init_amx();
899+
let ones = [HF8_ONE; 1024];
900+
let twos = [BF8_TWO; 1024];
901+
let mut res = [[0.0_f32; 16]; 16];
902+
let mut config = __tilecfg::default();
903+
config.palette = 1;
904+
(0..=2).for_each(|i| {
905+
config.colsb[i] = 64;
906+
config.rows[i] = 16;
907+
});
908+
_tile_loadconfig(config.as_ptr());
909+
_tile_zero::<0>();
910+
_tile_loadd::<1>(&ones as *const u8, 64);
911+
_tile_loadd::<2>(&twos as *const u8, 64);
912+
_tile_dphbf8ps::<0, 1, 2>();
913+
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
914+
_tile_release();
915+
assert_eq!(res, [[128.0_f32; 16]; 16]);
916+
}
917+
918+
#[simd_test(enable = "amx-fp8")]
919+
unsafe fn test_tile_dphf8ps() {
920+
_init_amx();
921+
let ones = [HF8_ONE; 1024];
922+
let twos = [HF8_TWO; 1024];
923+
let mut res = [[0.0_f32; 16]; 16];
924+
let mut config = __tilecfg::default();
925+
config.palette = 1;
926+
(0..=2).for_each(|i| {
927+
config.colsb[i] = 64;
928+
config.rows[i] = 16;
929+
});
930+
_tile_loadconfig(config.as_ptr());
931+
_tile_zero::<0>();
932+
_tile_loadd::<1>(&ones as *const u8, 64);
933+
_tile_loadd::<2>(&twos as *const u8, 64);
934+
_tile_dphf8ps::<0, 1, 2>();
935+
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
936+
_tile_release();
937+
assert_eq!(res, [[128.0_f32; 16]; 16]);
938+
}
939+
940+
#[simd_test(enable = "amx-tile")]
941+
unsafe fn test_tile_loaddrs() {
942+
_init_amx();
943+
let mut config = __tilecfg::default();
944+
config.palette = 1;
945+
config.colsb[0] = 64;
946+
config.rows[0] = 16;
947+
_tile_loadconfig(config.as_ptr());
948+
_tile_zero::<0>();
949+
let mat = [1_i8; 1024];
950+
_tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
951+
let mut out = [[0_i8; 64]; 16];
952+
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
953+
_tile_release();
954+
assert_eq!(out, [[1; 64]; 16]);
955+
}
956+
957+
#[simd_test(enable = "amx-tile")]
958+
unsafe fn test_tile_stream_loaddrs() {
959+
_init_amx();
960+
let mut config = __tilecfg::default();
961+
config.palette = 1;
962+
config.colsb[0] = 64;
963+
config.rows[0] = 16;
964+
_tile_loadconfig(config.as_ptr());
965+
_tile_zero::<0>();
966+
let mat = [1_i8; 1024];
967+
_tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
968+
let mut out = [[0_i8; 64]; 16];
969+
_tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
970+
_tile_release();
971+
assert_eq!(out, [[1; 64]; 16]);
972+
}
973+
974+
#[simd_test(enable = "amx-avx512,avx10.2")]
975+
unsafe fn test_tile_movrow() {
976+
_init_amx();
977+
let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
978+
979+
let mut config = __tilecfg::default();
980+
config.palette = 1;
981+
config.colsb[0] = 64;
982+
config.rows[0] = 16;
983+
_tile_loadconfig(config.as_ptr());
984+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
985+
for i in 0..16 {
986+
let row = _tile_movrow::<0>(i);
987+
assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
988+
}
989+
}
990+
991+
#[simd_test(enable = "amx-avx512,avx10.2")]
992+
unsafe fn test_tile_cvtrowd2ps() {
993+
_init_amx();
994+
let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
995+
996+
let mut config = __tilecfg::default();
997+
config.palette = 1;
998+
config.colsb[0] = 64;
999+
config.rows[0] = 16;
1000+
_tile_loadconfig(config.as_ptr());
1001+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1002+
for i in 0..16 {
1003+
let row = _tile_cvtrowd2ps::<0>(i);
1004+
assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1005+
}
1006+
}
1007+
1008+
#[simd_test(enable = "amx-avx512,avx10.2")]
1009+
unsafe fn test_tile_cvtrowps2phh() {
1010+
_init_amx();
1011+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1012+
1013+
let mut config = __tilecfg::default();
1014+
config.palette = 1;
1015+
config.colsb[0] = 64;
1016+
config.rows[0] = 16;
1017+
_tile_loadconfig(config.as_ptr());
1018+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1019+
for i in 0..16 {
1020+
let row = _tile_cvtrowps2phh::<0>(i);
1021+
assert_eq!(
1022+
*row.as_f16x32().as_array(),
1023+
array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1024+
);
1025+
}
1026+
}
1027+
1028+
#[simd_test(enable = "amx-avx512,avx10.2")]
1029+
unsafe fn test_tile_cvtrowps2phl() {
1030+
_init_amx();
1031+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1032+
1033+
let mut config = __tilecfg::default();
1034+
config.palette = 1;
1035+
config.colsb[0] = 64;
1036+
config.rows[0] = 16;
1037+
_tile_loadconfig(config.as_ptr());
1038+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1039+
for i in 0..16 {
1040+
let row = _tile_cvtrowps2phl::<0>(i);
1041+
assert_eq!(
1042+
*row.as_f16x32().as_array(),
1043+
array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1044+
);
1045+
}
1046+
}
1047+
1048+
#[simd_test(enable = "amx-tf32")]
1049+
unsafe fn test_tile_mmultf32ps() {
1050+
_init_amx();
1051+
let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1052+
let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
1053+
let mut res = [[0.0; 16]; 16];
1054+
1055+
let mut config = __tilecfg::default();
1056+
config.palette = 1;
1057+
(0..=2).for_each(|i| {
1058+
config.colsb[i] = 64;
1059+
config.rows[i] = 16;
1060+
});
1061+
_tile_loadconfig(config.as_ptr());
1062+
_tile_zero::<0>();
1063+
_tile_loadd::<1>(a.as_ptr().cast(), 64);
1064+
_tile_loadd::<2>(b.as_ptr().cast(), 64);
1065+
_tile_mmultf32ps::<0, 1, 2>();
1066+
_tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1067+
_tile_release();
1068+
1069+
let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
1070+
assert_eq!(res, expected);
1071+
}
8461072
}

0 commit comments

Comments
 (0)