diff --git a/naga/src/back/msl/keywords.rs b/naga/src/back/msl/keywords.rs index 315af754e6c..fe9316b20a7 100644 --- a/naga/src/back/msl/keywords.rs +++ b/naga/src/back/msl/keywords.rs @@ -1,12 +1,13 @@ -use crate::proc::KeywordSet; +use crate::proc::{concrete_int_scalars, vector_size_str, vector_sizes, KeywordSet}; use crate::racy_lock::RacyLock; +use alloc::{format, string::String, vec::Vec}; // MSLS - Metal Shading Language Specification: // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf // // C++ - Standard for Programming Language C++ (N4431) // https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf -pub const RESERVED: &[&str] = &[ +const RESERVED: &[&str] = &[ // Undocumented "assert", // found in https://github.com/gfx-rs/wgpu/issues/5347 // Standard for Programming Language C++ (N4431): 2.5 Alternative tokens @@ -346,6 +347,7 @@ pub const RESERVED: &[&str] = &[ super::writer::MODF_FUNCTION, super::writer::ABS_FUNCTION, super::writer::DIV_FUNCTION, + // DOT_FUNCTION_PREFIX variants are added dynamically below super::writer::MOD_FUNCTION, super::writer::NEG_FUNCTION, super::writer::F2I32_FUNCTION, @@ -359,8 +361,31 @@ pub const RESERVED: &[&str] = &[ super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT, ]; +// The set of concrete integer dot product function variants. +// This must match the set of names that could be produced by +// `Writer::get_dot_wrapper_function_helper_name`. +static DOT_FUNCTION_NAMES: RacyLock> = RacyLock::new(|| { + let mut names = Vec::new(); + for scalar in concrete_int_scalars().map(crate::Scalar::to_msl_name) { + for size_suffix in vector_sizes().map(vector_size_str) { + let fun_name = format!( + "{}_{}{}", + super::writer::DOT_FUNCTION_PREFIX, + scalar, + size_suffix + ); + names.push(fun_name); + } + } + names +}); + /// The above set of reserved keywords, turned into a cached HashSet. This saves /// significant time during [`Namer::reset`](crate::proc::Namer::reset). /// /// See for benchmarks. -pub static RESERVED_SET: RacyLock = RacyLock::new(|| KeywordSet::from_iter(RESERVED)); +pub static RESERVED_SET: RacyLock = RacyLock::new(|| { + let mut set = KeywordSet::from_iter(RESERVED); + set.extend(DOT_FUNCTION_NAMES.iter().map(String::as_str)); + set +}); diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index ca7da02a930..8c21c944718 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -19,7 +19,7 @@ use crate::{ back::{self, get_entry_points, Baked}, common, proc::{ - self, + self, concrete_int_scalars, index::{self, BoundsCheck}, ExternalTextureNameKey, NameKey, TypeResolution, }, @@ -55,6 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf"; pub(crate) const FREXP_FUNCTION: &str = "naga_frexp"; pub(crate) const ABS_FUNCTION: &str = "naga_abs"; pub(crate) const DIV_FUNCTION: &str = "naga_div"; +pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot"; pub(crate) const MOD_FUNCTION: &str = "naga_mod"; pub(crate) const NEG_FUNCTION: &str = "naga_neg"; pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32"; @@ -488,7 +489,7 @@ pub struct Writer { } impl crate::Scalar { - fn to_msl_name(self) -> &'static str { + pub(super) fn to_msl_name(self) -> &'static str { use crate::ScalarKind as Sk; match self { Self { @@ -2334,26 +2335,28 @@ impl Writer { crate::TypeInner::Vector { scalar: crate::Scalar { + // Resolve float values to MSL's builtin dot function. kind: crate::ScalarKind::Float, .. }, .. } => "dot", - crate::TypeInner::Vector { size, .. } => { - return self.put_dot_product( - arg, - arg1.unwrap(), - size as usize, - |writer, arg, index| { - // Write the vector expression; this expression is marked to be - // cached so unless it can't be cached (for example, it's a Constant) - // it shouldn't produce large expressions. - writer.put_expression(arg, context, true)?; - // Access the current component on the vector. - write!(writer.out, ".{}", back::COMPONENTS[index])?; - Ok(()) + crate::TypeInner::Vector { + size, + scalar: + scalar @ crate::Scalar { + kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint, + .. }, - ); + } => { + // Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`. + let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size); + write!(self.out, "{fun_name}(")?; + self.put_expression(arg, context, true)?; + write!(self.out, ", ")?; + self.put_expression(arg1.unwrap(), context, true)?; + write!(self.out, ")")?; + return Ok(()); } _ => unreachable!( "Correct TypeInner for dot product should be already validated" @@ -3370,26 +3373,15 @@ impl Writer { } = *expr { match fun { - crate::MathFunction::Dot => { - // WGSL's `dot` function works on any `vecN` type, but Metal's only - // works on floating-point vectors, so we emit inline code for - // integer vector `dot` calls. But that code uses each argument `N` - // times, once for each component (see `put_dot_product`), so to - // avoid duplicated evaluation, we must bake integer operands. - - // check what kind of product this is depending - // on the resolve type of the Dot function itself - let inner = context.resolve_type(expr_handle); - if let crate::TypeInner::Scalar(scalar) = *inner { - match scalar.kind { - crate::ScalarKind::Sint | crate::ScalarKind::Uint => { - self.need_bake_expressions.insert(arg); - self.need_bake_expressions.insert(arg1.unwrap()); - } - _ => {} - } - } - } + // WGSL's `dot` function works on any `vecN` type, but Metal's only + // works on floating-point vectors, so we emit inline code for + // integer vector `dot` calls. But that code uses each argument `N` + // times, once for each component (see `put_dot_product`), so to + // avoid duplicated evaluation, we must bake integer operands. + // This applies both when using the polyfill (because of the duplicate + // evaluation issue) and when we don't use the polyfill (because we + // need them to be emitted before casting to packed chars -- see the + // comment at the call to `put_casting_to_packed_chars`). crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => { self.need_bake_expressions.insert(arg); self.need_bake_expressions.insert(arg1.unwrap()); @@ -5806,6 +5798,24 @@ template Ok(()) } + /// Build the mangled helper name for integer vector dot products. + /// + /// `scalar` must be a concrete integer scalar type. + /// + /// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`). + fn get_dot_wrapper_function_helper_name( + &self, + scalar: crate::Scalar, + size: crate::VectorSize, + ) -> String { + // Check for consistency with [`super::keywords::RESERVED_SET`] + debug_assert!(concrete_int_scalars().any(|s| s == scalar)); + + let type_name = scalar.to_msl_name(); + let size_suffix = common::vector_size_str(size); + format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}") + } + #[allow(clippy::too_many_arguments)] fn write_wrapped_math_function( &mut self, @@ -5861,6 +5871,45 @@ template writeln!(self.out, "}}")?; writeln!(self.out)?; } + + crate::MathFunction::Dot => match *arg_ty { + crate::TypeInner::Vector { size, scalar } + if matches!( + scalar.kind, + crate::ScalarKind::Sint | crate::ScalarKind::Uint + ) => + { + // De-duplicate per (fun, arg type) like other wrapped math functions + let wrapped = WrappedFunction::Math { + fun, + arg_ty: (Some(size), scalar), + }; + if !self.wrapped_functions.insert(wrapped) { + return Ok(()); + } + + let mut vec_ty = String::new(); + put_numeric_type(&mut vec_ty, scalar, &[size])?; + let mut ret_ty = String::new(); + put_numeric_type(&mut ret_ty, scalar, &[])?; + + let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size); + + // Emit function signature and body using put_dot_product for the expression + writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?; + let level = back::Level(1); + write!(self.out, "{level}return ")?; + self.put_dot_product("a", "b", size as usize, |writer, name, index| { + write!(writer.out, "{name}.{}", back::COMPONENTS[index])?; + Ok(()) + })?; + writeln!(self.out, ";")?; + writeln!(self.out, "}}")?; + writeln!(self.out)?; + } + _ => {} + }, + _ => {} } Ok(()) diff --git a/naga/src/common/mod.rs b/naga/src/common/mod.rs index 34de0185b69..e1fe06d5c0a 100644 --- a/naga/src/common/mod.rs +++ b/naga/src/common/mod.rs @@ -8,11 +8,5 @@ pub mod wgsl; pub use diagnostic_debug::{DiagnosticDebug, ForDebug, ForDebugWithTypes}; pub use diagnostic_display::DiagnosticDisplay; -/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize) -pub const fn vector_size_str(size: crate::VectorSize) -> &'static str { - match size { - crate::VectorSize::Bi => "2", - crate::VectorSize::Tri => "3", - crate::VectorSize::Quad => "4", - } -} +// Re-exported here for backwards compatibility +pub use super::proc::vector_size_str; diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 64da0a9661e..cebd98f2e47 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -24,7 +24,9 @@ pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer}; pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule}; pub use terminator::ensure_block_returns; use thiserror::Error; -pub use type_methods::min_max_float_representable_by; +pub use type_methods::{ + concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes, +}; pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution}; use crate::non_max_u32::NonMaxU32; diff --git a/naga/src/proc/overloads/mathfunction.rs b/naga/src/proc/overloads/mathfunction.rs index 132f99aec42..82ab96930b6 100644 --- a/naga/src/proc/overloads/mathfunction.rs +++ b/naga/src/proc/overloads/mathfunction.rs @@ -4,10 +4,10 @@ use crate::proc::overloads::any_overload_set::AnyOverloadSet; use crate::proc::overloads::list::List; use crate::proc::overloads::regular::regular; use crate::proc::overloads::utils::{ - concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule, - scalar_or_vecn, triples, vector_sizes, + float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule, scalar_or_vecn, triples, }; use crate::proc::overloads::OverloadSet; +use crate::proc::type_methods::{concrete_int_scalars, vector_sizes}; use crate::ir; diff --git a/naga/src/proc/overloads/utils.rs b/naga/src/proc/overloads/utils.rs index 4b4396c21bd..8adfca2c5ea 100644 --- a/naga/src/proc/overloads/utils.rs +++ b/naga/src/proc/overloads/utils.rs @@ -9,17 +9,6 @@ use crate::proc::TypeResolution; use alloc::vec::Vec; -/// Produce all vector sizes. -pub fn vector_sizes() -> impl Iterator + Clone { - static SIZES: [ir::VectorSize; 3] = [ - ir::VectorSize::Bi, - ir::VectorSize::Tri, - ir::VectorSize::Quad, - ]; - - SIZES.iter().cloned() -} - /// Produce all the floating-point [`ir::Scalar`]s. /// /// Note that `F32` must appear before other sizes; this is how we @@ -40,20 +29,6 @@ pub fn float_scalars_unimplemented_abstract() -> impl Iterator impl Iterator { - [ - ir::Scalar::I32, - ir::Scalar::U32, - ir::Scalar::I64, - ir::Scalar::U64, - ] - .into_iter() -} - /// Produce the scalar and vector [`ir::TypeInner`]s that have `s` as /// their scalar. pub fn scalar_or_vecn(scalar: ir::Scalar) -> impl Iterator { diff --git a/naga/src/proc/type_methods.rs b/naga/src/proc/type_methods.rs index c59d524f13e..85678f564d3 100644 --- a/naga/src/proc/type_methods.rs +++ b/naga/src/proc/type_methods.rs @@ -1,8 +1,9 @@ -//! Methods on [`TypeInner`], [`Scalar`], and [`ScalarKind`]. +//! Methods on or related to [`TypeInner`], [`Scalar`], [`ScalarKind`], and [`VectorSize`]. //! //! [`TypeInner`]: crate::TypeInner //! [`Scalar`]: crate::Scalar //! [`ScalarKind`]: crate::ScalarKind +//! [`VectorSize`]: crate::VectorSize use crate::{ir, valid::MAX_TYPE_SIZE}; @@ -97,6 +98,31 @@ impl crate::Scalar { } } +/// Produce all concrete integer [`ir::Scalar`]s. +/// +/// Note that `I32` and `U32` must come first; this represents conversion rank +/// in overload resolution. +pub fn concrete_int_scalars() -> impl Iterator { + [ + ir::Scalar::I32, + ir::Scalar::U32, + ir::Scalar::I64, + ir::Scalar::U64, + ] + .into_iter() +} + +/// Produce all vector sizes. +pub fn vector_sizes() -> impl Iterator + Clone { + static SIZES: [ir::VectorSize; 3] = [ + ir::VectorSize::Bi, + ir::VectorSize::Tri, + ir::VectorSize::Quad, + ]; + + SIZES.iter().cloned() +} + const POINTER_SPAN: u32 = 4; impl crate::TypeInner { @@ -612,3 +638,12 @@ pub fn min_max_float_representable_by( _ => unreachable!(), } } + +/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize) +pub const fn vector_size_str(size: crate::VectorSize) -> &'static str { + match size { + crate::VectorSize::Bi => "2", + crate::VectorSize::Tri => "3", + crate::VectorSize::Quad => "4", + } +} diff --git a/naga/tests/out/msl/wgsl-functions.msl b/naga/tests/out/msl/wgsl-functions.msl index fd3bdc249ed..97afeed186b 100644 --- a/naga/tests/out/msl/wgsl-functions.msl +++ b/naga/tests/out/msl/wgsl-functions.msl @@ -13,14 +13,22 @@ metal::float2 test_fma( return metal::fma(a, b, c); } +int naga_dot_int2(metal::int2 a, metal::int2 b) { + return ( + a.x * b.x + a.y * b.y); +} + +uint naga_dot_uint3(metal::uint3 a, metal::uint3 b) { + return ( + a.x * b.x + a.y * b.y + a.z * b.z); +} + int test_integer_dot_product( ) { metal::int2 a_2_ = metal::int2(1); metal::int2 b_2_ = metal::int2(1); - int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y); + int c_2_ = naga_dot_int2(a_2_, b_2_); metal::uint3 a_3_ = metal::uint3(1u); metal::uint3 b_3_ = metal::uint3(1u); - uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z); + uint c_3_ = naga_dot_uint3(a_3_, b_3_); return 32; } diff --git a/naga/tests/out/msl/wgsl-int64.msl b/naga/tests/out/msl/wgsl-int64.msl index 2df5dfc9550..6e9c987da81 100644 --- a/naga/tests/out/msl/wgsl-int64.msl +++ b/naga/tests/out/msl/wgsl-int64.msl @@ -43,6 +43,10 @@ long naga_abs(long val) { return metal::select(as_type(-as_type(val)), val, val >= 0); } +long naga_dot_long2(metal::long2 a, metal::long2 b) { + return ( + a.x * b.x + a.y * b.y); +} + long int64_function( long x, thread long& private_variable, @@ -111,11 +115,9 @@ long int64_function( long _e130 = val; val = as_type(as_type(_e130) + as_type(metal::clamp(_e126, _e127, _e128))); long _e132 = val; - metal::long2 _e133 = metal::long2(_e132); long _e134 = val; - metal::long2 _e135 = metal::long2(_e134); long _e137 = val; - val = as_type(as_type(_e137) + as_type(( + _e133.x * _e135.x + _e133.y * _e135.y))); + val = as_type(as_type(_e137) + as_type(naga_dot_long2(metal::long2(_e132), metal::long2(_e134)))); long _e139 = val; long _e140 = val; long _e142 = val; @@ -135,6 +137,10 @@ ulong naga_f2u64(float value) { return static_cast(metal::clamp(value, 0.0, 18446743000000000000.0)); } +ulong naga_dot_ulong2(metal::ulong2 a, metal::ulong2 b) { + return ( + a.x * b.x + a.y * b.y); +} + ulong uint64_function( ulong x_1, constant UniformCompatible& input_uniform, @@ -199,11 +205,9 @@ ulong uint64_function( ulong _e125 = val_1; val_1 = _e125 + metal::clamp(_e121, _e122, _e123); ulong _e127 = val_1; - metal::ulong2 _e128 = metal::ulong2(_e127); ulong _e129 = val_1; - metal::ulong2 _e130 = metal::ulong2(_e129); ulong _e132 = val_1; - val_1 = _e132 + ( + _e128.x * _e130.x + _e128.y * _e130.y); + val_1 = _e132 + naga_dot_ulong2(metal::ulong2(_e127), metal::ulong2(_e129)); ulong _e134 = val_1; ulong _e135 = val_1; ulong _e137 = val_1;