@@ -4,6 +4,7 @@ use std::cmp;
44use libc:: c_uint;
55use rustc_abi:: { BackendRepr , HasDataLayout , Primitive , Reg , RegKind , Size } ;
66use rustc_codegen_ssa:: MemFlags ;
7+ use rustc_codegen_ssa:: common:: TypeKind ;
78use rustc_codegen_ssa:: mir:: operand:: { OperandRef , OperandValue } ;
89use rustc_codegen_ssa:: mir:: place:: { PlaceRef , PlaceValue } ;
910use rustc_codegen_ssa:: traits:: * ;
@@ -308,7 +309,12 @@ impl<'ll, 'tcx> ArgAbiBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> {
308309}
309310
310311pub ( crate ) trait FnAbiLlvmExt < ' ll , ' tcx > {
311- fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type ;
312+ fn llvm_type (
313+ & self ,
314+ cx : & CodegenCx < ' ll , ' tcx > ,
315+ name : & [ u8 ] ,
316+ is_llvm_intrinsic : bool ,
317+ ) -> & ' ll Type ;
312318 fn ptr_to_llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type ;
313319 fn llvm_cconv ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> llvm:: CallConv ;
314320
@@ -325,26 +331,45 @@ pub(crate) trait FnAbiLlvmExt<'ll, 'tcx> {
325331}
326332
327333impl < ' ll , ' tcx > FnAbiLlvmExt < ' ll , ' tcx > for FnAbi < ' tcx , Ty < ' tcx > > {
328- fn llvm_type ( & self , cx : & CodegenCx < ' ll , ' tcx > ) -> & ' ll Type {
334+ fn llvm_type (
335+ & self ,
336+ cx : & CodegenCx < ' ll , ' tcx > ,
337+ name : & [ u8 ] ,
338+ is_llvm_intrinsic : bool ,
339+ ) -> & ' ll Type {
329340 // Ignore "extra" args from the call site for C variadic functions.
330341 // Only the "fixed" args are part of the LLVM function signature.
331342 let args =
332343 if self . c_variadic { & self . args [ ..self . fixed_count as usize ] } else { & self . args } ;
333344
345+ let amx_intrinsic =
346+ is_llvm_intrinsic && name. starts_with ( b"llvm.x86." ) && name. ends_with ( b".internal" ) ;
347+ let adjust_ty = |ty| {
348+ // Change type to `x86amx` from `i32x256` for x86_64 AMX intrinsics
349+ if amx_intrinsic && cx. type_kind ( ty) == TypeKind :: Vector && cx. vector_length ( ty) == 256
350+ {
351+ let element_ty = cx. element_type ( ty) ;
352+ if cx. type_kind ( element_ty) == TypeKind :: Integer && cx. int_width ( element_ty) == 32 {
353+ return cx. type_x86amx ( ) ;
354+ }
355+ }
356+ ty
357+ } ;
358+
334359 // This capacity calculation is approximate.
335360 let mut llargument_tys = Vec :: with_capacity (
336361 self . args . len ( ) + if let PassMode :: Indirect { .. } = self . ret . mode { 1 } else { 0 } ,
337362 ) ;
338363
339- let llreturn_ty = match & self . ret . mode {
364+ let llreturn_ty = adjust_ty ( match & self . ret . mode {
340365 PassMode :: Ignore => cx. type_void ( ) ,
341366 PassMode :: Direct ( _) | PassMode :: Pair ( ..) => self . ret . layout . immediate_llvm_type ( cx) ,
342367 PassMode :: Cast { cast, pad_i32 : _ } => cast. llvm_type ( cx) ,
343368 PassMode :: Indirect { .. } => {
344369 llargument_tys. push ( cx. type_ptr ( ) ) ;
345370 cx. type_void ( )
346371 }
347- } ;
372+ } ) ;
348373
349374 for arg in args {
350375 // Note that the exact number of arguments pushed here is carefully synchronized with
@@ -388,7 +413,7 @@ impl<'ll, 'tcx> FnAbiLlvmExt<'ll, 'tcx> for FnAbi<'tcx, Ty<'tcx>> {
388413 cast. llvm_type ( cx)
389414 }
390415 } ;
391- llargument_tys. push ( llarg_ty) ;
416+ llargument_tys. push ( adjust_ty ( llarg_ty) ) ;
392417 }
393418
394419 if self . c_variadic {
0 commit comments