@@ -19,7 +19,7 @@ use crate::{
1919 back:: { self , get_entry_points, Baked } ,
2020 common,
2121 proc:: {
22- self ,
22+ self , concrete_int_scalars ,
2323 index:: { self , BoundsCheck } ,
2424 ExternalTextureNameKey , NameKey , TypeResolution ,
2525 } ,
@@ -55,6 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
5555pub ( crate ) const FREXP_FUNCTION : & str = "naga_frexp" ;
5656pub ( crate ) const ABS_FUNCTION : & str = "naga_abs" ;
5757pub ( crate ) const DIV_FUNCTION : & str = "naga_div" ;
58+ pub ( crate ) const DOT_FUNCTION_PREFIX : & str = "naga_dot" ;
5859pub ( crate ) const MOD_FUNCTION : & str = "naga_mod" ;
5960pub ( crate ) const NEG_FUNCTION : & str = "naga_neg" ;
6061pub ( crate ) const F2I32_FUNCTION : & str = "naga_f2i32" ;
@@ -488,7 +489,7 @@ pub struct Writer<W> {
488489}
489490
490491impl crate :: Scalar {
491- fn to_msl_name ( self ) -> & ' static str {
492+ pub ( super ) fn to_msl_name ( self ) -> & ' static str {
492493 use crate :: ScalarKind as Sk ;
493494 match self {
494495 Self {
@@ -2334,26 +2335,28 @@ impl<W: Write> Writer<W> {
23342335 crate :: TypeInner :: Vector {
23352336 scalar :
23362337 crate :: Scalar {
2338+ // Resolve float values to MSL's builtin dot function.
23372339 kind : crate :: ScalarKind :: Float ,
23382340 ..
23392341 } ,
23402342 ..
23412343 } => "dot" ,
2342- crate :: TypeInner :: Vector { size, .. } => {
2343- return self . put_dot_product (
2344- arg,
2345- arg1. unwrap ( ) ,
2346- size as usize ,
2347- |writer, arg, index| {
2348- // Write the vector expression; this expression is marked to be
2349- // cached so unless it can't be cached (for example, it's a Constant)
2350- // it shouldn't produce large expressions.
2351- writer. put_expression ( arg, context, true ) ?;
2352- // Access the current component on the vector.
2353- write ! ( writer. out, ".{}" , back:: COMPONENTS [ index] ) ?;
2354- Ok ( ( ) )
2344+ crate :: TypeInner :: Vector {
2345+ size,
2346+ scalar :
2347+ scalar @ crate :: Scalar {
2348+ kind : crate :: ScalarKind :: Sint | crate :: ScalarKind :: Uint ,
2349+ ..
23552350 } ,
2356- ) ;
2351+ } => {
2352+ // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2353+ let fun_name = self . get_dot_wrapper_function_helper_name ( scalar, size) ;
2354+ write ! ( self . out, "{fun_name}(" ) ?;
2355+ self . put_expression ( arg, context, true ) ?;
2356+ write ! ( self . out, ", " ) ?;
2357+ self . put_expression ( arg1. unwrap ( ) , context, true ) ?;
2358+ write ! ( self . out, ")" ) ?;
2359+ return Ok ( ( ) ) ;
23572360 }
23582361 _ => unreachable ! (
23592362 "Correct TypeInner for dot product should be already validated"
@@ -3370,26 +3373,15 @@ impl<W: Write> Writer<W> {
33703373 } = * expr
33713374 {
33723375 match fun {
3373- crate :: MathFunction :: Dot => {
3374- // WGSL's `dot` function works on any `vecN` type, but Metal's only
3375- // works on floating-point vectors, so we emit inline code for
3376- // integer vector `dot` calls. But that code uses each argument `N`
3377- // times, once for each component (see `put_dot_product`), so to
3378- // avoid duplicated evaluation, we must bake integer operands.
3379-
3380- // check what kind of product this is depending
3381- // on the resolve type of the Dot function itself
3382- let inner = context. resolve_type ( expr_handle) ;
3383- if let crate :: TypeInner :: Scalar ( scalar) = * inner {
3384- match scalar. kind {
3385- crate :: ScalarKind :: Sint | crate :: ScalarKind :: Uint => {
3386- self . need_bake_expressions . insert ( arg) ;
3387- self . need_bake_expressions . insert ( arg1. unwrap ( ) ) ;
3388- }
3389- _ => { }
3390- }
3391- }
3392- }
3376+ // WGSL's `dot` function works on any `vecN` type, but Metal's only
3377+ // works on floating-point vectors, so we emit inline code for
3378+ // integer vector `dot` calls. But that code uses each argument `N`
3379+ // times, once for each component (see `put_dot_product`), so to
3380+ // avoid duplicated evaluation, we must bake integer operands.
3381+ // This applies both when using the polyfill (because of the duplicate
3382+ // evaluation issue) and when we don't use the polyfill (because we
3383+ // need them to be emitted before casting to packed chars -- see the
3384+ // comment at the call to `put_casting_to_packed_chars`).
33933385 crate :: MathFunction :: Dot4U8Packed | crate :: MathFunction :: Dot4I8Packed => {
33943386 self . need_bake_expressions . insert ( arg) ;
33953387 self . need_bake_expressions . insert ( arg1. unwrap ( ) ) ;
@@ -5806,6 +5798,24 @@ template <typename A>
58065798 Ok ( ( ) )
58075799 }
58085800
5801+ /// Build the mangled helper name for integer vector dot products.
5802+ ///
5803+ /// `scalar` must be a concrete integer scalar type.
5804+ ///
5805+ /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5806+ fn get_dot_wrapper_function_helper_name (
5807+ & self ,
5808+ scalar : crate :: Scalar ,
5809+ size : crate :: VectorSize ,
5810+ ) -> String {
5811+ // Check for consistency with [`super::keywords::RESERVED_SET`]
5812+ debug_assert ! ( concrete_int_scalars( ) . any( |s| s == scalar) ) ;
5813+
5814+ let type_name = scalar. to_msl_name ( ) ;
5815+ let size_suffix = common:: vector_size_str ( size) ;
5816+ format ! ( "{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}" )
5817+ }
5818+
58095819 #[ allow( clippy:: too_many_arguments) ]
58105820 fn write_wrapped_math_function (
58115821 & mut self ,
@@ -5861,6 +5871,45 @@ template <typename A>
58615871 writeln ! ( self . out, "}}" ) ?;
58625872 writeln ! ( self . out) ?;
58635873 }
5874+
5875+ crate :: MathFunction :: Dot => match * arg_ty {
5876+ crate :: TypeInner :: Vector { size, scalar }
5877+ if matches ! (
5878+ scalar. kind,
5879+ crate :: ScalarKind :: Sint | crate :: ScalarKind :: Uint
5880+ ) =>
5881+ {
5882+ // De-duplicate per (fun, arg type) like other wrapped math functions
5883+ let wrapped = WrappedFunction :: Math {
5884+ fun,
5885+ arg_ty : ( Some ( size) , scalar) ,
5886+ } ;
5887+ if !self . wrapped_functions . insert ( wrapped) {
5888+ return Ok ( ( ) ) ;
5889+ }
5890+
5891+ let mut vec_ty = String :: new ( ) ;
5892+ put_numeric_type ( & mut vec_ty, scalar, & [ size] ) ?;
5893+ let mut ret_ty = String :: new ( ) ;
5894+ put_numeric_type ( & mut ret_ty, scalar, & [ ] ) ?;
5895+
5896+ let fun_name = self . get_dot_wrapper_function_helper_name ( scalar, size) ;
5897+
5898+ // Emit function signature and body using put_dot_product for the expression
5899+ writeln ! ( self . out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{" ) ?;
5900+ let level = back:: Level ( 1 ) ;
5901+ write ! ( self . out, "{level}return " ) ?;
5902+ self . put_dot_product ( "a" , "b" , size as usize , |writer, name, index| {
5903+ write ! ( writer. out, "{name}.{}" , back:: COMPONENTS [ index] ) ?;
5904+ Ok ( ( ) )
5905+ } ) ?;
5906+ writeln ! ( self . out, ";" ) ?;
5907+ writeln ! ( self . out, "}}" ) ?;
5908+ writeln ! ( self . out) ?;
5909+ }
5910+ _ => { }
5911+ } ,
5912+
58645913 _ => { }
58655914 }
58665915 Ok ( ( ) )
0 commit comments