Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions naga/src/back/msl/keywords.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
use crate::proc::KeywordSet;
use crate::proc::{concrete_int_scalars, vector_size_str, vector_sizes, KeywordSet};
use crate::racy_lock::RacyLock;
use alloc::{format, string::String, vec::Vec};

// MSLS - Metal Shading Language Specification:
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
//
// C++ - Standard for Programming Language C++ (N4431)
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf
pub const RESERVED: &[&str] = &[
const RESERVED: &[&str] = &[
// Undocumented
"assert", // found in https://github.com/gfx-rs/wgpu/issues/5347
// Standard for Programming Language C++ (N4431): 2.5 Alternative tokens
Expand Down Expand Up @@ -346,6 +347,7 @@ pub const RESERVED: &[&str] = &[
super::writer::MODF_FUNCTION,
super::writer::ABS_FUNCTION,
super::writer::DIV_FUNCTION,
// DOT_FUNCTION_PREFIX variants are added dynamically below
super::writer::MOD_FUNCTION,
super::writer::NEG_FUNCTION,
super::writer::F2I32_FUNCTION,
Expand All @@ -359,8 +361,31 @@ pub const RESERVED: &[&str] = &[
super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT,
];

// The set of concrete integer dot product function variants.
// This must match the set of names that could be produced by
// `Writer::get_dot_wrapper_function_helper_name`.
static DOT_FUNCTION_NAMES: RacyLock<Vec<String>> = RacyLock::new(|| {
let mut names = Vec::new();
for scalar in concrete_int_scalars().map(crate::Scalar::to_msl_name) {
for size_suffix in vector_sizes().map(vector_size_str) {
let fun_name = format!(
"{}_{}{}",
super::writer::DOT_FUNCTION_PREFIX,
scalar,
size_suffix
);
names.push(fun_name);
}
}
names
});

/// The above set of reserved keywords, turned into a cached HashSet. This saves
/// significant time during [`Namer::reset`](crate::proc::Namer::reset).
///
/// See <https://github.com/gfx-rs/wgpu/pull/7338> for benchmarks.
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| KeywordSet::from_iter(RESERVED));
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| {
let mut set = KeywordSet::from_iter(RESERVED);
set.extend(DOT_FUNCTION_NAMES.iter().map(String::as_str));
set
});
121 changes: 85 additions & 36 deletions naga/src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use crate::{
back::{self, get_entry_points, Baked},
common,
proc::{
self,
self, concrete_int_scalars,
index::{self, BoundsCheck},
ExternalTextureNameKey, NameKey, TypeResolution,
},
Expand Down Expand Up @@ -55,6 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
pub(crate) const ABS_FUNCTION: &str = "naga_abs";
pub(crate) const DIV_FUNCTION: &str = "naga_div";
pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
Expand Down Expand Up @@ -488,7 +489,7 @@ pub struct Writer<W> {
}

impl crate::Scalar {
fn to_msl_name(self) -> &'static str {
pub(super) fn to_msl_name(self) -> &'static str {
use crate::ScalarKind as Sk;
match self {
Self {
Expand Down Expand Up @@ -2334,26 +2335,28 @@ impl<W: Write> Writer<W> {
crate::TypeInner::Vector {
scalar:
crate::Scalar {
// Resolve float values to MSL's builtin dot function.
kind: crate::ScalarKind::Float,
..
},
..
} => "dot",
crate::TypeInner::Vector { size, .. } => {
return self.put_dot_product(
arg,
arg1.unwrap(),
size as usize,
|writer, arg, index| {
// Write the vector expression; this expression is marked to be
// cached so unless it can't be cached (for example, it's a Constant)
// it shouldn't produce large expressions.
writer.put_expression(arg, context, true)?;
// Access the current component on the vector.
write!(writer.out, ".{}", back::COMPONENTS[index])?;
Ok(())
crate::TypeInner::Vector {
size,
scalar:
scalar @ crate::Scalar {
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
..
},
);
} => {
// Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
write!(self.out, "{fun_name}(")?;
self.put_expression(arg, context, true)?;
write!(self.out, ", ")?;
self.put_expression(arg1.unwrap(), context, true)?;
write!(self.out, ")")?;
return Ok(());
}
_ => unreachable!(
"Correct TypeInner for dot product should be already validated"
Expand Down Expand Up @@ -3370,26 +3373,15 @@ impl<W: Write> Writer<W> {
} = *expr
{
match fun {
crate::MathFunction::Dot => {
// WGSL's `dot` function works on any `vecN` type, but Metal's only
// works on floating-point vectors, so we emit inline code for
// integer vector `dot` calls. But that code uses each argument `N`
// times, once for each component (see `put_dot_product`), so to
// avoid duplicated evaluation, we must bake integer operands.

// check what kind of product this is depending
// on the resolve type of the Dot function itself
let inner = context.resolve_type(expr_handle);
if let crate::TypeInner::Scalar(scalar) = *inner {
match scalar.kind {
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
}
_ => {}
}
}
}
// WGSL's `dot` function works on any `vecN` type, but Metal's only
// works on floating-point vectors, so we emit inline code for
// integer vector `dot` calls. But that code uses each argument `N`
// times, once for each component (see `put_dot_product`), so to
// avoid duplicated evaluation, we must bake integer operands.
// This applies both when using the polyfill (because of the duplicate
// evaluation issue) and when we don't use the polyfill (because we
// need them to be emitted before casting to packed chars -- see the
// comment at the call to `put_casting_to_packed_chars`).
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
self.need_bake_expressions.insert(arg);
self.need_bake_expressions.insert(arg1.unwrap());
Expand Down Expand Up @@ -5806,6 +5798,24 @@ template <typename A>
Ok(())
}

/// Build the mangled helper name for integer vector dot products.
///
/// `scalar` must be a concrete integer scalar type.
///
/// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
fn get_dot_wrapper_function_helper_name(
&self,
scalar: crate::Scalar,
size: crate::VectorSize,
) -> String {
// Check for consistency with [`super::keywords::RESERVED_SET`]
debug_assert!(concrete_int_scalars().any(|s| s == scalar));

let type_name = scalar.to_msl_name();
let size_suffix = common::vector_size_str(size);
format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
}

#[allow(clippy::too_many_arguments)]
fn write_wrapped_math_function(
&mut self,
Expand Down Expand Up @@ -5861,6 +5871,45 @@ template <typename A>
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}

crate::MathFunction::Dot => match *arg_ty {
crate::TypeInner::Vector { size, scalar }
if matches!(
scalar.kind,
crate::ScalarKind::Sint | crate::ScalarKind::Uint
) =>
{
// De-duplicate per (fun, arg type) like other wrapped math functions
let wrapped = WrappedFunction::Math {
fun,
arg_ty: (Some(size), scalar),
};
if !self.wrapped_functions.insert(wrapped) {
return Ok(());
}

let mut vec_ty = String::new();
put_numeric_type(&mut vec_ty, scalar, &[size])?;
let mut ret_ty = String::new();
put_numeric_type(&mut ret_ty, scalar, &[])?;

let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);

// Emit function signature and body using put_dot_product for the expression
writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
let level = back::Level(1);
write!(self.out, "{level}return ")?;
self.put_dot_product("a", "b", size as usize, |writer, name, index| {
write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
Ok(())
})?;
writeln!(self.out, ";")?;
writeln!(self.out, "}}")?;
writeln!(self.out)?;
}
_ => {}
},

_ => {}
}
Ok(())
Expand Down
10 changes: 2 additions & 8 deletions naga/src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,5 @@ pub mod wgsl;
pub use diagnostic_debug::{DiagnosticDebug, ForDebug, ForDebugWithTypes};
pub use diagnostic_display::DiagnosticDisplay;

/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize)
pub const fn vector_size_str(size: crate::VectorSize) -> &'static str {
match size {
crate::VectorSize::Bi => "2",
crate::VectorSize::Tri => "3",
crate::VectorSize::Quad => "4",
}
}
// Re-exported here for backwards compatibility
pub use super::proc::vector_size_str;
4 changes: 3 additions & 1 deletion naga/src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ pub use namer::{EntryPointIndex, ExternalTextureNameKey, NameKey, Namer};
pub use overloads::{Conclusion, MissingSpecialType, OverloadSet, Rule};
pub use terminator::ensure_block_returns;
use thiserror::Error;
pub use type_methods::min_max_float_representable_by;
pub use type_methods::{
concrete_int_scalars, min_max_float_representable_by, vector_size_str, vector_sizes,
};
pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution};

use crate::non_max_u32::NonMaxU32;
Expand Down
4 changes: 2 additions & 2 deletions naga/src/proc/overloads/mathfunction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ use crate::proc::overloads::any_overload_set::AnyOverloadSet;
use crate::proc::overloads::list::List;
use crate::proc::overloads::regular::regular;
use crate::proc::overloads::utils::{
concrete_int_scalars, float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule,
scalar_or_vecn, triples, vector_sizes,
float_scalars, float_scalars_unimplemented_abstract, list, pairs, rule, scalar_or_vecn, triples,
};
use crate::proc::overloads::OverloadSet;
use crate::proc::type_methods::{concrete_int_scalars, vector_sizes};

use crate::ir;

Expand Down
25 changes: 0 additions & 25 deletions naga/src/proc/overloads/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,6 @@ use crate::proc::TypeResolution;

use alloc::vec::Vec;

/// Produce all vector sizes.
pub fn vector_sizes() -> impl Iterator<Item = ir::VectorSize> + Clone {
static SIZES: [ir::VectorSize; 3] = [
ir::VectorSize::Bi,
ir::VectorSize::Tri,
ir::VectorSize::Quad,
];

SIZES.iter().cloned()
}

/// Produce all the floating-point [`ir::Scalar`]s.
///
/// Note that `F32` must appear before other sizes; this is how we
Expand All @@ -40,20 +29,6 @@ pub fn float_scalars_unimplemented_abstract() -> impl Iterator<Item = ir::Scalar
[ir::Scalar::F32, ir::Scalar::F16, ir::Scalar::F64].into_iter()
}

/// Produce all concrete integer [`ir::Scalar`]s.
///
/// Note that `I32` and `U32` must come first; this is how we
/// represent conversion rank.
pub fn concrete_int_scalars() -> impl Iterator<Item = ir::Scalar> {
[
ir::Scalar::I32,
ir::Scalar::U32,
ir::Scalar::I64,
ir::Scalar::U64,
]
.into_iter()
}

/// Produce the scalar and vector [`ir::TypeInner`]s that have `s` as
/// their scalar.
pub fn scalar_or_vecn(scalar: ir::Scalar) -> impl Iterator<Item = ir::TypeInner> {
Expand Down
37 changes: 36 additions & 1 deletion naga/src/proc/type_methods.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
//! Methods on [`TypeInner`], [`Scalar`], and [`ScalarKind`].
//! Methods on or related to [`TypeInner`], [`Scalar`], [`ScalarKind`], and [`VectorSize`].
//!
//! [`TypeInner`]: crate::TypeInner
//! [`Scalar`]: crate::Scalar
//! [`ScalarKind`]: crate::ScalarKind
//! [`VectorSize`]: crate::VectorSize

use crate::{ir, valid::MAX_TYPE_SIZE};

Expand Down Expand Up @@ -97,6 +98,31 @@ impl crate::Scalar {
}
}

/// Produce all concrete integer [`ir::Scalar`]s.
///
/// Note that `I32` and `U32` must come first; this represents conversion rank
/// in overload resolution.
pub fn concrete_int_scalars() -> impl Iterator<Item = ir::Scalar> {
[
ir::Scalar::I32,
ir::Scalar::U32,
ir::Scalar::I64,
ir::Scalar::U64,
]
.into_iter()
}

/// Produce all vector sizes.
pub fn vector_sizes() -> impl Iterator<Item = ir::VectorSize> + Clone {
static SIZES: [ir::VectorSize; 3] = [
ir::VectorSize::Bi,
ir::VectorSize::Tri,
ir::VectorSize::Quad,
];

SIZES.iter().cloned()
}

const POINTER_SPAN: u32 = 4;

impl crate::TypeInner {
Expand Down Expand Up @@ -612,3 +638,12 @@ pub fn min_max_float_representable_by(
_ => unreachable!(),
}
}

/// Helper function that returns the string corresponding to the [`VectorSize`](crate::VectorSize)
pub const fn vector_size_str(size: crate::VectorSize) -> &'static str {
match size {
crate::VectorSize::Bi => "2",
crate::VectorSize::Tri => "3",
crate::VectorSize::Quad => "4",
}
}
12 changes: 10 additions & 2 deletions naga/tests/out/msl/wgsl-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,22 @@ metal::float2 test_fma(
return metal::fma(a, b, c);
}

int naga_dot_int2(metal::int2 a, metal::int2 b) {
return ( + a.x * b.x + a.y * b.y);
}

uint naga_dot_uint3(metal::uint3 a, metal::uint3 b) {
return ( + a.x * b.x + a.y * b.y + a.z * b.z);
}

int test_integer_dot_product(
) {
metal::int2 a_2_ = metal::int2(1);
metal::int2 b_2_ = metal::int2(1);
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
int c_2_ = naga_dot_int2(a_2_, b_2_);
metal::uint3 a_3_ = metal::uint3(1u);
metal::uint3 b_3_ = metal::uint3(1u);
uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
uint c_3_ = naga_dot_uint3(a_3_, b_3_);
return 32;
}

Expand Down
Loading