@@ -504,7 +504,7 @@ unsafe extern "C" {
504504mod tests {
505505 use crate :: core_arch:: x86:: _mm_cvtness_sbh;
506506 use crate :: core_arch:: x86_64:: * ;
507- use core:: mem:: transmute;
507+ use core:: { array , mem:: transmute} ;
508508 use stdarch_test:: simd_test;
509509 #[ cfg( target_os = "linux" ) ]
510510 use syscalls:: { Sysno , syscall} ;
@@ -843,4 +843,230 @@ mod tests {
843843 _tile_release ( ) ;
844844 assert_eq ! ( res, [ [ 0f32 ; 16 ] ; 16 ] ) ;
845845 }
846+
847+ const BF8_ONE : u8 = 0x3c ;
848+ const BF8_TWO : u8 = 0x40 ;
849+ const HF8_ONE : u8 = 0x38 ;
850+ const HF8_TWO : u8 = 0x40 ;
851+
852+ #[ simd_test( enable = "amx-fp8" ) ]
853+ unsafe fn test_tile_dpbf8ps ( ) {
854+ _init_amx ( ) ;
855+ let ones = [ BF8_ONE ; 1024 ] ;
856+ let twos = [ BF8_TWO ; 1024 ] ;
857+ let mut res = [ [ 0.0_f32 ; 16 ] ; 16 ] ;
858+ let mut config = __tilecfg:: default ( ) ;
859+ config. palette = 1 ;
860+ ( 0 ..=2 ) . for_each ( |i| {
861+ config. colsb [ i] = 64 ;
862+ config. rows [ i] = 16 ;
863+ } ) ;
864+ _tile_loadconfig ( config. as_ptr ( ) ) ;
865+ _tile_zero :: < 0 > ( ) ;
866+ _tile_loadd :: < 1 > ( & ones as * const u8 , 64 ) ;
867+ _tile_loadd :: < 2 > ( & twos as * const u8 , 64 ) ;
868+ _tile_dpbf8ps :: < 0 , 1 , 2 > ( ) ;
869+ _tile_stored :: < 0 > ( res. as_mut_ptr ( ) . cast ( ) , 64 ) ;
870+ _tile_release ( ) ;
871+ assert_eq ! ( res, [ [ 128.0_f32 ; 16 ] ; 16 ] ) ;
872+ }
873+
874+ #[ simd_test( enable = "amx-fp8" ) ]
875+ unsafe fn test_tile_dpbhf8ps ( ) {
876+ _init_amx ( ) ;
877+ let ones = [ BF8_ONE ; 1024 ] ;
878+ let twos = [ HF8_TWO ; 1024 ] ;
879+ let mut res = [ [ 0.0_f32 ; 16 ] ; 16 ] ;
880+ let mut config = __tilecfg:: default ( ) ;
881+ config. palette = 1 ;
882+ ( 0 ..=2 ) . for_each ( |i| {
883+ config. colsb [ i] = 64 ;
884+ config. rows [ i] = 16 ;
885+ } ) ;
886+ _tile_loadconfig ( config. as_ptr ( ) ) ;
887+ _tile_zero :: < 0 > ( ) ;
888+ _tile_loadd :: < 1 > ( & ones as * const u8 , 64 ) ;
889+ _tile_loadd :: < 2 > ( & twos as * const u8 , 64 ) ;
890+ _tile_dpbhf8ps :: < 0 , 1 , 2 > ( ) ;
891+ _tile_stored :: < 0 > ( res. as_mut_ptr ( ) . cast ( ) , 64 ) ;
892+ _tile_release ( ) ;
893+ assert_eq ! ( res, [ [ 128.0_f32 ; 16 ] ; 16 ] ) ;
894+ }
895+
896+ #[ simd_test( enable = "amx-fp8" ) ]
897+ unsafe fn test_tile_dphbf8ps ( ) {
898+ _init_amx ( ) ;
899+ let ones = [ HF8_ONE ; 1024 ] ;
900+ let twos = [ BF8_TWO ; 1024 ] ;
901+ let mut res = [ [ 0.0_f32 ; 16 ] ; 16 ] ;
902+ let mut config = __tilecfg:: default ( ) ;
903+ config. palette = 1 ;
904+ ( 0 ..=2 ) . for_each ( |i| {
905+ config. colsb [ i] = 64 ;
906+ config. rows [ i] = 16 ;
907+ } ) ;
908+ _tile_loadconfig ( config. as_ptr ( ) ) ;
909+ _tile_zero :: < 0 > ( ) ;
910+ _tile_loadd :: < 1 > ( & ones as * const u8 , 64 ) ;
911+ _tile_loadd :: < 2 > ( & twos as * const u8 , 64 ) ;
912+ _tile_dphbf8ps :: < 0 , 1 , 2 > ( ) ;
913+ _tile_stored :: < 0 > ( res. as_mut_ptr ( ) . cast ( ) , 64 ) ;
914+ _tile_release ( ) ;
915+ assert_eq ! ( res, [ [ 128.0_f32 ; 16 ] ; 16 ] ) ;
916+ }
917+
918+ #[ simd_test( enable = "amx-fp8" ) ]
919+ unsafe fn test_tile_dphf8ps ( ) {
920+ _init_amx ( ) ;
921+ let ones = [ HF8_ONE ; 1024 ] ;
922+ let twos = [ HF8_TWO ; 1024 ] ;
923+ let mut res = [ [ 0.0_f32 ; 16 ] ; 16 ] ;
924+ let mut config = __tilecfg:: default ( ) ;
925+ config. palette = 1 ;
926+ ( 0 ..=2 ) . for_each ( |i| {
927+ config. colsb [ i] = 64 ;
928+ config. rows [ i] = 16 ;
929+ } ) ;
930+ _tile_loadconfig ( config. as_ptr ( ) ) ;
931+ _tile_zero :: < 0 > ( ) ;
932+ _tile_loadd :: < 1 > ( & ones as * const u8 , 64 ) ;
933+ _tile_loadd :: < 2 > ( & twos as * const u8 , 64 ) ;
934+ _tile_dphf8ps :: < 0 , 1 , 2 > ( ) ;
935+ _tile_stored :: < 0 > ( res. as_mut_ptr ( ) . cast ( ) , 64 ) ;
936+ _tile_release ( ) ;
937+ assert_eq ! ( res, [ [ 128.0_f32 ; 16 ] ; 16 ] ) ;
938+ }
939+
940+ #[ simd_test( enable = "amx-tile" ) ]
941+ unsafe fn test_tile_loaddrs ( ) {
942+ _init_amx ( ) ;
943+ let mut config = __tilecfg:: default ( ) ;
944+ config. palette = 1 ;
945+ config. colsb [ 0 ] = 64 ;
946+ config. rows [ 0 ] = 16 ;
947+ _tile_loadconfig ( config. as_ptr ( ) ) ;
948+ _tile_zero :: < 0 > ( ) ;
949+ let mat = [ 1_i8 ; 1024 ] ;
950+ _tile_loaddrs :: < 0 > ( & mat as * const i8 as * const u8 , 64 ) ;
951+ let mut out = [ [ 0_i8 ; 64 ] ; 16 ] ;
952+ _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut u8 , 64 ) ;
953+ _tile_release ( ) ;
954+ assert_eq ! ( out, [ [ 1 ; 64 ] ; 16 ] ) ;
955+ }
956+
957+ #[ simd_test( enable = "amx-tile" ) ]
958+ unsafe fn test_tile_stream_loaddrs ( ) {
959+ _init_amx ( ) ;
960+ let mut config = __tilecfg:: default ( ) ;
961+ config. palette = 1 ;
962+ config. colsb [ 0 ] = 64 ;
963+ config. rows [ 0 ] = 16 ;
964+ _tile_loadconfig ( config. as_ptr ( ) ) ;
965+ _tile_zero :: < 0 > ( ) ;
966+ let mat = [ 1_i8 ; 1024 ] ;
967+ _tile_stream_loaddrs :: < 0 > ( & mat as * const i8 as * const u8 , 64 ) ;
968+ let mut out = [ [ 0_i8 ; 64 ] ; 16 ] ;
969+ _tile_stored :: < 0 > ( & mut out as * mut [ i8 ; 64 ] as * mut u8 , 64 ) ;
970+ _tile_release ( ) ;
971+ assert_eq ! ( out, [ [ 1 ; 64 ] ; 16 ] ) ;
972+ }
973+
974+ #[ simd_test( enable = "amx-avx512,avx10.2" ) ]
975+ unsafe fn test_tile_movrow ( ) {
976+ _init_amx ( ) ;
977+ let array: [ [ u8 ; 64 ] ; 16 ] = array:: from_fn ( |i| [ i as _ ; _] ) ;
978+
979+ let mut config = __tilecfg:: default ( ) ;
980+ config. palette = 1 ;
981+ config. colsb [ 0 ] = 64 ;
982+ config. rows [ 0 ] = 16 ;
983+ _tile_loadconfig ( config. as_ptr ( ) ) ;
984+ _tile_loadd :: < 0 > ( array. as_ptr ( ) . cast ( ) , 64 ) ;
985+ for i in 0 ..16 {
986+ let row = _tile_movrow :: < 0 > ( i) ;
987+ assert_eq ! ( * row. as_u8x64( ) . as_array( ) , [ i as _; _] ) ;
988+ }
989+ }
990+
991+ #[ simd_test( enable = "amx-avx512,avx10.2" ) ]
992+ unsafe fn test_tile_cvtrowd2ps ( ) {
993+ _init_amx ( ) ;
994+ let array: [ [ u32 ; 16 ] ; 16 ] = array:: from_fn ( |i| [ i as _ ; _] ) ;
995+
996+ let mut config = __tilecfg:: default ( ) ;
997+ config. palette = 1 ;
998+ config. colsb [ 0 ] = 64 ;
999+ config. rows [ 0 ] = 16 ;
1000+ _tile_loadconfig ( config. as_ptr ( ) ) ;
1001+ _tile_loadd :: < 0 > ( array. as_ptr ( ) . cast ( ) , 64 ) ;
1002+ for i in 0 ..16 {
1003+ let row = _tile_cvtrowd2ps :: < 0 > ( i) ;
1004+ assert_eq ! ( * row. as_f32x16( ) . as_array( ) , [ i as _; _] ) ;
1005+ }
1006+ }
1007+
1008+ #[ simd_test( enable = "amx-avx512,avx10.2" ) ]
1009+ unsafe fn test_tile_cvtrowps2phh ( ) {
1010+ _init_amx ( ) ;
1011+ let array: [ [ f32 ; 16 ] ; 16 ] = array:: from_fn ( |i| [ i as _ ; _] ) ;
1012+
1013+ let mut config = __tilecfg:: default ( ) ;
1014+ config. palette = 1 ;
1015+ config. colsb [ 0 ] = 64 ;
1016+ config. rows [ 0 ] = 16 ;
1017+ _tile_loadconfig ( config. as_ptr ( ) ) ;
1018+ _tile_loadd :: < 0 > ( array. as_ptr ( ) . cast ( ) , 64 ) ;
1019+ for i in 0 ..16 {
1020+ let row = _tile_cvtrowps2phh :: < 0 > ( i) ;
1021+ assert_eq ! (
1022+ * row. as_f16x32( ) . as_array( ) ,
1023+ array:: from_fn( |j| if j & 1 == 0 { 0.0 } else { i as _ } )
1024+ ) ;
1025+ }
1026+ }
1027+
1028+ #[ simd_test( enable = "amx-avx512,avx10.2" ) ]
1029+ unsafe fn test_tile_cvtrowps2phl ( ) {
1030+ _init_amx ( ) ;
1031+ let array: [ [ f32 ; 16 ] ; 16 ] = array:: from_fn ( |i| [ i as _ ; _] ) ;
1032+
1033+ let mut config = __tilecfg:: default ( ) ;
1034+ config. palette = 1 ;
1035+ config. colsb [ 0 ] = 64 ;
1036+ config. rows [ 0 ] = 16 ;
1037+ _tile_loadconfig ( config. as_ptr ( ) ) ;
1038+ _tile_loadd :: < 0 > ( array. as_ptr ( ) . cast ( ) , 64 ) ;
1039+ for i in 0 ..16 {
1040+ let row = _tile_cvtrowps2phl :: < 0 > ( i) ;
1041+ assert_eq ! (
1042+ * row. as_f16x32( ) . as_array( ) ,
1043+ array:: from_fn( |j| if j & 1 == 0 { i as _ } else { 0.0 } )
1044+ ) ;
1045+ }
1046+ }
1047+
1048+ #[ simd_test( enable = "amx-tf32" ) ]
1049+ unsafe fn test_tile_mmultf32ps ( ) {
1050+ _init_amx ( ) ;
1051+ let a: [ [ f32 ; 16 ] ; 16 ] = array:: from_fn ( |i| [ i as _ ; _] ) ;
1052+ let b: [ [ f32 ; 16 ] ; 16 ] = [ array:: from_fn ( |j| j as _ ) ; _] ;
1053+ let mut res = [ [ 0.0 ; 16 ] ; 16 ] ;
1054+
1055+ let mut config = __tilecfg:: default ( ) ;
1056+ config. palette = 1 ;
1057+ ( 0 ..=2 ) . for_each ( |i| {
1058+ config. colsb [ i] = 64 ;
1059+ config. rows [ i] = 16 ;
1060+ } ) ;
1061+ _tile_loadconfig ( config. as_ptr ( ) ) ;
1062+ _tile_zero :: < 0 > ( ) ;
1063+ _tile_loadd :: < 1 > ( a. as_ptr ( ) . cast ( ) , 64 ) ;
1064+ _tile_loadd :: < 2 > ( b. as_ptr ( ) . cast ( ) , 64 ) ;
1065+ _tile_mmultf32ps :: < 0 , 1 , 2 > ( ) ;
1066+ _tile_stored :: < 0 > ( res. as_mut_ptr ( ) . cast ( ) , 64 ) ;
1067+ _tile_release ( ) ;
1068+
1069+ let expected = array:: from_fn ( |i| array:: from_fn ( |j| 16.0 * i as f32 * j as f32 ) ) ;
1070+ assert_eq ! ( res, expected) ;
1071+ }
8461072}
0 commit comments