Skip to content

Commit 6216007

Browse files
RiverDaveandyleiserson
authored andcommitted
[naga msl-out] Annotate dot product functions as wrapped functions
1 parent 874b750 commit 6216007

File tree

4 files changed

+133
-47
lines changed

4 files changed

+133
-47
lines changed

naga/src/back/msl/keywords.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
use crate::proc::KeywordSet;
1+
use crate::proc::{concrete_int_scalars, vector_size_str, vector_sizes, KeywordSet};
22
use crate::racy_lock::RacyLock;
3+
use alloc::{format, string::String, vec::Vec};
34

45
// MSLS - Metal Shading Language Specification:
56
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
67
//
78
// C++ - Standard for Programming Language C++ (N4431)
89
// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2015/n4431.pdf
9-
pub const RESERVED: &[&str] = &[
10+
const RESERVED: &[&str] = &[
1011
// Undocumented
1112
"assert", // found in https://github.com/gfx-rs/wgpu/issues/5347
1213
// Standard for Programming Language C++ (N4431): 2.5 Alternative tokens
@@ -346,6 +347,7 @@ pub const RESERVED: &[&str] = &[
346347
super::writer::MODF_FUNCTION,
347348
super::writer::ABS_FUNCTION,
348349
super::writer::DIV_FUNCTION,
350+
// DOT_FUNCTION_PREFIX variants are added dynamically below
349351
super::writer::MOD_FUNCTION,
350352
super::writer::NEG_FUNCTION,
351353
super::writer::F2I32_FUNCTION,
@@ -359,8 +361,31 @@ pub const RESERVED: &[&str] = &[
359361
super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT,
360362
];
361363

364+
// The set of concrete integer dot product function variants.
365+
// This must match the set of names that could be produced by
366+
// `Writer::get_dot_wrapper_function_helper_name`.
367+
static DOT_FUNCTION_NAMES: RacyLock<Vec<String>> = RacyLock::new(|| {
368+
let mut names = Vec::new();
369+
for scalar in concrete_int_scalars().map(crate::Scalar::to_msl_name) {
370+
for size_suffix in vector_sizes().map(vector_size_str) {
371+
let fun_name = format!(
372+
"{}_{}{}",
373+
super::writer::DOT_FUNCTION_PREFIX,
374+
scalar,
375+
size_suffix
376+
);
377+
names.push(fun_name);
378+
}
379+
}
380+
names
381+
});
382+
362383
/// The above set of reserved keywords, turned into a cached HashSet. This saves
363384
/// significant time during [`Namer::reset`](crate::proc::Namer::reset).
364385
///
365386
/// See <https://github.com/gfx-rs/wgpu/pull/7338> for benchmarks.
366-
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| KeywordSet::from_iter(RESERVED));
387+
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| {
388+
let mut set = KeywordSet::from_iter(RESERVED);
389+
set.extend(DOT_FUNCTION_NAMES.iter().map(String::as_str));
390+
set
391+
});

naga/src/back/msl/writer.rs

Lines changed: 85 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ use crate::{
1919
back::{self, get_entry_points, Baked},
2020
common,
2121
proc::{
22-
self,
22+
self, concrete_int_scalars,
2323
index::{self, BoundsCheck},
2424
ExternalTextureNameKey, NameKey, TypeResolution,
2525
},
@@ -55,6 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
5555
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
5656
pub(crate) const ABS_FUNCTION: &str = "naga_abs";
5757
pub(crate) const DIV_FUNCTION: &str = "naga_div";
58+
pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
5859
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
5960
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
6061
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
@@ -488,7 +489,7 @@ pub struct Writer<W> {
488489
}
489490

490491
impl crate::Scalar {
491-
fn to_msl_name(self) -> &'static str {
492+
pub(super) fn to_msl_name(self) -> &'static str {
492493
use crate::ScalarKind as Sk;
493494
match self {
494495
Self {
@@ -2334,26 +2335,28 @@ impl<W: Write> Writer<W> {
23342335
crate::TypeInner::Vector {
23352336
scalar:
23362337
crate::Scalar {
2338+
// Resolve float values to MSL's builtin dot function.
23372339
kind: crate::ScalarKind::Float,
23382340
..
23392341
},
23402342
..
23412343
} => "dot",
2342-
crate::TypeInner::Vector { size, .. } => {
2343-
return self.put_dot_product(
2344-
arg,
2345-
arg1.unwrap(),
2346-
size as usize,
2347-
|writer, arg, index| {
2348-
// Write the vector expression; this expression is marked to be
2349-
// cached so unless it can't be cached (for example, it's a Constant)
2350-
// it shouldn't produce large expressions.
2351-
writer.put_expression(arg, context, true)?;
2352-
// Access the current component on the vector.
2353-
write!(writer.out, ".{}", back::COMPONENTS[index])?;
2354-
Ok(())
2344+
crate::TypeInner::Vector {
2345+
size,
2346+
scalar:
2347+
scalar @ crate::Scalar {
2348+
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2349+
..
23552350
},
2356-
);
2351+
} => {
2352+
// Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
2353+
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
2354+
write!(self.out, "{fun_name}(")?;
2355+
self.put_expression(arg, context, true)?;
2356+
write!(self.out, ", ")?;
2357+
self.put_expression(arg1.unwrap(), context, true)?;
2358+
write!(self.out, ")")?;
2359+
return Ok(());
23572360
}
23582361
_ => unreachable!(
23592362
"Correct TypeInner for dot product should be already validated"
@@ -3370,26 +3373,15 @@ impl<W: Write> Writer<W> {
33703373
} = *expr
33713374
{
33723375
match fun {
3373-
crate::MathFunction::Dot => {
3374-
// WGSL's `dot` function works on any `vecN` type, but Metal's only
3375-
// works on floating-point vectors, so we emit inline code for
3376-
// integer vector `dot` calls. But that code uses each argument `N`
3377-
// times, once for each component (see `put_dot_product`), so to
3378-
// avoid duplicated evaluation, we must bake integer operands.
3379-
3380-
// check what kind of product this is depending
3381-
// on the resolve type of the Dot function itself
3382-
let inner = context.resolve_type(expr_handle);
3383-
if let crate::TypeInner::Scalar(scalar) = *inner {
3384-
match scalar.kind {
3385-
crate::ScalarKind::Sint | crate::ScalarKind::Uint => {
3386-
self.need_bake_expressions.insert(arg);
3387-
self.need_bake_expressions.insert(arg1.unwrap());
3388-
}
3389-
_ => {}
3390-
}
3391-
}
3392-
}
3376+
// WGSL's `dot` function works on any `vecN` type, but Metal's only
3377+
// works on floating-point vectors, so we emit inline code for
3378+
// integer vector `dot` calls. But that code uses each argument `N`
3379+
// times, once for each component (see `put_dot_product`), so to
3380+
// avoid duplicated evaluation, we must bake integer operands.
3381+
// This applies both when using the polyfill (because of the duplicate
3382+
// evaluation issue) and when we don't use the polyfill (because we
3383+
// need them to be emitted before casting to packed chars -- see the
3384+
// comment at the call to `put_casting_to_packed_chars`).
33933385
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
33943386
self.need_bake_expressions.insert(arg);
33953387
self.need_bake_expressions.insert(arg1.unwrap());
@@ -5806,6 +5798,24 @@ template <typename A>
58065798
Ok(())
58075799
}
58085800

5801+
/// Build the mangled helper name for integer vector dot products.
5802+
///
5803+
/// `scalar` must be a concrete integer scalar type.
5804+
///
5805+
/// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
5806+
fn get_dot_wrapper_function_helper_name(
5807+
&self,
5808+
scalar: crate::Scalar,
5809+
size: crate::VectorSize,
5810+
) -> String {
5811+
// Check for consistency with [`super::keywords::RESERVED_SET`]
5812+
debug_assert!(concrete_int_scalars().any(|s| s == scalar));
5813+
5814+
let type_name = scalar.to_msl_name();
5815+
let size_suffix = common::vector_size_str(size);
5816+
format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
5817+
}
5818+
58095819
#[allow(clippy::too_many_arguments)]
58105820
fn write_wrapped_math_function(
58115821
&mut self,
@@ -5861,6 +5871,45 @@ template <typename A>
58615871
writeln!(self.out, "}}")?;
58625872
writeln!(self.out)?;
58635873
}
5874+
5875+
crate::MathFunction::Dot => match *arg_ty {
5876+
crate::TypeInner::Vector { size, scalar }
5877+
if matches!(
5878+
scalar.kind,
5879+
crate::ScalarKind::Sint | crate::ScalarKind::Uint
5880+
) =>
5881+
{
5882+
// De-duplicate per (fun, arg type) like other wrapped math functions
5883+
let wrapped = WrappedFunction::Math {
5884+
fun,
5885+
arg_ty: (Some(size), scalar),
5886+
};
5887+
if !self.wrapped_functions.insert(wrapped) {
5888+
return Ok(());
5889+
}
5890+
5891+
let mut vec_ty = String::new();
5892+
put_numeric_type(&mut vec_ty, scalar, &[size])?;
5893+
let mut ret_ty = String::new();
5894+
put_numeric_type(&mut ret_ty, scalar, &[])?;
5895+
5896+
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
5897+
5898+
// Emit function signature and body using put_dot_product for the expression
5899+
writeln!(self.out, "{ret_ty} {fun_name}({vec_ty} a, {vec_ty} b) {{")?;
5900+
let level = back::Level(1);
5901+
write!(self.out, "{level}return ")?;
5902+
self.put_dot_product("a", "b", size as usize, |writer, name, index| {
5903+
write!(writer.out, "{name}.{}", back::COMPONENTS[index])?;
5904+
Ok(())
5905+
})?;
5906+
writeln!(self.out, ";")?;
5907+
writeln!(self.out, "}}")?;
5908+
writeln!(self.out)?;
5909+
}
5910+
_ => {}
5911+
},
5912+
58645913
_ => {}
58655914
}
58665915
Ok(())

naga/tests/out/msl/wgsl-functions.msl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,22 @@ metal::float2 test_fma(
1313
return metal::fma(a, b, c);
1414
}
1515

16+
int naga_dot_int2(metal::int2 a, metal::int2 b) {
17+
return ( + a.x * b.x + a.y * b.y);
18+
}
19+
20+
uint naga_dot_uint3(metal::uint3 a, metal::uint3 b) {
21+
return ( + a.x * b.x + a.y * b.y + a.z * b.z);
22+
}
23+
1624
int test_integer_dot_product(
1725
) {
1826
metal::int2 a_2_ = metal::int2(1);
1927
metal::int2 b_2_ = metal::int2(1);
20-
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
28+
int c_2_ = naga_dot_int2(a_2_, b_2_);
2129
metal::uint3 a_3_ = metal::uint3(1u);
2230
metal::uint3 b_3_ = metal::uint3(1u);
23-
uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
31+
uint c_3_ = naga_dot_uint3(a_3_, b_3_);
2432
return 32;
2533
}
2634

naga/tests/out/msl/wgsl-int64.msl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ long naga_abs(long val) {
4343
return metal::select(as_type<long>(-as_type<ulong>(val)), val, val >= 0);
4444
}
4545

46+
long naga_dot_long2(metal::long2 a, metal::long2 b) {
47+
return ( + a.x * b.x + a.y * b.y);
48+
}
49+
4650
long int64_function(
4751
long x,
4852
thread long& private_variable,
@@ -111,11 +115,9 @@ long int64_function(
111115
long _e130 = val;
112116
val = as_type<long>(as_type<ulong>(_e130) + as_type<ulong>(metal::clamp(_e126, _e127, _e128)));
113117
long _e132 = val;
114-
metal::long2 _e133 = metal::long2(_e132);
115118
long _e134 = val;
116-
metal::long2 _e135 = metal::long2(_e134);
117119
long _e137 = val;
118-
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(( + _e133.x * _e135.x + _e133.y * _e135.y)));
120+
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(naga_dot_long2(metal::long2(_e132), metal::long2(_e134))));
119121
long _e139 = val;
120122
long _e140 = val;
121123
long _e142 = val;
@@ -135,6 +137,10 @@ ulong naga_f2u64(float value) {
135137
return static_cast<ulong>(metal::clamp(value, 0.0, 18446743000000000000.0));
136138
}
137139

140+
ulong naga_dot_ulong2(metal::ulong2 a, metal::ulong2 b) {
141+
return ( + a.x * b.x + a.y * b.y);
142+
}
143+
138144
ulong uint64_function(
139145
ulong x_1,
140146
constant UniformCompatible& input_uniform,
@@ -199,11 +205,9 @@ ulong uint64_function(
199205
ulong _e125 = val_1;
200206
val_1 = _e125 + metal::clamp(_e121, _e122, _e123);
201207
ulong _e127 = val_1;
202-
metal::ulong2 _e128 = metal::ulong2(_e127);
203208
ulong _e129 = val_1;
204-
metal::ulong2 _e130 = metal::ulong2(_e129);
205209
ulong _e132 = val_1;
206-
val_1 = _e132 + ( + _e128.x * _e130.x + _e128.y * _e130.y);
210+
val_1 = _e132 + naga_dot_ulong2(metal::ulong2(_e127), metal::ulong2(_e129));
207211
ulong _e134 = val_1;
208212
ulong _e135 = val_1;
209213
ulong _e137 = val_1;

0 commit comments

Comments
 (0)