Skip to content

Commit 6d15d1b

Browse files
committed
[naga][metal] Anottate dot product functions as wrapped functions
1 parent f9d946b commit 6d15d1b

File tree

3 files changed

+21
-11
lines changed

3 files changed

+21
-11
lines changed

naga/src/back/msl/writer.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5852,15 +5852,13 @@ template <typename A>
58525852
writeln!(self.out)?;
58535853
}
58545854

5855-
crate::MathFunction::Dot => match arg_ty {
5855+
crate::MathFunction::Dot => match *arg_ty {
58565856
crate::TypeInner::Vector { size, scalar }
58575857
if matches!(
58585858
scalar.kind,
58595859
crate::ScalarKind::Sint | crate::ScalarKind::Uint
58605860
) =>
58615861
{
5862-
let size = *size;
5863-
let scalar = *scalar;
58645862
// De-duplicate per (fun, arg type) like other wrapped math functions
58655863
let wrapped = WrappedFunction::Math {
58665864
fun,

naga/tests/out/msl/wgsl-functions.msl

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,22 @@ metal::float2 test_fma(
1313
return metal::fma(a, b, c);
1414
}
1515

16+
int naga_dot_int2(metal::int2 a, metal::int2 b) {
17+
return ( + a.x * b.x + a.y * b.y);
18+
}
19+
20+
uint naga_dot_uint3(metal::uint3 a, metal::uint3 b) {
21+
return ( + a.x * b.x + a.y * b.y + a.z * b.z);
22+
}
23+
1624
int test_integer_dot_product(
1725
) {
1826
metal::int2 a_2_ = metal::int2(1);
1927
metal::int2 b_2_ = metal::int2(1);
20-
int c_2_ = ( + a_2_.x * b_2_.x + a_2_.y * b_2_.y);
28+
int c_2_ = naga_dot_int2(a_2_, b_2_);
2129
metal::uint3 a_3_ = metal::uint3(1u);
2230
metal::uint3 b_3_ = metal::uint3(1u);
23-
uint c_3_ = ( + a_3_.x * b_3_.x + a_3_.y * b_3_.y + a_3_.z * b_3_.z);
31+
uint c_3_ = naga_dot_uint3(a_3_, b_3_);
2432
return 32;
2533
}
2634

naga/tests/out/msl/wgsl-int64.msl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ long naga_abs(long val) {
4343
return metal::select(as_type<long>(-as_type<ulong>(val)), val, val >= 0);
4444
}
4545

46+
long naga_dot_long2(metal::long2 a, metal::long2 b) {
47+
return ( + a.x * b.x + a.y * b.y);
48+
}
49+
4650
long int64_function(
4751
long x,
4852
thread long& private_variable,
@@ -111,11 +115,9 @@ long int64_function(
111115
long _e130 = val;
112116
val = as_type<long>(as_type<ulong>(_e130) + as_type<ulong>(metal::clamp(_e126, _e127, _e128)));
113117
long _e132 = val;
114-
metal::long2 _e133 = metal::long2(_e132);
115118
long _e134 = val;
116-
metal::long2 _e135 = metal::long2(_e134);
117119
long _e137 = val;
118-
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(( + _e133.x * _e135.x + _e133.y * _e135.y)));
120+
val = as_type<long>(as_type<ulong>(_e137) + as_type<ulong>(naga_dot_long2(metal::long2(_e132), metal::long2(_e134))));
119121
long _e139 = val;
120122
long _e140 = val;
121123
long _e142 = val;
@@ -135,6 +137,10 @@ ulong naga_f2u64(float value) {
135137
return static_cast<ulong>(metal::clamp(value, 0.0, 18446743000000000000.0));
136138
}
137139

140+
ulong naga_dot_ulong2(metal::ulong2 a, metal::ulong2 b) {
141+
return ( + a.x * b.x + a.y * b.y);
142+
}
143+
138144
ulong uint64_function(
139145
ulong x_1,
140146
constant UniformCompatible& input_uniform,
@@ -199,11 +205,9 @@ ulong uint64_function(
199205
ulong _e125 = val_1;
200206
val_1 = _e125 + metal::clamp(_e121, _e122, _e123);
201207
ulong _e127 = val_1;
202-
metal::ulong2 _e128 = metal::ulong2(_e127);
203208
ulong _e129 = val_1;
204-
metal::ulong2 _e130 = metal::ulong2(_e129);
205209
ulong _e132 = val_1;
206-
val_1 = _e132 + ( + _e128.x * _e130.x + _e128.y * _e130.y);
210+
val_1 = _e132 + naga_dot_ulong2(metal::ulong2(_e127), metal::ulong2(_e129));
207211
ulong _e134 = val_1;
208212
ulong _e135 = val_1;
209213
ulong _e137 = val_1;

0 commit comments

Comments
 (0)