Skip to content

Commit 780c8b4

Browse files
committed
Address comments
1 parent 83775b0 commit 780c8b4

File tree

2 files changed

+82
-11
lines changed

2 files changed

+82
-11
lines changed

naga/src/back/msl/keywords.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
use crate::common;
12
use crate::proc::KeywordSet;
23
use crate::racy_lock::RacyLock;
4+
use alloc::vec::Vec;
5+
use alloc::{boxed::Box, format};
36

47
// MSLS - Metal Shading Language Specification:
58
// https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
@@ -346,7 +349,7 @@ pub const RESERVED: &[&str] = &[
346349
super::writer::MODF_FUNCTION,
347350
super::writer::ABS_FUNCTION,
348351
super::writer::DIV_FUNCTION,
349-
super::writer::DOT_FUNCTION,
352+
// DOT_FUNCTION_PREFIX variants are added dynamically below
350353
super::writer::MOD_FUNCTION,
351354
super::writer::NEG_FUNCTION,
352355
super::writer::F2I32_FUNCTION,
@@ -360,8 +363,65 @@ pub const RESERVED: &[&str] = &[
360363
super::writer::EXTERNAL_TEXTURE_WRAPPER_STRUCT,
361364
];
362365

366+
const CONCRETE_INTEGER_SCALARS: [crate::Scalar; 4] = [
367+
crate::Scalar::I32,
368+
crate::Scalar::U32,
369+
crate::Scalar::I64,
370+
crate::Scalar::U64,
371+
];
372+
373+
const CONCRETE_VECTOR_SIZES: [crate::VectorSize; 3] = [
374+
crate::VectorSize::Bi,
375+
crate::VectorSize::Tri,
376+
crate::VectorSize::Quad,
377+
];
378+
363379
/// The above set of reserved keywords, turned into a cached HashSet. This saves
364380
/// significant time during [`Namer::reset`](crate::proc::Namer::reset).
365381
///
366382
/// See <https://github.com/gfx-rs/wgpu/pull/7338> for benchmarks.
367-
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| KeywordSet::from_iter(RESERVED));
383+
pub static RESERVED_SET: RacyLock<KeywordSet> = RacyLock::new(|| {
384+
let mut set = KeywordSet::from_iter(RESERVED);
385+
// Add all concrete integer dot product function variants.
386+
// These are generated to match the names produced by
387+
// `Writer::get_dot_wrapper_function_helper_name`, ensuring they stay in sync.
388+
let mut dot_function_names: Vec<&'static str> = Vec::new();
389+
for scalar in CONCRETE_INTEGER_SCALARS {
390+
// Map scalar to MSL type name (matching the logic in Scalar::to_msl_name)
391+
let type_name = match scalar {
392+
crate::Scalar {
393+
kind: crate::ScalarKind::Sint,
394+
width: 4,
395+
} => "int",
396+
crate::Scalar {
397+
kind: crate::ScalarKind::Uint,
398+
width: 4,
399+
} => "uint",
400+
crate::Scalar {
401+
kind: crate::ScalarKind::Sint,
402+
width: 8,
403+
} => "long",
404+
crate::Scalar {
405+
kind: crate::ScalarKind::Uint,
406+
width: 8,
407+
} => "ulong",
408+
_ => continue, // Skip non-integer or unsupported types
409+
};
410+
411+
for size in CONCRETE_VECTOR_SIZES {
412+
let size_suffix = common::vector_size_str(size);
413+
let fun_name = format!(
414+
"{}_{}{}",
415+
super::writer::DOT_FUNCTION_PREFIX,
416+
type_name,
417+
size_suffix
418+
);
419+
// Convert to &'static str by leaking the String
420+
// (In theory) This is safe because these are generated once and cached
421+
let leaked = Box::leak(fun_name.into_boxed_str());
422+
dot_function_names.push(leaked);
423+
}
424+
}
425+
set.extend(dot_function_names.into_iter());
426+
set
427+
});

naga/src/back/msl/writer.rs

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ pub(crate) const MODF_FUNCTION: &str = "naga_modf";
5555
pub(crate) const FREXP_FUNCTION: &str = "naga_frexp";
5656
pub(crate) const ABS_FUNCTION: &str = "naga_abs";
5757
pub(crate) const DIV_FUNCTION: &str = "naga_div";
58-
pub(crate) const DOT_FUNCTION: &str = "naga_dot";
58+
pub(crate) const DOT_FUNCTION_PREFIX: &str = "naga_dot";
5959
pub(crate) const MOD_FUNCTION: &str = "naga_mod";
6060
pub(crate) const NEG_FUNCTION: &str = "naga_neg";
6161
pub(crate) const F2I32_FUNCTION: &str = "naga_f2i32";
@@ -2338,12 +2338,14 @@ impl<W: Write> Writer<W> {
23382338
},
23392339
..
23402340
} => "dot",
2341-
crate::TypeInner::Vector { size, scalar }
2342-
if matches!(
2343-
scalar.kind,
2344-
crate::ScalarKind::Sint | crate::ScalarKind::Uint
2345-
) =>
2346-
{
2341+
crate::TypeInner::Vector {
2342+
size,
2343+
scalar:
2344+
scalar @ crate::Scalar {
2345+
kind: crate::ScalarKind::Sint | crate::ScalarKind::Uint,
2346+
..
2347+
},
2348+
} => {
23472349
// Integer vector dot: call our mangled helper `dot_{type}{N}(a, b)`.
23482350
let fun_name = self.get_dot_wrapper_function_helper_name(scalar, size);
23492351
write!(self.out, "{fun_name}(")?;
@@ -3368,6 +3370,15 @@ impl<W: Write> Writer<W> {
33683370
} = *expr
33693371
{
33703372
match fun {
3373+
// WGSL's `dot` function works on any `vecN` type, but Metal's only
3374+
// works on floating-point vectors, so we emit inline code for
3375+
// integer vector `dot` calls. But that code uses each argument `N`
3376+
// times, once for each component (see `put_dot_product`), so to
3377+
// avoid duplicated evaluation, we must bake integer operands.
3378+
// This applies both when using the polyfill (because of the duplicate
3379+
// evaluation issue) and when we don't use the polyfill (because we
3380+
// need them to be emitted before casting to packed chars -- see the
3381+
// comment at the call to `put_casting_to_packed_chars`).
33713382
crate::MathFunction::Dot4U8Packed | crate::MathFunction::Dot4I8Packed => {
33723383
self.need_bake_expressions.insert(arg);
33733384
self.need_bake_expressions.insert(arg1.unwrap());
@@ -5785,15 +5796,15 @@ template <typename A>
57855796
}
57865797

57875798
/// Build the mangled helper name for integer vector dot products.
5788-
/// Result format: `{DOT_FUNCTION}_{type}{N}` (e.g., `naga_dot_int3`).
5799+
/// Result format: `{DOT_FUNCTION_PREFIX}_{type}{N}` (e.g., `naga_dot_int3`).
57895800
fn get_dot_wrapper_function_helper_name(
57905801
&self,
57915802
scalar: crate::Scalar,
57925803
size: crate::VectorSize,
57935804
) -> String {
57945805
let type_name = scalar.to_msl_name();
57955806
let size_suffix = common::vector_size_str(size);
5796-
format!("{DOT_FUNCTION}_{type_name}{size_suffix}")
5807+
format!("{DOT_FUNCTION_PREFIX}_{type_name}{size_suffix}")
57975808
}
57985809

57995810
#[allow(clippy::too_many_arguments)]

0 commit comments

Comments
 (0)