From 1c90d193f48a6968855562062236c6d907a744e7 Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated Date: Thu, 14 Aug 2025 12:53:07 -0500 Subject: [PATCH 01/89] Initial commit --- docs/api-specs/mesh_shading.md | 32 ++-- naga-cli/src/bin/naga.rs | 25 +++ naga/src/back/dot/mod.rs | 19 +++ naga/src/back/glsl/features.rs | 1 + naga/src/back/glsl/mod.rs | 23 ++- naga/src/back/hlsl/conv.rs | 3 + naga/src/back/hlsl/mod.rs | 3 +- naga/src/back/hlsl/writer.rs | 19 ++- naga/src/back/msl/mod.rs | 5 + naga/src/back/msl/writer.rs | 20 ++- naga/src/back/pipeline_constants.rs | 45 ++++++ naga/src/back/wgsl/writer.rs | 5 +- naga/src/common/wgsl/to_wgsl.rs | 8 +- naga/src/compact/mod.rs | 56 +++++++ naga/src/compact/statements.rs | 34 ++++ naga/src/front/glsl/functions.rs | 4 + naga/src/front/glsl/mod.rs | 2 +- naga/src/front/glsl/variables.rs | 1 + naga/src/front/interpolator.rs | 1 + naga/src/front/spv/function.rs | 2 + naga/src/front/spv/mod.rs | 4 + naga/src/ir/mod.rs | 76 ++++++++- naga/src/proc/mod.rs | 3 + naga/src/proc/terminator.rs | 1 + naga/src/valid/analyzer.rs | 102 +++++++++++- naga/src/valid/function.rs | 42 +++++ naga/src/valid/handles.rs | 16 ++ naga/src/valid/interface.rs | 232 ++++++++++++++++++++++++++-- naga/src/valid/mod.rs | 2 + naga/src/valid/type.rs | 9 +- wgpu-core/src/validation.rs | 4 +- wgpu-hal/src/vulkan/adapter.rs | 3 + 32 files changed, 754 insertions(+), 48 deletions(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index 8c979890b78..ee14f99e757 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -80,32 +80,36 @@ This shader stage can be selected by marking a function with `@task`. Task shade The output of this determines how many workgroups of mesh shaders will be dispatched. Once dispatched, global id variables will be local to the task shader workgroup dispatch, and mesh shaders won't know the position of their dispatch among all mesh shader dispatches unless this is passed through the payload. The output may be zero to skip dispatching any mesh shader workgroups for the task shader workgroup. -If task shaders are marked with `@payload(someVar)`, where `someVar` is global variable declared like `var someVar: `, task shaders may write to `someVar`. This payload is passed to the mesh shader workgroup that is invoked. The mesh shader can skip declaring `@payload` to ignore this input. +If task shaders are marked with `@payload(someVar)`, where `someVar` is global variable declared like `var someVar: `, task shaders may use `someVar` as if it is a read-write workgroup storage variable. This payload is passed to the mesh shader workgroup that is invoked. The mesh shader can skip declaring `@payload` to ignore this input. ### Mesh shader This shader stage can be selected by marking a function with `@mesh`. Mesh shaders must not return anything. -Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, mesh shaders cannot write to this workgroup memory. Declaring `@payload` in a pipeline with no task shader, in a pipeline with a task shader that doesn't declare `@payload`, or in a task shader with an `@payload` that is statically sized and smaller than the mesh shader payload is illegal. +Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, mesh shaders cannot write to this memory. Declaring `@payload` in a pipeline with no task shader, in a pipeline with a task shader that doesn't declare `@payload`, or in a task shader with an `@payload` that is statically sized and smaller than the mesh shader payload is illegal. -Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output. +Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output, and must be a struct. Mesh shaders must also be marked with `@primitive_output(OutputType, numOutputs)`, which is similar to `@vertex_output` except it describes the primitive outputs. ### Mesh shader outputs -Primitive outputs from mesh shaders have some additional builtins they can set. These include `@builtin(cull_primitive)`, which must be a boolean value. If this is set to true, then the primitive is skipped during rendering. +Vertex outputs from mesh shaders function identically to outputs of vertex shaders, and as such must have a field with `@builtin(position)`. + +Primitive outputs from mesh shaders have some additional builtins they can set. These include `@builtin(cull_primitive)`, which must be a boolean value. If this is set to true, then the primitive is skipped during rendering. All non-builtin primitive outputs must be decorated with `@per_primitive`. Mesh shader primitive outputs must also specify exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`. This determines the output topology of the mesh shader, and must match the output topology of the pipeline descriptor the mesh shader is used with. These must be of type `vec3`, `vec2`, and `u32` respectively. When setting this, each of the indices must be less than the number of vertices declared in `setMeshOutputs`. Additionally, the `@location` attributes from the vertex and primitive outputs can't overlap. -Before setting any vertices or indices, or exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly these numbers of vertices and primitives. +Before setting any vertices or indices, or exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly these numbers of vertices and primitives. A varying member with `@per_primitive` cannot be used in function interfaces except as the primitive output for mesh shaders or as input for fragment shaders. The mesh shader can write to vertices using the `setVertex(idx: u32, vertex: VertexOutput)` where `VertexOutput` is replaced with the vertex type declared in `@vertex_output`, and `idx` is the index of the vertex to write. Similarly, the mesh shader can write to vertices using `setPrimitive(idx: u32, primitive: PrimitiveOutput)`. These can be written to multiple times, however unsynchronized writes are undefined behavior. The primitives and indices are shared across the entire mesh shader workgroup. ### Fragment shader -Fragment shaders may now be passed the primitive info from a mesh shader the same was as they are passed vertex inputs, for example `fn fs_main(vertex: VertexOutput, primitive: PrimitiveOutput)`. The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. +Fragment shaders can access vertex output data as if it is from a vertex shader. They can also access primitive output data, provided the input is decorated with `@per_primitive`. The `@per_primitive` attribute can be applied to a value directly, such as `@per_primitive @location(1) value: vec4`, to a struct such as `@per_primitive primitive_input: PrimitiveInput` where `PrimitiveInput` is a struct containing fields decorated with `@location` and `@builtin`, or to members of a struct that are themselves decorated with `@location` or `@builtin`. + +The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. Using `@per_primitive` also requires enabling the mesh shader extension. Additionally, the locations of vertex and primitive input cannot overlap. ### Full example @@ -115,9 +119,9 @@ The following is a full example of WGSL shaders that could be used to create a m enable mesh_shading; const positions = array( - vec4(0.,-1.,0.,1.), - vec4(-1.,1.,0.,1.), - vec4(1.,1.,0.,1.) + vec4(0.,1.,0.,1.), + vec4(-1.,-1.,0.,1.), + vec4(1.,-1.,0.,1.) ); const colors = array( vec4(0.,1.,0.,1.), @@ -128,7 +132,7 @@ struct TaskPayload { colorMask: vec4, visible: bool, } -var taskPayload: TaskPayload; +var taskPayload: TaskPayload; var workgroupData: f32; struct VertexOutput { @builtin(position) position: vec4, @@ -137,14 +141,12 @@ struct VertexOutput { struct PrimitiveOutput { @builtin(triangle_indices) index: vec3, @builtin(cull_primitive) cull: bool, - @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } struct PrimitiveInput { - @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } -fn test_function(input: u32) { -} @task @payload(taskPayload) @workgroup_size(1) @@ -163,8 +165,6 @@ fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocati workgroupData = 2.0; var v: VertexOutput; - test_function(1); - v.position = positions[0]; v.color = colors[0] * taskPayload.colorMask; setVertex(0, v); diff --git a/naga-cli/src/bin/naga.rs b/naga-cli/src/bin/naga.rs index 44369e9df7d..171d970166e 100644 --- a/naga-cli/src/bin/naga.rs +++ b/naga-cli/src/bin/naga.rs @@ -64,6 +64,12 @@ struct Args { #[argh(option)] shader_model: Option, + /// the SPIR-V version to use if targeting SPIR-V + /// + /// For example, 1.0, 1.4, etc + #[argh(option)] + spirv_version: Option, + /// the shader stage, for example 'frag', 'vert', or 'compute'. /// if the shader stage is unspecified it will be derived from /// the file extension. @@ -189,6 +195,22 @@ impl FromStr for ShaderModelArg { } } +#[derive(Debug, Clone)] +struct SpirvVersionArg(u8, u8); + +impl FromStr for SpirvVersionArg { + type Err = String; + + fn from_str(s: &str) -> Result { + let dot = s + .find(".") + .ok_or_else(|| "Missing dot separator".to_owned())?; + let major = s[..dot].parse::().map_err(|e| e.to_string())?; + let minor = s[dot + 1..].parse::().map_err(|e| e.to_string())?; + Ok(Self(major, minor)) + } +} + /// Newtype so we can implement [`FromStr`] for `ShaderSource`. #[derive(Debug, Clone, Copy)] struct ShaderStage(naga::ShaderStage); @@ -465,6 +487,9 @@ fn run() -> anyhow::Result<()> { if let Some(ref version) = args.metal_version { params.msl.lang_version = version.0; } + if let Some(ref version) = args.spirv_version { + params.spv_out.lang_version = (version.0, version.1); + } params.keep_coordinate_space = args.keep_coordinate_space; params.dot.cfg_only = args.dot_cfg_only; diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 826dad1c219..1f1396eccff 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -307,6 +307,25 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } + S::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + self.dependencies.push((id, vertex_count, "vertex_count")); + self.dependencies + .push((id, primitive_count, "primitive_count")); + "SetMeshOutputs" + } + S::MeshFunction(crate::MeshFunction::SetVertex { index, value }) => { + self.dependencies.push((id, index, "index")); + self.dependencies.push((id, value, "value")); + "SetVertex" + } + S::MeshFunction(crate::MeshFunction::SetPrimitive { index, value }) => { + self.dependencies.push((id, index, "index")); + self.dependencies.push((id, value, "value")); + "SetPrimitive" + } S::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.dependencies.push((id, predicate, "predicate")); diff --git a/naga/src/back/glsl/features.rs b/naga/src/back/glsl/features.rs index a6dfe4e3100..b884f08ac39 100644 --- a/naga/src/back/glsl/features.rs +++ b/naga/src/back/glsl/features.rs @@ -610,6 +610,7 @@ impl Writer<'_, W> { interpolation, sampling, blend_src, + per_primitive: _, } => { if interpolation == Some(Interpolation::Linear) { self.features.request(Features::NOPERSPECTIVE_QUALIFIER); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index e78af74c844..1af18528944 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -139,7 +139,8 @@ impl crate::AddressSpace { | crate::AddressSpace::Uniform | crate::AddressSpace::Storage { .. } | crate::AddressSpace::Handle - | crate::AddressSpace::PushConstant => false, + | crate::AddressSpace::PushConstant + | crate::AddressSpace::TaskPayload => false, } } } @@ -1300,6 +1301,9 @@ impl<'a, W: Write> Writer<'a, W> { crate::AddressSpace::Storage { .. } => { self.write_interface_block(handle, global)?; } + crate::AddressSpace::TaskPayload => { + self.write_interface_block(handle, global)?; + } // A global variable in the `Function` address space is a // contradiction in terms. crate::AddressSpace::Function => unreachable!(), @@ -1614,6 +1618,7 @@ impl<'a, W: Write> Writer<'a, W> { interpolation, sampling, blend_src, + per_primitive: _, } => (location, interpolation, sampling, blend_src), crate::Binding::BuiltIn(built_in) => { match built_in { @@ -1732,6 +1737,7 @@ impl<'a, W: Write> Writer<'a, W> { interpolation: None, sampling: None, blend_src, + per_primitive: false, }, stage: self.entry_point.stage, options: VaryingOptions::from_writer_options(self.options, output), @@ -2669,6 +2675,11 @@ impl<'a, W: Write> Writer<'a, W> { self.write_image_atomic(ctx, image, coordinate, array_index, fun, value)? } Statement::RayQuery { .. } => unreachable!(), + Statement::MeshFunction( + crate::MeshFunction::SetMeshOutputs { .. } + | crate::MeshFunction::SetVertex { .. } + | crate::MeshFunction::SetPrimitive { .. }, + ) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); @@ -5247,6 +5258,15 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s Bi::SubgroupId => "gl_SubgroupID", Bi::SubgroupSize => "gl_SubgroupSize", Bi::SubgroupInvocationId => "gl_SubgroupInvocationID", + // mesh + // TODO: figure out how to map these to glsl things as glsl treats them as arrays + Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices + | Bi::MeshTaskSize => { + unimplemented!() + } } } @@ -5262,6 +5282,7 @@ const fn glsl_storage_qualifier(space: crate::AddressSpace) -> Option<&'static s As::Handle => Some("uniform"), As::WorkGroup => Some("shared"), As::PushConstant => Some("uniform"), + As::TaskPayload => unreachable!(), } } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index ed40cbe5102..d6ccc5ec6e4 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -183,6 +183,9 @@ impl crate::BuiltIn { Self::PointSize | Self::ViewIndex | Self::PointCoord | Self::DrawID => { return Err(Error::Custom(format!("Unsupported builtin {self:?}"))) } + Self::CullPrimitive => "SV_CullPrimitive", + Self::PointIndex | Self::LineIndices | Self::TriangleIndices => unimplemented!(), + Self::MeshTaskSize => unreachable!(), }) } } diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 8df06cf1323..f357c02bb3f 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -283,7 +283,8 @@ impl crate::ShaderStage { Self::Vertex => "vs", Self::Fragment => "ps", Self::Compute => "cs", - Self::Task | Self::Mesh => unreachable!(), + Self::Task => "ts", + Self::Mesh => "ms", } } } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 357b8597521..9401766448f 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -507,7 +507,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_wrapped_functions(module, &ctx)?; - if ep.stage == ShaderStage::Compute { + if ep.stage.compute_like() { // HLSL is calling workgroup size "num threads" let num_threads = ep.workgroup_size; writeln!( @@ -967,6 +967,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { self.write_type(module, global.ty)?; "" } + crate::AddressSpace::TaskPayload => unimplemented!(), crate::AddressSpace::Uniform => { // constant buffer declarations are expected to be inlined, e.g. // `cbuffer foo: register(b0) { field1: type1; }` @@ -2599,6 +2600,19 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ".Abort();")?; } }, + Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + write!(self.out, "{level}SetMeshOutputCounts(")?; + self.write_expr(module, vertex_count, func_ctx)?; + write!(self.out, ", ")?; + self.write_expr(module, primitive_count, func_ctx)?; + write!(self.out, ");")?; + } + Statement::MeshFunction( + crate::MeshFunction::SetVertex { .. } | crate::MeshFunction::SetPrimitive { .. }, + ) => unimplemented!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); @@ -3076,7 +3090,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup - | crate::AddressSpace::PushConstant, + | crate::AddressSpace::PushConstant + | crate::AddressSpace::TaskPayload, ) | None => true, Some(crate::AddressSpace::Uniform) => { diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 7bc8289b9b8..8a2e07635b8 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -494,6 +494,7 @@ impl Options { interpolation, sampling, blend_src, + per_primitive: _, } => match mode { LocationMode::VertexInput => Ok(ResolvedBinding::Attribute(location)), LocationMode::FragmentOutput => { @@ -651,6 +652,10 @@ impl ResolvedBinding { Bi::CullDistance | Bi::ViewIndex | Bi::DrawID => { return Err(Error::UnsupportedBuiltIn(built_in)) } + Bi::CullPrimitive => "primitive_culled", + // TODO: figure out how to make this written as a function call + Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices => unimplemented!(), + Bi::MeshTaskSize => unreachable!(), }; write!(out, "{name}")?; } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 2525855cd70..a6b80a2dd27 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -578,7 +578,8 @@ impl crate::AddressSpace { | Self::Private | Self::WorkGroup | Self::PushConstant - | Self::Handle => true, + | Self::Handle + | Self::TaskPayload => true, Self::Function => false, } } @@ -591,6 +592,7 @@ impl crate::AddressSpace { // may end up with "const" even if the binding is read-write, // and that should be OK. Self::Storage { .. } => true, + Self::TaskPayload => unimplemented!(), // These should always be read-write. Self::Private | Self::WorkGroup => false, // These translate to `constant` address space, no need for qualifiers. @@ -607,6 +609,7 @@ impl crate::AddressSpace { Self::Storage { .. } => Some("device"), Self::Private | Self::Function => Some("thread"), Self::WorkGroup => Some("threadgroup"), + Self::TaskPayload => Some("object_data"), } } } @@ -4020,6 +4023,14 @@ impl Writer { } } } + // TODO: write emitters for these + crate::Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { .. }) => { + unimplemented!() + } + crate::Statement::MeshFunction( + crate::MeshFunction::SetVertex { .. } + | crate::MeshFunction::SetPrimitive { .. }, + ) => unimplemented!(), crate::Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = self.namer.call(""); @@ -6169,7 +6180,7 @@ template LocationMode::Uniform, false, ), - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(), }; // Should this entry point be modified to do vertex pulling? @@ -6232,6 +6243,9 @@ template break; } } + crate::AddressSpace::TaskPayload => { + unimplemented!() + } crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {} @@ -7159,7 +7173,7 @@ mod workgroup_mem_init { fun_info: &valid::FunctionInfo, ) -> bool { options.zero_initialize_workgroup_memory - && ep.stage == crate::ShaderStage::Compute + && ep.stage.compute_like() && module.global_variables.iter().any(|(handle, var)| { !fun_info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }) diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index d2b3ed70eda..c009082a3c9 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -39,6 +39,8 @@ pub enum PipelineConstantError { ValidationError(#[from] WithSpan), #[error("workgroup_size override isn't strictly positive")] NegativeWorkgroupSize, + #[error("max vertices or max primitives is negative")] + NegativeMeshOutputMax, } /// Compact `module` and replace all overrides with constants. @@ -243,6 +245,7 @@ pub fn process_overrides<'a>( for ep in entry_points.iter_mut() { process_function(&mut module, &override_map, &mut layouter, &mut ep.function)?; process_workgroup_size_override(&mut module, &adjusted_global_expressions, ep)?; + process_mesh_shader_overrides(&mut module, &adjusted_global_expressions, ep)?; } module.entry_points = entry_points; module.overrides = overrides; @@ -296,6 +299,28 @@ fn process_workgroup_size_override( Ok(()) } +fn process_mesh_shader_overrides( + module: &mut Module, + adjusted_global_expressions: &HandleVec>, + ep: &mut crate::EntryPoint, +) -> Result<(), PipelineConstantError> { + if let Some(ref mut mesh_info) = ep.mesh_info { + if let Some(r#override) = mesh_info.max_vertices_override { + mesh_info.max_vertices = module + .to_ctx() + .eval_expr_to_u32(adjusted_global_expressions[r#override]) + .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)?; + } + if let Some(r#override) = mesh_info.max_primitives_override { + mesh_info.max_primitives = module + .to_ctx() + .eval_expr_to_u32(adjusted_global_expressions[r#override]) + .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)?; + } + } + Ok(()) +} + /// Add a [`Constant`] to `module` for the override `old_h`. /// /// Add the new `Constant` to `override_map` and `adjusted_constant_initializers`. @@ -835,6 +860,26 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S crate::RayQueryFunction::Terminate => {} } } + Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { + ref mut vertex_count, + ref mut primitive_count, + }) => { + adjust(vertex_count); + adjust(primitive_count); + } + Statement::MeshFunction( + crate::MeshFunction::SetVertex { + ref mut index, + ref mut value, + } + | crate::MeshFunction::SetPrimitive { + ref mut index, + ref mut value, + }, + ) => { + adjust(index); + adjust(value); + } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 8982242daca..245bc40dd5d 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -207,7 +207,7 @@ impl Writer { Attribute::Stage(ShaderStage::Compute), Attribute::WorkGroupSize(ep.workgroup_size), ], - ShaderStage::Task | ShaderStage::Mesh => unreachable!(), + ShaderStage::Mesh | ShaderStage::Task => unreachable!(), }; self.write_attributes(&attributes)?; @@ -856,6 +856,7 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), + Statement::MeshFunction(..) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); @@ -1822,6 +1823,7 @@ fn map_binding_to_attribute(binding: &crate::Binding) -> Vec { interpolation, sampling, blend_src: None, + per_primitive: _, } => vec![ Attribute::Location(location), Attribute::Interpolate(interpolation, sampling), @@ -1831,6 +1833,7 @@ fn map_binding_to_attribute(binding: &crate::Binding) -> Vec { interpolation, sampling, blend_src: Some(blend_src), + per_primitive: _, } => vec![ Attribute::Location(location), Attribute::BlendSrc(blend_src), diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 035c4eafb32..dc891aa5a3f 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -188,7 +188,12 @@ impl TryToWgsl for crate::BuiltIn { | Bi::PointSize | Bi::DrawID | Bi::PointCoord - | Bi::WorkGroupSize => return None, + | Bi::WorkGroupSize + | Bi::CullPrimitive + | Bi::TriangleIndices + | Bi::LineIndices + | Bi::MeshTaskSize + | Bi::PointIndex => return None, }) } } @@ -352,6 +357,7 @@ pub const fn address_space_str( As::WorkGroup => "workgroup", As::Handle => return (None, None), As::Function => "function", + As::TaskPayload => return (None, None), }), None, ) diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index d059ba21e4f..a7d3d463f11 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -221,6 +221,45 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { } } + for entry in &module.entry_points { + if let Some(task_payload) = entry.task_payload { + module_tracer.global_variables_used.insert(task_payload); + } + if let Some(ref mesh_info) = entry.mesh_info { + module_tracer + .types_used + .insert(mesh_info.vertex_output_type); + module_tracer + .types_used + .insert(mesh_info.primitive_output_type); + if let Some(max_vertices_override) = mesh_info.max_vertices_override { + module_tracer + .global_expressions_used + .insert(max_vertices_override); + } + if let Some(max_primitives_override) = mesh_info.max_primitives_override { + module_tracer + .global_expressions_used + .insert(max_primitives_override); + } + } + if entry.stage == crate::ShaderStage::Task || entry.stage == crate::ShaderStage::Mesh { + // u32 should always be there if the module is valid, as it is e.g. the type of some expressions + let u32_type = module + .types + .iter() + .find_map(|tuple| { + if tuple.1.inner == crate::TypeInner::Scalar(crate::Scalar::U32) { + Some(tuple.0) + } else { + None + } + }) + .unwrap(); + module_tracer.types_used.insert(u32_type); + } + } + module_tracer.type_expression_tandem(); // Now that we know what is used and what is never touched, @@ -342,6 +381,23 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { &module_map, &mut reused_named_expressions, ); + if let Some(ref mut task_payload) = entry.task_payload { + module_map.globals.adjust(task_payload); + } + if let Some(ref mut mesh_info) = entry.mesh_info { + module_map.types.adjust(&mut mesh_info.vertex_output_type); + module_map + .types + .adjust(&mut mesh_info.primitive_output_type); + if let Some(ref mut max_vertices_override) = mesh_info.max_vertices_override { + module_map.global_expressions.adjust(max_vertices_override); + } + if let Some(ref mut max_primitives_override) = mesh_info.max_primitives_override { + module_map + .global_expressions + .adjust(max_primitives_override); + } + } } } diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index 39d6065f5f0..b370501baca 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -117,6 +117,20 @@ impl FunctionTracer<'_> { self.expressions_used.insert(query); self.trace_ray_query_function(fun); } + St::MeshFunction(crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + }) => { + self.expressions_used.insert(vertex_count); + self.expressions_used.insert(primitive_count); + } + St::MeshFunction( + crate::MeshFunction::SetPrimitive { index, value } + | crate::MeshFunction::SetVertex { index, value }, + ) => { + self.expressions_used.insert(index); + self.expressions_used.insert(value); + } St::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.expressions_used.insert(predicate); @@ -335,6 +349,26 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } + St::MeshFunction(crate::MeshFunction::SetMeshOutputs { + ref mut vertex_count, + ref mut primitive_count, + }) => { + adjust(vertex_count); + adjust(primitive_count); + } + St::MeshFunction( + crate::MeshFunction::SetVertex { + ref mut index, + ref mut value, + } + | crate::MeshFunction::SetPrimitive { + ref mut index, + ref mut value, + }, + ) => { + adjust(index); + adjust(value); + } St::SubgroupBallot { ref mut result, ref mut predicate, diff --git a/naga/src/front/glsl/functions.rs b/naga/src/front/glsl/functions.rs index 7de7364cd40..ba096a82b3b 100644 --- a/naga/src/front/glsl/functions.rs +++ b/naga/src/front/glsl/functions.rs @@ -1377,6 +1377,8 @@ impl Frontend { result: ty.map(|ty| FunctionResult { ty, binding: None }), ..Default::default() }, + mesh_info: None, + task_payload: None, }); Ok(()) @@ -1446,6 +1448,7 @@ impl Context<'_> { interpolation, sampling: None, blend_src: None, + per_primitive: false, }; location += 1; @@ -1482,6 +1485,7 @@ impl Context<'_> { interpolation, sampling: None, blend_src: None, + per_primitive: false, }; location += 1; binding diff --git a/naga/src/front/glsl/mod.rs b/naga/src/front/glsl/mod.rs index 876add46a1c..e5eda6b3ad9 100644 --- a/naga/src/front/glsl/mod.rs +++ b/naga/src/front/glsl/mod.rs @@ -107,7 +107,7 @@ impl ShaderMetadata { self.version = 0; self.profile = Profile::Core; self.stage = stage; - self.workgroup_size = [u32::from(stage == ShaderStage::Compute); 3]; + self.workgroup_size = [u32::from(stage.compute_like()); 3]; self.early_fragment_tests = false; self.extensions.clear(); } diff --git a/naga/src/front/glsl/variables.rs b/naga/src/front/glsl/variables.rs index ef98143b769..98871bd2f81 100644 --- a/naga/src/front/glsl/variables.rs +++ b/naga/src/front/glsl/variables.rs @@ -465,6 +465,7 @@ impl Frontend { interpolation, sampling, blend_src, + per_primitive: false, }, handle, storage, diff --git a/naga/src/front/interpolator.rs b/naga/src/front/interpolator.rs index e23cae0e7c2..126e860426c 100644 --- a/naga/src/front/interpolator.rs +++ b/naga/src/front/interpolator.rs @@ -44,6 +44,7 @@ impl crate::Binding { interpolation: ref mut interpolation @ None, ref mut sampling, blend_src: _, + per_primitive: _, } = *self { match ty.scalar_kind() { diff --git a/naga/src/front/spv/function.rs b/naga/src/front/spv/function.rs index 67cbf05f04f..48b23e7c4c4 100644 --- a/naga/src/front/spv/function.rs +++ b/naga/src/front/spv/function.rs @@ -596,6 +596,8 @@ impl> super::Frontend { workgroup_size: ep.workgroup_size, workgroup_size_overrides: None, function, + mesh_info: None, + task_payload: None, }); Ok(()) diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 960437ece58..396318f14dc 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -263,6 +263,7 @@ impl Decoration { interpolation, sampling, blend_src: None, + per_primitive: false, }), _ => Err(Error::MissingDecoration(spirv::Decoration::Location)), } @@ -4613,6 +4614,7 @@ impl> Frontend { | S::Atomic { .. } | S::ImageAtomic { .. } | S::RayQuery { .. } + | S::MeshFunction(..) | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } => {} @@ -4894,6 +4896,8 @@ impl> Frontend { spirv::ExecutionModel::Vertex => crate::ShaderStage::Vertex, spirv::ExecutionModel::Fragment => crate::ShaderStage::Fragment, spirv::ExecutionModel::GLCompute => crate::ShaderStage::Compute, + spirv::ExecutionModel::TaskEXT => crate::ShaderStage::Task, + spirv::ExecutionModel::MeshEXT => crate::ShaderStage::Mesh, _ => return Err(Error::UnsupportedExecutionModel(exec_model as u32)), }, name, diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 257445952b8..a182bf0e064 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -329,6 +329,16 @@ pub enum ShaderStage { Mesh, } +impl ShaderStage { + // TODO: make more things respect this + pub const fn compute_like(self) -> bool { + match self { + Self::Vertex | Self::Fragment => false, + Self::Compute | Self::Task | Self::Mesh => true, + } + } +} + /// Addressing space of variables. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] @@ -363,6 +373,8 @@ pub enum AddressSpace { /// /// [`SHADER_FLOAT16`]: crate::valid::Capabilities::SHADER_FLOAT16 PushConstant, + /// Task shader to mesh shader payload + TaskPayload, } /// Built-in inputs and outputs. @@ -373,7 +385,7 @@ pub enum AddressSpace { pub enum BuiltIn { Position { invariant: bool }, ViewIndex, - // vertex + // vertex (and often mesh) BaseInstance, BaseVertex, ClipDistance, @@ -386,10 +398,10 @@ pub enum BuiltIn { FragDepth, PointCoord, FrontFacing, - PrimitiveIndex, + PrimitiveIndex, // Also for mesh output SampleIndex, SampleMask, - // compute + // compute (and task/mesh) GlobalInvocationId, LocalInvocationId, LocalInvocationIndex, @@ -401,6 +413,12 @@ pub enum BuiltIn { SubgroupId, SubgroupSize, SubgroupInvocationId, + // mesh + MeshTaskSize, + CullPrimitive, + PointIndex, + LineIndices, + TriangleIndices, } /// Number of bytes per scalar. @@ -966,6 +984,7 @@ pub enum Binding { /// Optional `blend_src` index used for dual source blending. /// See blend_src: Option, + per_primitive: bool, }, } @@ -1935,7 +1954,9 @@ pub enum Statement { /// [`Loop`] statement. /// /// [`Loop`]: Statement::Loop - Return { value: Option> }, + Return { + value: Option>, + }, /// Aborts the current shader execution. /// @@ -2141,6 +2162,7 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + MeshFunction(MeshFunction), /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. @@ -2314,6 +2336,9 @@ pub struct EntryPoint { pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, + /// The information relating to a mesh shader + pub mesh_info: Option, + pub task_payload: Option>, } /// Return types predeclared for the frexp, modf, and atomicCompareExchangeWeak built-in functions. @@ -2578,3 +2603,46 @@ pub struct Module { /// Doc comments. pub doc_comments: Option>, } + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MeshOutputTopology { + Points, + Lines, + Triangles, +} +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[allow(dead_code)] +pub struct MeshStageInfo { + pub topology: MeshOutputTopology, + pub max_vertices: u32, + pub max_vertices_override: Option>, + pub max_primitives: u32, + pub max_primitives_override: Option>, + pub vertex_output_type: Handle, + pub primitive_output_type: Handle, +} + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MeshFunction { + SetMeshOutputs { + vertex_count: Handle, + primitive_count: Handle, + }, + SetVertex { + index: Handle, + value: Handle, + }, + SetPrimitive { + index: Handle, + value: Handle, + }, +} diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 413e49c1eed..434c6e3f724 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -177,6 +177,9 @@ impl super::AddressSpace { crate::AddressSpace::Storage { access } => access, crate::AddressSpace::Handle => Sa::LOAD, crate::AddressSpace::PushConstant => Sa::LOAD, + // TaskPayload isn't always writable, but this is checked for elsewhere, + // when not using multiple payloads and matching the entry payload is checked. + crate::AddressSpace::TaskPayload => Sa::LOAD | Sa::STORE, } } } diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index b29ccb054a3..f76d4c06a3b 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -36,6 +36,7 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::ImageStore { .. } | S::Call { .. } | S::RayQuery { .. } + | S::MeshFunction(..) | S::Atomic { .. } | S::ImageAtomic { .. } | S::WorkGroupUniformLoad { .. } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 95ae40dcdb4..101ea046487 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -85,6 +85,16 @@ struct FunctionUniformity { exit: ExitFlags, } +/// Mesh shader related characteristics of a function. +#[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +#[cfg_attr(test, derive(PartialEq))] +pub struct FunctionMeshShaderInfo { + pub vertex_type: Option<(Handle, Handle)>, + pub primitive_type: Option<(Handle, Handle)>, +} + impl ops::BitOr for FunctionUniformity { type Output = Self; fn bitor(self, other: Self) -> Self { @@ -302,6 +312,8 @@ pub struct FunctionInfo { /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. diagnostic_filter_leaf: Option>, + + pub mesh_shader_info: FunctionMeshShaderInfo, } impl FunctionInfo { @@ -372,6 +384,14 @@ impl FunctionInfo { info.uniformity.non_uniform_result } + pub fn insert_global_use( + &mut self, + global_use: GlobalUse, + global: Handle, + ) { + self.global_uses[global.index()] |= global_use; + } + /// Record a use of `expr` for its value. /// /// This is used for almost all expression references. Anything @@ -482,6 +502,8 @@ impl FunctionInfo { *mine |= *other; } + self.try_update_mesh_info(&callee.mesh_shader_info)?; + Ok(FunctionUniformity { result: callee.uniformity.clone(), exit: if callee.may_kill { @@ -635,7 +657,8 @@ impl FunctionInfo { // local data is non-uniform As::Function | As::Private => false, // workgroup memory is exclusively accessed by the group - As::WorkGroup => true, + // task payload memory is very similar to workgroup memory + As::WorkGroup | As::TaskPayload => true, // uniform data As::Uniform | As::PushConstant => true, // storage data is only uniform when read-only @@ -1113,6 +1136,34 @@ impl FunctionInfo { } FunctionUniformity::new() } + S::MeshFunction(func) => match &func { + // TODO: double check all of this uniformity stuff. I frankly don't fully understand all of it. + &crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + let _ = self.add_ref(vertex_count); + let _ = self.add_ref(primitive_count); + FunctionUniformity::new() + } + &crate::MeshFunction::SetVertex { index, value } + | &crate::MeshFunction::SetPrimitive { index, value } => { + let _ = self.add_ref(index); + let _ = self.add_ref(value); + let ty = + self.expressions[value.index()].ty.clone().handle().ok_or( + FunctionError::InvalidMeshShaderOutputType(value).with_span(), + )?; + + if matches!(func, crate::MeshFunction::SetVertex { .. }) { + self.try_update_mesh_vertex_type(ty, value)?; + } else { + self.try_update_mesh_primitive_type(ty, value)?; + }; + + FunctionUniformity::new() + } + }, S::SubgroupBallot { result: _, predicate, @@ -1158,6 +1209,53 @@ impl FunctionInfo { } Ok(combined_uniformity) } + + fn try_update_mesh_vertex_type( + &mut self, + ty: Handle, + value: Handle, + ) -> Result<(), WithSpan> { + if let &Some(ref existing) = &self.mesh_shader_info.vertex_type { + if existing.0 != ty { + return Err( + FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() + ); + } + } else { + self.mesh_shader_info.vertex_type = Some((ty, value)); + } + Ok(()) + } + + fn try_update_mesh_primitive_type( + &mut self, + ty: Handle, + value: Handle, + ) -> Result<(), WithSpan> { + if let &Some(ref existing) = &self.mesh_shader_info.primitive_type { + if existing.0 != ty { + return Err( + FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() + ); + } + } else { + self.mesh_shader_info.primitive_type = Some((ty, value)); + } + Ok(()) + } + + fn try_update_mesh_info( + &mut self, + other: &FunctionMeshShaderInfo, + ) -> Result<(), WithSpan> { + if let &Some(ref other_vertex) = &other.vertex_type { + self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?; + } + if let &Some(ref other_primitive) = &other.vertex_type { + self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?; + } + Ok(()) + } } impl ModuleInfo { @@ -1193,6 +1291,7 @@ impl ModuleInfo { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: fun.diagnostic_filter_leaf, + mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); @@ -1326,6 +1425,7 @@ fn uniform_control_flow() { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext { constants: &Arena::new(), diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index dc19e191764..0ae2ffdb54f 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -217,6 +217,14 @@ pub enum FunctionError { EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] UnvisitedExpression(Handle), + #[error("Expression {0:?} should be u32, but isn't")] + InvalidMeshFunctionCall(Handle), + #[error("Mesh output types differ from {0:?} to {1:?}")] + ConflictingMeshOutputTypes(Handle, Handle), + #[error("Task payload variables differ from {0:?} to {1:?}")] + ConflictingTaskPayloadVariables(Handle, Handle), + #[error("Mesh shader output at {0:?} is not a user-defined struct")] + InvalidMeshShaderOutputType(Handle), } bitflags::bitflags! { @@ -1539,6 +1547,40 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } + S::MeshFunction(func) => { + let ensure_u32 = + |expr: Handle| -> Result<(), WithSpan> { + let u32_ty = TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)); + let ty = context + .resolve_type_impl(expr, &self.valid_expression_set) + .map_err_inner(|source| { + FunctionError::Expression { + source, + handle: expr, + } + .with_span_handle(expr, context.expressions) + })?; + if !context.compare_types(&u32_ty, ty) { + return Err(FunctionError::InvalidMeshFunctionCall(expr) + .with_span_handle(expr, context.expressions)); + } + Ok(()) + }; + match func { + crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + ensure_u32(vertex_count)?; + ensure_u32(primitive_count)?; + } + crate::MeshFunction::SetVertex { index, value: _ } + | crate::MeshFunction::SetPrimitive { index, value: _ } => { + ensure_u32(index)?; + // TODO: ensure it is correct for the value + } + } + } S::SubgroupBallot { result, predicate } => { stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index e8a69013434..a0153e9398c 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -801,6 +801,22 @@ impl super::Validator { } Ok(()) } + crate::Statement::MeshFunction(func) => match func { + crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + validate_expr(vertex_count)?; + validate_expr(primitive_count)?; + Ok(()) + } + crate::MeshFunction::SetVertex { index, value } + | crate::MeshFunction::SetPrimitive { index, value } => { + validate_expr(index)?; + validate_expr(value)?; + Ok(()) + } + }, crate::Statement::SubgroupBallot { result, predicate } => { validate_expr_opt(predicate)?; validate_expr(result)?; diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 7c8cc903139..51167a4810d 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -92,6 +92,10 @@ pub enum VaryingError { }, #[error("Workgroup size is multi dimensional, `@builtin(subgroup_id)` and `@builtin(subgroup_invocation_id)` are not supported.")] InvalidMultiDimensionalSubgroupBuiltIn, + #[error("The `@per_primitive` attribute can only be used in fragment shader inputs or mesh shader primitive outputs")] + InvalidPerPrimitive, + #[error("Non-builtin members of a mesh primitive output struct must be decorated with `@per_primitive`")] + MissingPerPrimitive, } #[derive(Clone, Debug, thiserror::Error)] @@ -123,6 +127,26 @@ pub enum EntryPointError { InvalidIntegerInterpolation { location: u32 }, #[error(transparent)] Function(#[from] FunctionError), + #[error("Non mesh shader entry point cannot have mesh shader attributes")] + UnexpectedMeshShaderAttributes, + #[error("Non mesh/task shader entry point cannot have task payload attribute")] + UnexpectedTaskPayload, + #[error("Task payload must be declared with `var`")] + TaskPayloadWrongAddressSpace, + #[error("For a task payload to be used, it must be declared with @payload")] + WrongTaskPayloadUsed, + #[error("A function can only set vertex and primitive types that correspond to the mesh shader attributes")] + WrongMeshOutputType, + #[error("Only mesh shader entry points can write to mesh output vertices and primitives")] + UnexpectedMeshShaderOutput, + #[error("Mesh shader entry point cannot have a return type")] + UnexpectedMeshShaderEntryResult, + #[error("Task shader entry point must return @builtin(mesh_task_size) vec3")] + WrongTaskShaderEntryResult, + #[error("Mesh output type must be a user-defined struct.")] + InvalidMeshOutputType, + #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] + InvalidMeshPrimitiveOutputType, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -139,6 +163,13 @@ fn storage_usage(access: crate::StorageAccess) -> GlobalUse { storage_usage } +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +enum MeshOutputType { + None, + VertexOutput, + PrimitiveOutput, +} + struct VaryingContext<'a> { stage: crate::ShaderStage, output: bool, @@ -149,6 +180,7 @@ struct VaryingContext<'a> { built_ins: &'a mut crate::FastHashSet, capabilities: Capabilities, flags: super::ValidationFlags, + mesh_output_type: MeshOutputType, } impl VaryingContext<'_> { @@ -236,10 +268,9 @@ impl VaryingContext<'_> { ), Bi::Position { .. } => ( match self.stage { - St::Vertex => self.output, + St::Vertex | St::Mesh => self.output, St::Fragment => !self.output, - St::Compute => false, - St::Task | St::Mesh => unreachable!(), + St::Compute | St::Task => false, }, *ty_inner == Ti::Vector { @@ -276,7 +307,7 @@ impl VaryingContext<'_> { *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::LocalInvocationIndex => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::GlobalInvocationId @@ -284,7 +315,7 @@ impl VaryingContext<'_> { | Bi::WorkGroupId | Bi::WorkGroupSize | Bi::NumWorkGroups => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Vector { size: Vs::Tri, @@ -292,17 +323,48 @@ impl VaryingContext<'_> { }, ), Bi::NumSubgroups | Bi::SubgroupId => ( - self.stage == St::Compute && !self.output, + self.stage.compute_like() && !self.output, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), Bi::SubgroupSize | Bi::SubgroupInvocationId => ( match self.stage { - St::Compute | St::Fragment => !self.output, + St::Compute | St::Fragment | St::Task | St::Mesh => !self.output, St::Vertex => false, - St::Task | St::Mesh => unreachable!(), }, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), + Bi::CullPrimitive => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner == Ti::Scalar(crate::Scalar::BOOL), + ), + Bi::PointIndex => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), + Bi::LineIndices => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner + == Ti::Vector { + size: Vs::Bi, + scalar: crate::Scalar::U32, + }, + ), + Bi::TriangleIndices => ( + self.mesh_output_type == MeshOutputType::PrimitiveOutput, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + scalar: crate::Scalar::U32, + }, + ), + Bi::MeshTaskSize => ( + self.stage == St::Task && self.output, + *ty_inner + == Ti::Vector { + size: Vs::Tri, + scalar: crate::Scalar::U32, + }, + ), }; if !visible { @@ -318,6 +380,7 @@ impl VaryingContext<'_> { interpolation, sampling, blend_src, + per_primitive, } => { // Only IO-shareable types may be stored in locations. if !self.type_info[ty.index()] @@ -326,6 +389,14 @@ impl VaryingContext<'_> { { return Err(VaryingError::NotIOShareableType(ty)); } + if !per_primitive && self.mesh_output_type == MeshOutputType::PrimitiveOutput { + return Err(VaryingError::MissingPerPrimitive); + } else if per_primitive + && ((self.stage != crate::ShaderStage::Fragment || self.output) + && self.mesh_output_type != MeshOutputType::PrimitiveOutput) + { + return Err(VaryingError::InvalidPerPrimitive); + } if let Some(blend_src) = blend_src { // `blend_src` is only valid if dual source blending was explicitly enabled, @@ -390,11 +461,12 @@ impl VaryingContext<'_> { } } + // TODO: update this to reflect the fact that per-primitive outputs aren't interpolated for fragment and mesh stages let needs_interpolation = match self.stage { crate::ShaderStage::Vertex => self.output, crate::ShaderStage::Fragment => !self.output, - crate::ShaderStage::Compute => false, - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Compute | crate::ShaderStage::Task => false, + crate::ShaderStage::Mesh => self.output, }; // It doesn't make sense to specify a sampling when `interpolation` is `Flat`, but @@ -595,7 +667,9 @@ impl super::Validator { TypeFlags::CONSTRUCTIBLE | TypeFlags::CREATION_RESOLVED, false, ), - crate::AddressSpace::WorkGroup => (TypeFlags::DATA | TypeFlags::SIZED, false), + crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => { + (TypeFlags::DATA | TypeFlags::SIZED, false) + } crate::AddressSpace::PushConstant => { if !self.capabilities.contains(Capabilities::PUSH_CONSTANT) { return Err(GlobalVariableError::UnsupportedCapability( @@ -671,7 +745,7 @@ impl super::Validator { } } - if ep.stage == crate::ShaderStage::Compute { + if ep.stage.compute_like() { if ep .workgroup_size .iter() @@ -683,10 +757,30 @@ impl super::Validator { return Err(EntryPointError::UnexpectedWorkgroupSize.with_span()); } + if ep.stage != crate::ShaderStage::Mesh && ep.mesh_info.is_some() { + return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span()); + } + let mut info = self .validate_function(&ep.function, module, mod_info, true) .map_err(WithSpan::into_other)?; + if let Some(handle) = ep.task_payload { + if ep.stage != crate::ShaderStage::Task && ep.stage != crate::ShaderStage::Mesh { + return Err(EntryPointError::UnexpectedTaskPayload.with_span()); + } + if module.global_variables[handle].space != crate::AddressSpace::TaskPayload { + return Err(EntryPointError::TaskPayloadWrongAddressSpace.with_span()); + } + // Make sure that this is always present in the outputted shader + let uses = if ep.stage == crate::ShaderStage::Mesh { + GlobalUse::READ + } else { + GlobalUse::READ | GlobalUse::WRITE + }; + info.insert_global_use(uses, handle); + } + { use super::ShaderStages; @@ -694,7 +788,8 @@ impl super::Validator { crate::ShaderStage::Vertex => ShaderStages::VERTEX, crate::ShaderStage::Fragment => ShaderStages::FRAGMENT, crate::ShaderStage::Compute => ShaderStages::COMPUTE, - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unreachable!(), + crate::ShaderStage::Mesh => ShaderStages::MESH, + crate::ShaderStage::Task => ShaderStages::TASK, }; if !info.available_stages.contains(stage_bit) { @@ -716,6 +811,7 @@ impl super::Validator { built_ins: &mut argument_built_ins, capabilities: self.capabilities, flags: self.flags, + mesh_output_type: MeshOutputType::None, }; ctx.validate(ep, fa.ty, fa.binding.as_ref()) .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; @@ -734,6 +830,7 @@ impl super::Validator { built_ins: &mut result_built_ins, capabilities: self.capabilities, flags: self.flags, + mesh_output_type: MeshOutputType::None, }; ctx.validate(ep, fr.ty, fr.binding.as_ref()) .map_err_inner(|e| EntryPointError::Result(e).with_span())?; @@ -742,11 +839,26 @@ impl super::Validator { { return Err(EntryPointError::MissingVertexOutputPosition.with_span()); } + if ep.stage == crate::ShaderStage::Mesh + && (!result_built_ins.is_empty() || !self.location_mask.is_empty()) + { + return Err(EntryPointError::UnexpectedMeshShaderEntryResult.with_span()); + } + // Cannot have any other built-ins or @location outputs as those are per-vertex or per-primitive + if ep.stage == crate::ShaderStage::Task + && (!result_built_ins.contains(&crate::BuiltIn::MeshTaskSize) + || result_built_ins.len() != 1 + || !self.location_mask.is_empty()) + { + return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); + } if !self.blend_src_mask.is_empty() { info.dual_source_blending = true; } } else if ep.stage == crate::ShaderStage::Vertex { return Err(EntryPointError::MissingVertexOutputPosition.with_span()); + } else if ep.stage == crate::ShaderStage::Task { + return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); } { @@ -764,6 +876,13 @@ impl super::Validator { } } + if let Some(task_payload) = ep.task_payload { + if module.global_variables[task_payload].space != crate::AddressSpace::TaskPayload { + return Err(EntryPointError::TaskPayloadWrongAddressSpace + .with_span_handle(task_payload, &module.global_variables)); + } + } + self.ep_resource_bindings.clear(); for (var_handle, var) in module.global_variables.iter() { let usage = info[var_handle]; @@ -771,6 +890,13 @@ impl super::Validator { continue; } + if var.space == crate::AddressSpace::TaskPayload { + if ep.task_payload != Some(var_handle) { + return Err(EntryPointError::WrongTaskPayloadUsed + .with_span_handle(var_handle, &module.global_variables)); + } + } + let allowed_usage = match var.space { crate::AddressSpace::Function => unreachable!(), crate::AddressSpace::Uniform => GlobalUse::READ | GlobalUse::QUERY, @@ -792,6 +918,15 @@ impl super::Validator { crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => { GlobalUse::READ | GlobalUse::WRITE | GlobalUse::QUERY } + crate::AddressSpace::TaskPayload => { + GlobalUse::READ + | GlobalUse::QUERY + | if ep.stage == crate::ShaderStage::Task { + GlobalUse::WRITE + } else { + GlobalUse::empty() + } + } crate::AddressSpace::PushConstant => GlobalUse::READ, }; if !allowed_usage.contains(usage) { @@ -811,6 +946,77 @@ impl super::Validator { } } + if let &Some(ref mesh_info) = &ep.mesh_info { + // Technically it is allowed to not output anything + // TODO: check that only the allowed builtins are used here + if let Some(used_vertex_type) = info.mesh_shader_info.vertex_type { + if used_vertex_type.0 != mesh_info.vertex_output_type { + return Err(EntryPointError::WrongMeshOutputType + .with_span_handle(mesh_info.vertex_output_type, &module.types)); + } + } + if let Some(used_primitive_type) = info.mesh_shader_info.primitive_type { + if used_primitive_type.0 != mesh_info.primitive_output_type { + return Err(EntryPointError::WrongMeshOutputType + .with_span_handle(mesh_info.primitive_output_type, &module.types)); + } + } + + for (ty, mesh_output_type) in [ + (mesh_info.vertex_output_type, MeshOutputType::VertexOutput), + ( + mesh_info.primitive_output_type, + MeshOutputType::PrimitiveOutput, + ), + ] { + if !matches!(module.types[ty].inner, crate::TypeInner::Struct { .. }) { + return Err( + EntryPointError::InvalidMeshOutputType.with_span_handle(ty, &module.types) + ); + } + let mut result_built_ins = crate::FastHashSet::default(); + let mut ctx = VaryingContext { + stage: ep.stage, + output: true, + types: &module.types, + type_info: &self.types, + location_mask: &mut self.location_mask, + blend_src_mask: &mut self.blend_src_mask, + built_ins: &mut result_built_ins, + capabilities: self.capabilities, + flags: self.flags, + mesh_output_type, + }; + ctx.validate(ep, ty, None) + .map_err_inner(|e| EntryPointError::Result(e).with_span())?; + if mesh_output_type == MeshOutputType::PrimitiveOutput { + let mut num_indices_builtins = 0; + if result_built_ins.contains(&crate::BuiltIn::PointIndex) { + num_indices_builtins += 1; + } + if result_built_ins.contains(&crate::BuiltIn::LineIndices) { + num_indices_builtins += 1; + } + if result_built_ins.contains(&crate::BuiltIn::TriangleIndices) { + num_indices_builtins += 1; + } + if num_indices_builtins != 1 { + return Err(EntryPointError::InvalidMeshPrimitiveOutputType + .with_span_handle(ty, &module.types)); + } + } else if mesh_output_type == MeshOutputType::VertexOutput + && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false }) + { + return Err(EntryPointError::MissingVertexOutputPosition + .with_span_handle(ty, &module.types)); + } + } + } else if info.mesh_shader_info.vertex_type.is_some() + || info.mesh_shader_info.primitive_type.is_some() + { + return Err(EntryPointError::UnexpectedMeshShaderOutput.with_span()); + } + Ok(info) } } diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index fe45d3bfb07..babea985244 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -240,6 +240,8 @@ bitflags::bitflags! { const VERTEX = 0x1; const FRAGMENT = 0x2; const COMPUTE = 0x4; + const MESH = 0x8; + const TASK = 0x10; } } diff --git a/naga/src/valid/type.rs b/naga/src/valid/type.rs index e8b83ff08f3..aa0633e1852 100644 --- a/naga/src/valid/type.rs +++ b/naga/src/valid/type.rs @@ -220,9 +220,12 @@ const fn ptr_space_argument_flag(space: crate::AddressSpace) -> TypeFlags { use crate::AddressSpace as As; match space { As::Function | As::Private => TypeFlags::ARGUMENT, - As::Uniform | As::Storage { .. } | As::Handle | As::PushConstant | As::WorkGroup => { - TypeFlags::empty() - } + As::Uniform + | As::Storage { .. } + | As::Handle + | As::PushConstant + | As::WorkGroup + | As::TaskPayload => TypeFlags::empty(), } } diff --git a/wgpu-core/src/validation.rs b/wgpu-core/src/validation.rs index 2c2f4b36c44..ae199f2c703 100644 --- a/wgpu-core/src/validation.rs +++ b/wgpu-core/src/validation.rs @@ -1085,6 +1085,8 @@ impl Interface { wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex, wgt::ShaderStages::FRAGMENT => naga::ShaderStage::Fragment, wgt::ShaderStages::COMPUTE => naga::ShaderStage::Compute, + wgt::ShaderStages::MESH => naga::ShaderStage::Mesh, + wgt::ShaderStages::TASK => naga::ShaderStage::Task, _ => unreachable!(), } } @@ -1229,7 +1231,7 @@ impl Interface { } // check workgroup size limits - if shader_stage == naga::ShaderStage::Compute { + if shader_stage.compute_like() { let max_workgroup_size_limits = [ self.limits.max_compute_workgroup_size_x, self.limits.max_compute_workgroup_size_y, diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index bb4e2a9d4ae..51381ce4f75 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -2099,6 +2099,9 @@ impl super::Adapter { if features.contains(wgt::Features::EXPERIMENTAL_RAY_HIT_VERTEX_RETURN) { capabilities.push(spv::Capability::RayQueryPositionFetchKHR) } + if features.contains(wgt::Features::EXPERIMENTAL_MESH_SHADER) { + capabilities.push(spv::Capability::MeshShadingEXT); + } if self.private_caps.shader_integer_dot_product { // See . capabilities.extend(&[ From 8c3e550d30ba44eec07f9bc0b3c0301e33a38f29 Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated Date: Thu, 14 Aug 2025 12:53:21 -0500 Subject: [PATCH 02/89] Other initial changes --- naga/src/back/spv/block.rs | 1 + naga/src/back/spv/helpers.rs | 1 + naga/src/back/spv/writer.rs | 6 ++++++ naga/src/front/wgsl/lower/mod.rs | 3 +++ 4 files changed, 11 insertions(+) diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index 0cd414bfbeb..148626ce6bd 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3654,6 +3654,7 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } + Statement::MeshFunction(_) => unreachable!(), } } diff --git a/naga/src/back/spv/helpers.rs b/naga/src/back/spv/helpers.rs index 84e130efaa3..f6d26794e70 100644 --- a/naga/src/back/spv/helpers.rs +++ b/naga/src/back/spv/helpers.rs @@ -54,6 +54,7 @@ pub(super) const fn map_storage_class(space: crate::AddressSpace) -> spirv::Stor crate::AddressSpace::Uniform => spirv::StorageClass::Uniform, crate::AddressSpace::WorkGroup => spirv::StorageClass::Workgroup, crate::AddressSpace::PushConstant => spirv::StorageClass::PushConstant, + crate::AddressSpace::TaskPayload => unreachable!(), } } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 0688eb6c975..2a294a92275 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1927,6 +1927,7 @@ impl Writer { interpolation, sampling, blend_src, + per_primitive: _, } => { self.decorate(id, Decoration::Location, &[location]); @@ -2076,6 +2077,11 @@ impl Writer { )?; BuiltIn::SubgroupLocalInvocationId } + Bi::MeshTaskSize + | Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices => unreachable!(), }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index e90d7eab0a8..2066d7cf2c8 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1527,6 +1527,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { workgroup_size, workgroup_size_overrides, function, + mesh_info: None, + task_payload: None, }); Ok(LoweredGlobalDecl::EntryPoint( ctx.module.entry_points.len() - 1, @@ -4069,6 +4071,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, + per_primitive: false, }; binding.apply_default_interpolation(&ctx.module.types[ty].inner); Some(binding) From 85bbc5a0bbb8958e0d2d8bf977e7dd00effafaeb Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 14 Aug 2025 13:24:44 -0500 Subject: [PATCH 03/89] Updated shader snapshots --- naga/tests/out/analysis/spv-shadow.info.ron | 18 ++- naga/tests/out/analysis/wgsl-access.info.ron | 114 +++++++++++++++--- naga/tests/out/analysis/wgsl-collatz.info.ron | 12 +- .../out/analysis/wgsl-overrides.info.ron | 6 +- .../analysis/wgsl-storage-textures.info.ron | 12 +- naga/tests/out/ir/spv-fetch_depth.compact.ron | 2 + naga/tests/out/ir/spv-fetch_depth.ron | 2 + naga/tests/out/ir/spv-shadow.compact.ron | 5 + naga/tests/out/ir/spv-shadow.ron | 5 + .../out/ir/spv-spec-constants.compact.ron | 6 + naga/tests/out/ir/spv-spec-constants.ron | 6 + naga/tests/out/ir/wgsl-access.compact.ron | 7 ++ naga/tests/out/ir/wgsl-access.ron | 7 ++ naga/tests/out/ir/wgsl-collatz.compact.ron | 2 + naga/tests/out/ir/wgsl-collatz.ron | 2 + .../out/ir/wgsl-const_assert.compact.ron | 2 + naga/tests/out/ir/wgsl-const_assert.ron | 2 + .../out/ir/wgsl-diagnostic-filter.compact.ron | 2 + naga/tests/out/ir/wgsl-diagnostic-filter.ron | 2 + .../out/ir/wgsl-index-by-value.compact.ron | 2 + naga/tests/out/ir/wgsl-index-by-value.ron | 2 + .../tests/out/ir/wgsl-local-const.compact.ron | 2 + naga/tests/out/ir/wgsl-local-const.ron | 2 + naga/tests/out/ir/wgsl-must-use.compact.ron | 2 + naga/tests/out/ir/wgsl-must-use.ron | 2 + ...ides-atomicCompareExchangeWeak.compact.ron | 2 + ...sl-overrides-atomicCompareExchangeWeak.ron | 2 + .../ir/wgsl-overrides-ray-query.compact.ron | 2 + .../tests/out/ir/wgsl-overrides-ray-query.ron | 2 + naga/tests/out/ir/wgsl-overrides.compact.ron | 2 + naga/tests/out/ir/wgsl-overrides.ron | 2 + .../out/ir/wgsl-storage-textures.compact.ron | 4 + naga/tests/out/ir/wgsl-storage-textures.ron | 4 + ...l-template-list-trailing-comma.compact.ron | 2 + .../ir/wgsl-template-list-trailing-comma.ron | 2 + .../out/ir/wgsl-texture-external.compact.ron | 7 ++ naga/tests/out/ir/wgsl-texture-external.ron | 7 ++ .../ir/wgsl-types_with_comments.compact.ron | 2 + .../tests/out/ir/wgsl-types_with_comments.ron | 2 + 39 files changed, 241 insertions(+), 27 deletions(-) diff --git a/naga/tests/out/analysis/spv-shadow.info.ron b/naga/tests/out/analysis/spv-shadow.info.ron index 6ddda61f5c6..b08a28438ed 100644 --- a/naga/tests/out/analysis/spv-shadow.info.ron +++ b/naga/tests/out/analysis/spv-shadow.info.ron @@ -18,7 +18,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -413,10 +413,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -1591,12 +1595,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -1685,6 +1693,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-access.info.ron b/naga/tests/out/analysis/wgsl-access.info.ron index 319f62bdf13..d297b09a404 100644 --- a/naga/tests/out/analysis/wgsl-access.info.ron +++ b/naga/tests/out/analysis/wgsl-access.info.ron @@ -42,7 +42,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -1197,10 +1197,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2523,10 +2527,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2563,10 +2571,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2612,10 +2624,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2655,10 +2671,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2749,10 +2769,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2870,10 +2894,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -2922,10 +2950,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -2977,10 +3009,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -3029,10 +3065,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -3084,10 +3124,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -3148,10 +3192,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(2), requirements: (""), @@ -3221,10 +3269,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(2), requirements: (""), @@ -3297,10 +3349,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -3397,10 +3453,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(1), requirements: (""), @@ -3593,12 +3653,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -4290,10 +4354,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -4742,10 +4810,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(0), requirements: (""), @@ -4812,6 +4884,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-collatz.info.ron b/naga/tests/out/analysis/wgsl-collatz.info.ron index 7ec5799d758..2796f544510 100644 --- a/naga/tests/out/analysis/wgsl-collatz.info.ron +++ b/naga/tests/out/analysis/wgsl-collatz.info.ron @@ -8,7 +8,7 @@ functions: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(3), requirements: (""), @@ -275,12 +275,16 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: Some(3), requirements: (""), @@ -430,6 +434,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [], diff --git a/naga/tests/out/analysis/wgsl-overrides.info.ron b/naga/tests/out/analysis/wgsl-overrides.info.ron index 0e0ae318042..a76c9c89c9b 100644 --- a/naga/tests/out/analysis/wgsl-overrides.info.ron +++ b/naga/tests/out/analysis/wgsl-overrides.info.ron @@ -8,7 +8,7 @@ entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -201,6 +201,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-storage-textures.info.ron b/naga/tests/out/analysis/wgsl-storage-textures.info.ron index fbbf7206c33..35b5a7e320c 100644 --- a/naga/tests/out/analysis/wgsl-storage-textures.info.ron +++ b/naga/tests/out/analysis/wgsl-storage-textures.info.ron @@ -11,7 +11,7 @@ entry_points: [ ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -184,10 +184,14 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), - available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), uniformity: ( non_uniform_result: None, requirements: (""), @@ -396,6 +400,10 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), ), ], const_expression_types: [], diff --git a/naga/tests/out/ir/spv-fetch_depth.compact.ron b/naga/tests/out/ir/spv-fetch_depth.compact.ron index 1fbee2deb35..98f4426c3eb 100644 --- a/naga/tests/out/ir/spv-fetch_depth.compact.ron +++ b/naga/tests/out/ir/spv-fetch_depth.compact.ron @@ -196,6 +196,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-fetch_depth.ron b/naga/tests/out/ir/spv-fetch_depth.ron index 186f78354ad..104de852c17 100644 --- a/naga/tests/out/ir/spv-fetch_depth.ron +++ b/naga/tests/out/ir/spv-fetch_depth.ron @@ -266,6 +266,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-shadow.compact.ron b/naga/tests/out/ir/spv-shadow.compact.ron index b49cd9b55be..bed86a5334d 100644 --- a/naga/tests/out/ir/spv-shadow.compact.ron +++ b/naga/tests/out/ir/spv-shadow.compact.ron @@ -974,6 +974,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ( @@ -984,6 +985,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ], @@ -994,6 +996,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -1032,6 +1035,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-shadow.ron b/naga/tests/out/ir/spv-shadow.ron index e1f0f60b6bb..bdda1d18566 100644 --- a/naga/tests/out/ir/spv-shadow.ron +++ b/naga/tests/out/ir/spv-shadow.ron @@ -1252,6 +1252,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ( @@ -1262,6 +1263,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), ), ], @@ -1272,6 +1274,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -1310,6 +1313,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-spec-constants.compact.ron b/naga/tests/out/ir/spv-spec-constants.compact.ron index 3fa6ffef4ff..67eb29c2475 100644 --- a/naga/tests/out/ir/spv-spec-constants.compact.ron +++ b/naga/tests/out/ir/spv-spec-constants.compact.ron @@ -151,6 +151,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), offset: 0, ), @@ -510,6 +511,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -520,6 +522,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -530,6 +533,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ], @@ -613,6 +617,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/spv-spec-constants.ron b/naga/tests/out/ir/spv-spec-constants.ron index 94c90aa78f9..51686aa20eb 100644 --- a/naga/tests/out/ir/spv-spec-constants.ron +++ b/naga/tests/out/ir/spv-spec-constants.ron @@ -242,6 +242,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), offset: 0, ), @@ -616,6 +617,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -626,6 +628,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ( @@ -636,6 +639,7 @@ interpolation: None, sampling: None, blend_src: None, + per_primitive: false, )), ), ], @@ -719,6 +723,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-access.compact.ron b/naga/tests/out/ir/wgsl-access.compact.ron index 30e88984f3c..c3df0c8c500 100644 --- a/naga/tests/out/ir/wgsl-access.compact.ron +++ b/naga/tests/out/ir/wgsl-access.compact.ron @@ -2655,6 +2655,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_frag", @@ -2672,6 +2674,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -2848,6 +2851,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_compute", @@ -2907,6 +2912,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-access.ron b/naga/tests/out/ir/wgsl-access.ron index 30e88984f3c..c3df0c8c500 100644 --- a/naga/tests/out/ir/wgsl-access.ron +++ b/naga/tests/out/ir/wgsl-access.ron @@ -2655,6 +2655,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_frag", @@ -2672,6 +2674,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -2848,6 +2851,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "foo_compute", @@ -2907,6 +2912,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-collatz.compact.ron b/naga/tests/out/ir/wgsl-collatz.compact.ron index f72c652d032..fc4daaa1296 100644 --- a/naga/tests/out/ir/wgsl-collatz.compact.ron +++ b/naga/tests/out/ir/wgsl-collatz.compact.ron @@ -334,6 +334,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-collatz.ron b/naga/tests/out/ir/wgsl-collatz.ron index f72c652d032..fc4daaa1296 100644 --- a/naga/tests/out/ir/wgsl-collatz.ron +++ b/naga/tests/out/ir/wgsl-collatz.ron @@ -334,6 +334,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-const_assert.compact.ron b/naga/tests/out/ir/wgsl-const_assert.compact.ron index 2816364f88b..648f4ff9bc9 100644 --- a/naga/tests/out/ir/wgsl-const_assert.compact.ron +++ b/naga/tests/out/ir/wgsl-const_assert.compact.ron @@ -34,6 +34,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-const_assert.ron b/naga/tests/out/ir/wgsl-const_assert.ron index 2816364f88b..648f4ff9bc9 100644 --- a/naga/tests/out/ir/wgsl-const_assert.ron +++ b/naga/tests/out/ir/wgsl-const_assert.ron @@ -34,6 +34,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron b/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron index c5746696d52..9a2bf193f30 100644 --- a/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron +++ b/naga/tests/out/ir/wgsl-diagnostic-filter.compact.ron @@ -73,6 +73,8 @@ ], diagnostic_filter_leaf: Some(0), ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [ diff --git a/naga/tests/out/ir/wgsl-diagnostic-filter.ron b/naga/tests/out/ir/wgsl-diagnostic-filter.ron index c5746696d52..9a2bf193f30 100644 --- a/naga/tests/out/ir/wgsl-diagnostic-filter.ron +++ b/naga/tests/out/ir/wgsl-diagnostic-filter.ron @@ -73,6 +73,8 @@ ], diagnostic_filter_leaf: Some(0), ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [ diff --git a/naga/tests/out/ir/wgsl-index-by-value.compact.ron b/naga/tests/out/ir/wgsl-index-by-value.compact.ron index a4f84a7a6b2..addd0e5871c 100644 --- a/naga/tests/out/ir/wgsl-index-by-value.compact.ron +++ b/naga/tests/out/ir/wgsl-index-by-value.compact.ron @@ -465,6 +465,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-index-by-value.ron b/naga/tests/out/ir/wgsl-index-by-value.ron index a4f84a7a6b2..addd0e5871c 100644 --- a/naga/tests/out/ir/wgsl-index-by-value.ron +++ b/naga/tests/out/ir/wgsl-index-by-value.ron @@ -465,6 +465,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-local-const.compact.ron b/naga/tests/out/ir/wgsl-local-const.compact.ron index 512972657ed..0e4e2e4d40e 100644 --- a/naga/tests/out/ir/wgsl-local-const.compact.ron +++ b/naga/tests/out/ir/wgsl-local-const.compact.ron @@ -100,6 +100,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-local-const.ron b/naga/tests/out/ir/wgsl-local-const.ron index 512972657ed..0e4e2e4d40e 100644 --- a/naga/tests/out/ir/wgsl-local-const.ron +++ b/naga/tests/out/ir/wgsl-local-const.ron @@ -100,6 +100,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-must-use.compact.ron b/naga/tests/out/ir/wgsl-must-use.compact.ron index a701a6805da..16e925f2fb8 100644 --- a/naga/tests/out/ir/wgsl-must-use.compact.ron +++ b/naga/tests/out/ir/wgsl-must-use.compact.ron @@ -201,6 +201,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-must-use.ron b/naga/tests/out/ir/wgsl-must-use.ron index a701a6805da..16e925f2fb8 100644 --- a/naga/tests/out/ir/wgsl-must-use.ron +++ b/naga/tests/out/ir/wgsl-must-use.ron @@ -201,6 +201,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron index 640ee25ca49..28a824bb035 100644 --- a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.compact.ron @@ -128,6 +128,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron index 640ee25ca49..28a824bb035 100644 --- a/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron +++ b/naga/tests/out/ir/wgsl-overrides-atomicCompareExchangeWeak.ron @@ -128,6 +128,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron b/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron index f65e8f186db..152a45008c5 100644 --- a/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides-ray-query.compact.ron @@ -263,6 +263,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides-ray-query.ron b/naga/tests/out/ir/wgsl-overrides-ray-query.ron index f65e8f186db..152a45008c5 100644 --- a/naga/tests/out/ir/wgsl-overrides-ray-query.ron +++ b/naga/tests/out/ir/wgsl-overrides-ray-query.ron @@ -263,6 +263,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides.compact.ron b/naga/tests/out/ir/wgsl-overrides.compact.ron index 81221ff7941..fe136e71e4d 100644 --- a/naga/tests/out/ir/wgsl-overrides.compact.ron +++ b/naga/tests/out/ir/wgsl-overrides.compact.ron @@ -221,6 +221,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-overrides.ron b/naga/tests/out/ir/wgsl-overrides.ron index 81221ff7941..fe136e71e4d 100644 --- a/naga/tests/out/ir/wgsl-overrides.ron +++ b/naga/tests/out/ir/wgsl-overrides.ron @@ -221,6 +221,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-storage-textures.compact.ron b/naga/tests/out/ir/wgsl-storage-textures.compact.ron index ec63fecac27..68c867a19e2 100644 --- a/naga/tests/out/ir/wgsl-storage-textures.compact.ron +++ b/naga/tests/out/ir/wgsl-storage-textures.compact.ron @@ -218,6 +218,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "csStore", @@ -315,6 +317,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-storage-textures.ron b/naga/tests/out/ir/wgsl-storage-textures.ron index ec63fecac27..68c867a19e2 100644 --- a/naga/tests/out/ir/wgsl-storage-textures.ron +++ b/naga/tests/out/ir/wgsl-storage-textures.ron @@ -218,6 +218,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "csStore", @@ -315,6 +317,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron b/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron index a8208c09b86..db619dff836 100644 --- a/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron +++ b/naga/tests/out/ir/wgsl-template-list-trailing-comma.compact.ron @@ -190,6 +190,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron b/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron index a8208c09b86..db619dff836 100644 --- a/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron +++ b/naga/tests/out/ir/wgsl-template-list-trailing-comma.ron @@ -190,6 +190,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-texture-external.compact.ron b/naga/tests/out/ir/wgsl-texture-external.compact.ron index dbffbddcdc7..379e76566c5 100644 --- a/naga/tests/out/ir/wgsl-texture-external.compact.ron +++ b/naga/tests/out/ir/wgsl-texture-external.compact.ron @@ -360,6 +360,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -382,6 +383,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "vertex_main", @@ -418,6 +421,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "compute_main", @@ -449,6 +454,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-texture-external.ron b/naga/tests/out/ir/wgsl-texture-external.ron index dbffbddcdc7..379e76566c5 100644 --- a/naga/tests/out/ir/wgsl-texture-external.ron +++ b/naga/tests/out/ir/wgsl-texture-external.ron @@ -360,6 +360,7 @@ interpolation: Some(Perspective), sampling: Some(Center), blend_src: None, + per_primitive: false, )), )), local_variables: [], @@ -382,6 +383,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "vertex_main", @@ -418,6 +421,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ( name: "compute_main", @@ -449,6 +454,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-types_with_comments.compact.ron b/naga/tests/out/ir/wgsl-types_with_comments.compact.ron index 7186209f00e..7c0d856946f 100644 --- a/naga/tests/out/ir/wgsl-types_with_comments.compact.ron +++ b/naga/tests/out/ir/wgsl-types_with_comments.compact.ron @@ -116,6 +116,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], diff --git a/naga/tests/out/ir/wgsl-types_with_comments.ron b/naga/tests/out/ir/wgsl-types_with_comments.ron index 480b0d2337f..34e44cb9653 100644 --- a/naga/tests/out/ir/wgsl-types_with_comments.ron +++ b/naga/tests/out/ir/wgsl-types_with_comments.ron @@ -172,6 +172,8 @@ ], diagnostic_filter_leaf: None, ), + mesh_info: None, + task_payload: None, ), ], diagnostic_filters: [], From ccf84676ce22129a3199c022e24cd46591e71284 Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated Date: Sat, 16 Aug 2025 20:09:18 -0500 Subject: [PATCH 04/89] Added new HLSL limitation --- naga/src/valid/interface.rs | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 51167a4810d..0e2a2583f0f 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -147,6 +147,8 @@ pub enum EntryPointError { InvalidMeshOutputType, #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] InvalidMeshPrimitiveOutputType, + #[error("Task payload must not be zero-sized")] + ZeroSizedTaskPayload, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -881,6 +883,13 @@ impl super::Validator { return Err(EntryPointError::TaskPayloadWrongAddressSpace .with_span_handle(task_payload, &module.global_variables)); } + let var = &module.global_variables[task_payload]; + let ty = &module.types[var.ty].inner; + // HLSL doesn't allow zero sized payloads. + if ty.try_size(module.to_ctx()) == Some(0) { + return Err(EntryPointError::ZeroSizedTaskPayload + .with_span_handle(task_payload, &module.global_variables)); + } } self.ep_resource_bindings.clear(); From e55c02f2e3d75ba607f9f9b1886b69eb0c65cea9 Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated Date: Sat, 16 Aug 2025 20:20:36 -0500 Subject: [PATCH 05/89] Moved error to global variable error --- naga/src/valid/interface.rs | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 0e2a2583f0f..16c09f6dc7c 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -43,6 +43,8 @@ pub enum GlobalVariableError { StorageAddressSpaceWriteOnlyNotSupported, #[error("Type is not valid for use as a push constant")] InvalidPushConstantType(#[source] PushConstantError), + #[error("Task payload must not be zero-sized")] + ZeroSizedTaskPayload, } #[derive(Clone, Debug, thiserror::Error)] @@ -147,8 +149,6 @@ pub enum EntryPointError { InvalidMeshOutputType, #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] InvalidMeshPrimitiveOutputType, - #[error("Task payload must not be zero-sized")] - ZeroSizedTaskPayload, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -704,6 +704,14 @@ impl super::Validator { } } + if var.space == crate::AddressSpace::TaskPayload { + let ty = &gctx.types[var.ty].inner; + // HLSL doesn't allow zero sized payloads. + if ty.try_size(gctx) == Some(0) { + return Err(GlobalVariableError::ZeroSizedTaskPayload); + } + } + if let Some(init) = var.init { match var.space { crate::AddressSpace::Private | crate::AddressSpace::Function => {} @@ -883,13 +891,6 @@ impl super::Validator { return Err(EntryPointError::TaskPayloadWrongAddressSpace .with_span_handle(task_payload, &module.global_variables)); } - let var = &module.global_variables[task_payload]; - let ty = &module.types[var.ty].inner; - // HLSL doesn't allow zero sized payloads. - if ty.try_size(module.to_ctx()) == Some(0) { - return Err(EntryPointError::ZeroSizedTaskPayload - .with_span_handle(task_payload, &module.global_variables)); - } } self.ep_resource_bindings.clear(); From 0f6da753722c1585ecc2089f5e4d121f03a02cd3 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 20 Aug 2025 10:46:27 -0500 Subject: [PATCH 06/89] Added docs to per_primitive --- naga/src/ir/mod.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index a182bf0e064..12a0fecf5c8 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -984,6 +984,13 @@ pub enum Binding { /// Optional `blend_src` index used for dual source blending. /// See blend_src: Option, + /// Whether the binding is a per-primitive binding for use with mesh shaders. + /// This is required to match for mesh and fragment shader stages. + /// This is merely an extra attribute on a binding. You still may not have + /// a per-vertex and per-primitive input with the same location. + /// + /// Per primitive values are not interpolated at all and are not dependent on the vertices + /// or pixel location. For example, it may be used to store a non-interpolated normal vector. per_primitive: bool, }, } From 3017214d9bb12b6021d561045d9fb9ea3485f70c Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 20 Aug 2025 11:08:34 -0500 Subject: [PATCH 07/89] Added a little bit more docs here and there in IR --- naga/src/ir/mod.rs | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 12a0fecf5c8..2856872db27 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -325,6 +325,7 @@ pub enum ShaderStage { Vertex, Fragment, Compute, + // Mesh shader stages Task, Mesh, } @@ -1961,9 +1962,7 @@ pub enum Statement { /// [`Loop`] statement. /// /// [`Loop`]: Statement::Loop - Return { - value: Option>, - }, + Return { value: Option> }, /// Aborts the current shader execution. /// @@ -2169,6 +2168,7 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, + /// A mesh shader intrinsic MeshFunction(MeshFunction), /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { @@ -2345,6 +2345,7 @@ pub struct EntryPoint { pub function: Function, /// The information relating to a mesh shader pub mesh_info: Option, + /// The unique global variable used as a task payload from task shader to mesh shader pub task_payload: Option>, } @@ -2620,6 +2621,7 @@ pub enum MeshOutputTopology { Lines, Triangles, } + #[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] @@ -2635,6 +2637,7 @@ pub struct MeshStageInfo { pub primitive_output_type: Handle, } +/// Mesh shader intrinsics #[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] From 198437b71d2bb39756c5a5133b8e19235553a1f6 Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated Date: Wed, 20 Aug 2025 12:37:38 -0500 Subject: [PATCH 08/89] Adding validation to ensure that task shaders have a task payload --- naga/src/valid/interface.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 16c09f6dc7c..1fed0fda529 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -149,6 +149,8 @@ pub enum EntryPointError { InvalidMeshOutputType, #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] InvalidMeshPrimitiveOutputType, + #[error("Task shaders must declare a task payload output")] + ExpectedTaskPayload, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -891,6 +893,8 @@ impl super::Validator { return Err(EntryPointError::TaskPayloadWrongAddressSpace .with_span_handle(task_payload, &module.global_variables)); } + } else if ep.stage == crate::ShaderStage::Task { + return Err(EntryPointError::ExpectedTaskPayload.with_span()); } self.ep_resource_bindings.clear(); From 64000e4d976edb7397bdd2e71e940a4db0a19c39 Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated Date: Wed, 20 Aug 2025 12:42:01 -0500 Subject: [PATCH 09/89] Updated spec to reflect the change to payload variables --- docs/api-specs/mesh_shading.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index ee14f99e757..e1f28d43e91 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -80,12 +80,12 @@ This shader stage can be selected by marking a function with `@task`. Task shade The output of this determines how many workgroups of mesh shaders will be dispatched. Once dispatched, global id variables will be local to the task shader workgroup dispatch, and mesh shaders won't know the position of their dispatch among all mesh shader dispatches unless this is passed through the payload. The output may be zero to skip dispatching any mesh shader workgroups for the task shader workgroup. -If task shaders are marked with `@payload(someVar)`, where `someVar` is global variable declared like `var someVar: `, task shaders may use `someVar` as if it is a read-write workgroup storage variable. This payload is passed to the mesh shader workgroup that is invoked. The mesh shader can skip declaring `@payload` to ignore this input. +Task shaders must be marked with `@payload(someVar)`, where `someVar` is global variable declared like `var someVar: `. Task shaders may use `someVar` as if it is a read-write workgroup storage variable. This payload is passed to the mesh shader workgroup that is invoked. ### Mesh shader This shader stage can be selected by marking a function with `@mesh`. Mesh shaders must not return anything. -Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, mesh shaders cannot write to this memory. Declaring `@payload` in a pipeline with no task shader, in a pipeline with a task shader that doesn't declare `@payload`, or in a task shader with an `@payload` that is statically sized and smaller than the mesh shader payload is illegal. +Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, this is optional, and mesh shaders cannot write to this memory. Declaring `@payload` in a pipeline with no task shader or in a task shader with an `@payload` that is statically sized and differently than the mesh shader payload is illegal. The `@payload` attribute can only be ignored in pipelines that don't have a task shader. Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output, and must be a struct. From b572ec7e231d466457aec0d17aa7a11ceffd313d Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated Date: Sat, 23 Aug 2025 20:08:16 -0500 Subject: [PATCH 10/89] Updated the mesh shading spec because it was goofy --- docs/api-specs/mesh_shading.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index e1f28d43e91..e9b6df3710d 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -2,8 +2,8 @@ 🧪Experimental🧪 -`wgpu` supports an experimental version of mesh shading. The extensions allow for acceleration structures to be created and built (with -`Features::EXPERIMENTAL_MESH_SHADER` enabled) and interacted with in shaders. Currently `naga` has no support for mesh shaders beyond recognizing the additional shader stages. +`wgpu` supports an experimental version of mesh shading when `Features::EXPERIMENTAL_MESH_SHADER` is enabled. +Currently `naga` has no support for parsing or writing mesh shaders. For this reason, all shaders must be created with `Device::create_shader_module_passthrough`. **Note**: The features documented here may have major bugs in them and are expected to be subject From 45fbacc17f24b83b29fab9b3c1e1472d32edd4f6 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 23 Aug 2025 23:21:54 -0500 Subject: [PATCH 11/89] Wait did I break it --- wgpu-hal/src/metal/adapter.rs | 12 +++++++++--- wgpu-hal/src/metal/command.rs | 32 ++++++++++++++++++++++++++------ wgpu-hal/src/metal/device.rs | 28 +++++++++++++--------------- wgpu-hal/src/metal/mod.rs | 21 ++++++++++++++------- 4 files changed, 62 insertions(+), 31 deletions(-) diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 02dfc0fe601..9517f0b4dd6 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -606,6 +606,8 @@ impl super::PrivateCapabilities { } let argument_buffers = device.argument_buffers_support(); + let mesh_shaders = device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Mac2); Self { family_check, @@ -902,6 +904,7 @@ impl super::PrivateCapabilities { && (device.supports_family(MTLGPUFamily::Apple7) || device.supports_family(MTLGPUFamily::Mac2)), supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac), + mesh_shaders, } } @@ -1003,6 +1006,8 @@ impl super::PrivateCapabilities { features.insert(F::SUBGROUP | F::SUBGROUP_BARRIER); } + features.set(F::EXPERIMENTAL_MESH_SHADER, self.mesh_shaders); + features } @@ -1079,10 +1084,11 @@ impl super::PrivateCapabilities { max_buffer_size: self.max_buffer_size, max_non_sampler_bindings: u32::MAX, - max_task_workgroup_total_count: 0, - max_task_workgroups_per_dimension: 0, + // See https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf, Maximum threadgroups per mesh shader grid + max_task_workgroup_total_count: 1024, + max_task_workgroups_per_dimension: 1024, max_mesh_multiview_count: 0, - max_mesh_output_layers: 0, + max_mesh_output_layers: self.max_texture_layers as u32, max_blas_primitive_count: 0, // When added: 2^28 from https://developer.apple.com/documentation/metal/mtlaccelerationstructureusage/extendedlimits max_blas_geometry_count: 0, // When added: 2^24 diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 72a799a0275..2b66343c478 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -906,11 +906,22 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) { self.state.raw_primitive_type = pipeline.raw_primitive_type; - self.state.stage_infos.vs.assign_from(&pipeline.vs_info); + match pipeline.vs_info { + Some(ref info) => self.state.stage_infos.vs.assign_from(info), + None => self.state.stage_infos.vs.clear(), + } match pipeline.fs_info { Some(ref info) => self.state.stage_infos.fs.assign_from(info), None => self.state.stage_infos.fs.clear(), } + match pipeline.ts_info { + Some(ref info) => self.state.stage_infos.ts.assign_from(info), + None => self.state.stage_infos.vs.clear(), + } + match pipeline.ms_info { + Some(ref info) => self.state.stage_infos.ms.assign_from(info), + None => self.state.stage_infos.fs.clear(), + } let encoder = self.state.render.as_ref().unwrap(); encoder.set_render_pipeline_state(&pipeline.raw); @@ -937,7 +948,7 @@ impl crate::CommandEncoder for super::CommandEncoder { ); } } - if pipeline.fs_lib.is_some() { + if pipeline.fs_info.is_some() { if let Some((index, sizes)) = self .state .make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes) @@ -1111,11 +1122,20 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn draw_mesh_tasks( &mut self, - _group_count_x: u32, - _group_count_y: u32, - _group_count_z: u32, + group_count_x: u32, + group_count_y: u32, + group_count_z: u32, ) { - unreachable!() + let encoder = self.state.render.as_ref().unwrap(); + encoder.draw_mesh_threadgroups( + MTLSize { + width: group_count_x as u64, + height: group_count_y as u64, + depth: group_count_z as u64, + }, + todo!(), + todo!(), + ); } unsafe fn draw_indirect( diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 6af8ad3062d..97878960a36 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1078,7 +1078,7 @@ impl crate::Device for super::Device { conv::map_primitive_topology(desc.primitive.topology); // Vertex shader - let (vs_lib, vs_info) = { + let vs_info = { let mut vertex_buffer_mappings = Vec::::new(); for (i, vbl) in desc_vertex_buffers.iter().enumerate() { let mut attributes = Vec::::new(); @@ -1124,18 +1124,17 @@ impl crate::Device for super::Device { ); } - let info = super::PipelineStageInfo { + super::PipelineStageInfo { push_constants: desc.layout.push_constants_infos.vs, sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, sized_bindings: vs.sized_bindings, vertex_buffer_mappings, - }; - - (vs.library, info) + library: Some(vs.library), + } }; // Fragment shader - let (fs_lib, fs_info) = match desc.fragment_stage { + let fs_info = match desc.fragment_stage { Some(ref stage) => { let fs = self.load_shader( stage, @@ -1153,14 +1152,13 @@ impl crate::Device for super::Device { ); } - let info = super::PipelineStageInfo { + Some(super::PipelineStageInfo { push_constants: desc.layout.push_constants_infos.fs, sizes_slot: desc.layout.per_stage_map.fs.sizes_buffer, sized_bindings: fs.sized_bindings, vertex_buffer_mappings: vec![], - }; - - (Some(fs.library), Some(info)) + library: Some(fs.library), + }) } None => { // TODO: This is a workaround for what appears to be a Metal validation bug @@ -1168,7 +1166,7 @@ impl crate::Device for super::Device { if desc.color_targets.is_empty() && desc.depth_stencil.is_none() { descriptor.set_depth_attachment_pixel_format(MTLPixelFormat::Depth32Float); } - (None, None) + None } }; @@ -1302,10 +1300,10 @@ impl crate::Device for super::Device { Ok(super::RenderPipeline { raw, - vs_lib, - fs_lib, - vs_info, + vs_info: Some(vs_info), fs_info, + ts_info: None, + ms_info: None, raw_primitive_type, raw_triangle_fill_mode, raw_front_winding: conv::map_winding(desc.primitive.front_face), @@ -1373,6 +1371,7 @@ impl crate::Device for super::Device { } let cs_info = super::PipelineStageInfo { + library: Some(cs.library), push_constants: desc.layout.push_constants_infos.cs, sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer, sized_bindings: cs.sized_bindings, @@ -1400,7 +1399,6 @@ impl crate::Device for super::Device { Ok(super::ComputePipeline { raw, cs_info, - cs_lib: cs.library, work_group_size: cs.wg_size, work_group_memory_sizes: cs.wg_memory_sizes, }) diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 00223b2f778..ec4ae11cdef 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -300,6 +300,7 @@ struct PrivateCapabilities { int64_atomics: bool, float_atomics: bool, supports_shared_event: bool, + mesh_shaders: bool, } #[derive(Clone, Debug)] @@ -604,12 +605,16 @@ struct MultiStageData { vs: T, fs: T, cs: T, + ts: T, + ms: T, } const NAGA_STAGES: MultiStageData = MultiStageData { vs: naga::ShaderStage::Vertex, fs: naga::ShaderStage::Fragment, cs: naga::ShaderStage::Compute, + ts: naga::ShaderStage::Task, + ms: naga::ShaderStage::Mesh, }; impl ops::Index for MultiStageData { @@ -630,6 +635,8 @@ impl MultiStageData { vs: fun(&self.vs), fs: fun(&self.fs), cs: fun(&self.cs), + ts: fun(&self.ts), + ms: fun(&self.ms), } } fn map(self, fun: impl Fn(T) -> Y) -> MultiStageData { @@ -637,6 +644,8 @@ impl MultiStageData { vs: fun(self.vs), fs: fun(self.fs), cs: fun(self.cs), + ts: fun(self.ts), + ms: fun(self.ms), } } fn iter<'a>(&'a self) -> impl Iterator { @@ -811,6 +820,8 @@ impl crate::DynShaderModule for ShaderModule {} #[derive(Debug, Default)] struct PipelineStageInfo { + #[allow(dead_code)] + library: Option, push_constants: Option, /// The buffer argument table index at which we pass runtime-sized arrays' buffer sizes. @@ -849,12 +860,10 @@ impl PipelineStageInfo { #[derive(Debug)] pub struct RenderPipeline { raw: metal::RenderPipelineState, - #[allow(dead_code)] - vs_lib: metal::Library, - #[allow(dead_code)] - fs_lib: Option, - vs_info: PipelineStageInfo, + vs_info: Option, fs_info: Option, + ts_info: Option, + ms_info: Option, raw_primitive_type: MTLPrimitiveType, raw_triangle_fill_mode: MTLTriangleFillMode, raw_front_winding: MTLWinding, @@ -871,8 +880,6 @@ impl crate::DynRenderPipeline for RenderPipeline {} #[derive(Debug)] pub struct ComputePipeline { raw: metal::ComputePipelineState, - #[allow(dead_code)] - cs_lib: metal::Library, cs_info: PipelineStageInfo, work_group_size: MTLSize, work_group_memory_sizes: Vec, From 611c01a4566a3e6bb48dbf599df063db2b2b6449 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 00:11:50 -0500 Subject: [PATCH 12/89] More work --- wgpu-hal/src/metal/command.rs | 59 +++-- wgpu-hal/src/metal/device.rs | 393 +++++++++++++++++++++------------- wgpu-hal/src/metal/mod.rs | 5 +- 3 files changed, 299 insertions(+), 158 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 2b66343c478..37beb41a9a3 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -21,7 +21,6 @@ impl Default for super::CommandState { compute: None, raw_primitive_type: MTLPrimitiveType::Point, index: None, - raw_wg_size: MTLSize::new(0, 0, 0), stage_infos: Default::default(), storage_buffer_length_map: Default::default(), vertex_buffer_size_map: Default::default(), @@ -936,7 +935,7 @@ impl crate::CommandEncoder for super::CommandEncoder { encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp); } - { + if pipeline.vs_info.is_some() { if let Some((index, sizes)) = self .state .make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes) @@ -960,6 +959,30 @@ impl crate::CommandEncoder for super::CommandEncoder { ); } } + if pipeline.ts_info.is_some() { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Task, &mut self.temp.binding_sizes) + { + encoder.set_object_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr().cast(), + ); + } + } + if pipeline.ms_info.is_some() { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Mesh, &mut self.temp.binding_sizes) + { + encoder.set_mesh_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr().cast(), + ); + } + } } unsafe fn set_index_buffer<'a>( @@ -1133,8 +1156,8 @@ impl crate::CommandEncoder for super::CommandEncoder { height: group_count_y as u64, depth: group_count_z as u64, }, - todo!(), - todo!(), + self.state.stage_infos.ts.raw_wg_size, + self.state.stage_infos.ms.raw_wg_size, ); } @@ -1174,11 +1197,20 @@ impl crate::CommandEncoder for super::CommandEncoder { unsafe fn draw_mesh_tasks_indirect( &mut self, - _buffer: &::Buffer, - _offset: wgt::BufferAddress, - _draw_count: u32, + buffer: &::Buffer, + mut offset: wgt::BufferAddress, + draw_count: u32, ) { - unreachable!() + let encoder = self.state.render.as_ref().unwrap(); + for _ in 0..draw_count { + encoder.draw_mesh_threadgroups_with_indirect_buffer( + &buffer.raw, + offset, + self.state.stage_infos.ts.raw_wg_size, + self.state.stage_infos.ms.raw_wg_size, + ); + offset += size_of::() as wgt::BufferAddress; + } } unsafe fn draw_indirect_count( @@ -1210,7 +1242,7 @@ impl crate::CommandEncoder for super::CommandEncoder { _count_offset: wgt::BufferAddress, _max_count: u32, ) { - unreachable!() + //TODO } // compute @@ -1286,7 +1318,6 @@ impl crate::CommandEncoder for super::CommandEncoder { } unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) { - self.state.raw_wg_size = pipeline.work_group_size; self.state.stage_infos.cs.assign_from(&pipeline.cs_info); let encoder = self.state.compute.as_ref().unwrap(); @@ -1330,13 +1361,17 @@ impl crate::CommandEncoder for super::CommandEncoder { height: count[1] as u64, depth: count[2] as u64, }; - encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size); + encoder.dispatch_thread_groups(raw_count, self.state.stage_infos.cs.raw_wg_size); } } unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) { let encoder = self.state.compute.as_ref().unwrap(); - encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size); + encoder.dispatch_thread_groups_indirect( + &buffer.raw, + offset, + self.state.stage_infos.cs.raw_wg_size, + ); } unsafe fn build_acceleration_structures<'a, T>( diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 97878960a36..6474136f4d7 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -18,6 +18,11 @@ use metal::{ type DeviceResult = Result; +enum MetalGenericRenderPipelineDescriptor { + Standard(metal::RenderPipelineDescriptor), + Mesh(metal::MeshRenderPipelineDescriptor), +} + struct CompiledShader { library: metal::Library, function: metal::Function, @@ -1054,83 +1059,207 @@ impl crate::Device for super::Device { super::PipelineCache, >, ) -> Result { - let (desc_vertex_stage, desc_vertex_buffers) = match &desc.vertex_processor { - crate::VertexProcessor::Standard { - vertex_buffers, - vertex_stage, - } => (vertex_stage, *vertex_buffers), - crate::VertexProcessor::Mesh { .. } => unreachable!(), - }; - objc::rc::autoreleasepool(|| { - let descriptor = metal::RenderPipelineDescriptor::new(); - - let raw_triangle_fill_mode = match desc.primitive.polygon_mode { - wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill, - wgt::PolygonMode::Line => MTLTriangleFillMode::Lines, - wgt::PolygonMode::Point => panic!( - "{:?} is not enabled for this backend", - wgt::Features::POLYGON_MODE_POINT - ), - }; - let (primitive_class, raw_primitive_type) = conv::map_primitive_topology(desc.primitive.topology); - // Vertex shader - let vs_info = { - let mut vertex_buffer_mappings = Vec::::new(); - for (i, vbl) in desc_vertex_buffers.iter().enumerate() { - let mut attributes = Vec::::new(); - for attribute in vbl.attributes.iter() { - attributes.push(naga::back::msl::AttributeMapping { - shader_location: attribute.shader_location, - offset: attribute.offset as u32, - format: convert_vertex_format_to_naga(attribute.format), - }); - } + let vs_info; + let ts_info; + let ms_info; + let descriptor = match desc.vertex_processor { + crate::VertexProcessor::Standard { + vertex_buffers, + ref vertex_stage, + } => { + let descriptor = metal::RenderPipelineDescriptor::new(); + ts_info = None; + ms_info = None; + vs_info = Some({ + let mut vertex_buffer_mappings = + Vec::::new(); + for (i, vbl) in vertex_buffers.iter().enumerate() { + let mut attributes = Vec::::new(); + for attribute in vbl.attributes.iter() { + attributes.push(naga::back::msl::AttributeMapping { + shader_location: attribute.shader_location, + offset: attribute.offset as u32, + format: convert_vertex_format_to_naga(attribute.format), + }); + } + + vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping { + id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, + stride: if vbl.array_stride > 0 { + vbl.array_stride.try_into().unwrap() + } else { + vbl.attributes + .iter() + .map(|attribute| attribute.offset + attribute.format.size()) + .max() + .unwrap_or(0) + .try_into() + .unwrap() + }, + indexed_by_vertex: (vbl.step_mode + == wgt::VertexStepMode::Vertex {}), + attributes, + }); + } - vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping { - id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, - stride: if vbl.array_stride > 0 { - vbl.array_stride.try_into().unwrap() - } else { - vbl.attributes - .iter() - .map(|attribute| attribute.offset + attribute.format.size()) - .max() - .unwrap_or(0) - .try_into() - .unwrap() - }, - indexed_by_vertex: (vbl.step_mode == wgt::VertexStepMode::Vertex {}), - attributes, + let vs = self.load_shader( + vertex_stage, + &vertex_buffer_mappings, + desc.layout, + primitive_class, + naga::ShaderStage::Vertex, + )?; + + descriptor.set_vertex_function(Some(&vs.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.vertex_buffers().unwrap(), + vs.immutable_buffer_mask, + ); + } + + super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.vs, + sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, + sized_bindings: vs.sized_bindings, + vertex_buffer_mappings, + library: Some(vs.library), + raw_wg_size: Default::default(), + } }); - } + if desc.layout.total_counters.vs.buffers + (vertex_buffers.len() as u32) + > self.shared.private_caps.max_vertex_buffers + { + let msg = format!( + "pipeline needs too many buffers in the vertex stage: {} vertex and {} layout", + vertex_buffers.len(), + desc.layout.total_counters.vs.buffers + ); + return Err(crate::PipelineError::Linkage( + wgt::ShaderStages::VERTEX, + msg, + )); + } - let vs = self.load_shader( - desc_vertex_stage, - &vertex_buffer_mappings, - desc.layout, - primitive_class, - naga::ShaderStage::Vertex, - )?; - - descriptor.set_vertex_function(Some(&vs.function)); - if self.shared.private_caps.supports_mutability { - Self::set_buffers_mutability( - descriptor.vertex_buffers().unwrap(), - vs.immutable_buffer_mask, - ); - } + if !vertex_buffers.is_empty() { + let vertex_descriptor = metal::VertexDescriptor::new(); + for (i, vb) in vertex_buffers.iter().enumerate() { + let buffer_index = + self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64; + let buffer_desc = + vertex_descriptor.layouts().object_at(buffer_index).unwrap(); + + // Metal expects the stride to be the actual size of the attributes. + // The semantics of array_stride == 0 can be achieved by setting + // the step function to constant and rate to 0. + if vb.array_stride == 0 { + let stride = vb + .attributes + .iter() + .map(|attribute| attribute.offset + attribute.format.size()) + .max() + .unwrap_or(0); + buffer_desc.set_stride(wgt::math::align_to(stride, 4)); + buffer_desc.set_step_function(MTLVertexStepFunction::Constant); + buffer_desc.set_step_rate(0); + } else { + buffer_desc.set_stride(vb.array_stride); + buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode)); + } - super::PipelineStageInfo { - push_constants: desc.layout.push_constants_infos.vs, - sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, - sized_bindings: vs.sized_bindings, - vertex_buffer_mappings, - library: Some(vs.library), + for at in vb.attributes { + let attribute_desc = vertex_descriptor + .attributes() + .object_at(at.shader_location as u64) + .unwrap(); + attribute_desc.set_format(conv::map_vertex_format(at.format)); + attribute_desc.set_buffer_index(buffer_index); + attribute_desc.set_offset(at.offset); + } + } + descriptor.set_vertex_descriptor(Some(vertex_descriptor)); + } + todo!() } + crate::VertexProcessor::Mesh { + ref task_stage, + ref mesh_stage, + } => { + vs_info = None; + let descriptor = metal::MeshRenderPipelineDescriptor::new(); + if let Some(ref task_stage) = task_stage { + let ts = self.load_shader( + task_stage, + &[], + desc.layout, + primitive_class, + naga::ShaderStage::Task, + )?; + descriptor.set_mesh_function(Some(&ts.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.mesh_buffers().unwrap(), + ts.immutable_buffer_mask, + ); + } + ts_info = Some(super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.ts, + sizes_slot: desc.layout.per_stage_map.ts.sizes_buffer, + sized_bindings: ts.sized_bindings, + vertex_buffer_mappings: vec![], + library: Some(ts.library), + raw_wg_size: Default::default(), + }); + } else { + ts_info = None; + } + { + let ms = self.load_shader( + mesh_stage, + &[], + desc.layout, + primitive_class, + naga::ShaderStage::Mesh, + )?; + descriptor.set_mesh_function(Some(&ms.function)); + if self.shared.private_caps.supports_mutability { + Self::set_buffers_mutability( + descriptor.mesh_buffers().unwrap(), + ms.immutable_buffer_mask, + ); + } + ms_info = Some(super::PipelineStageInfo { + push_constants: desc.layout.push_constants_infos.ms, + sizes_slot: desc.layout.per_stage_map.ms.sizes_buffer, + sized_bindings: ms.sized_bindings, + vertex_buffer_mappings: vec![], + library: Some(ms.library), + raw_wg_size: Default::default(), + }); + } + MetalGenericRenderPipelineDescriptor::Mesh(descriptor) + } + }; + macro_rules! descriptor_fn { + ($method:ident $( ( $($args:expr),* ) )? ) => { + match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(ref inner) => inner.$method$(($($args),*))?, + MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => inner.$method$(($($args),*))?, + } + }; + } + + let raw_triangle_fill_mode = match desc.primitive.polygon_mode { + wgt::PolygonMode::Fill => MTLTriangleFillMode::Fill, + wgt::PolygonMode::Line => MTLTriangleFillMode::Lines, + wgt::PolygonMode::Point => panic!( + "{:?} is not enabled for this backend", + wgt::Features::POLYGON_MODE_POINT + ), }; // Fragment shader @@ -1144,10 +1273,10 @@ impl crate::Device for super::Device { naga::ShaderStage::Fragment, )?; - descriptor.set_fragment_function(Some(&fs.function)); + descriptor_fn!(set_fragment_function(Some(&fs.function))); if self.shared.private_caps.supports_mutability { Self::set_buffers_mutability( - descriptor.fragment_buffers().unwrap(), + descriptor_fn!(fragment_buffers()).unwrap(), fs.immutable_buffer_mask, ); } @@ -1158,20 +1287,25 @@ impl crate::Device for super::Device { sized_bindings: fs.sized_bindings, vertex_buffer_mappings: vec![], library: Some(fs.library), + raw_wg_size: Default::default(), }) } None => { // TODO: This is a workaround for what appears to be a Metal validation bug // A pixel format is required even though no attachments are provided if desc.color_targets.is_empty() && desc.depth_stencil.is_none() { - descriptor.set_depth_attachment_pixel_format(MTLPixelFormat::Depth32Float); + descriptor_fn!(set_depth_attachment_pixel_format( + MTLPixelFormat::Depth32Float + )); } None } }; for (i, ct) in desc.color_targets.iter().enumerate() { - let at_descriptor = descriptor.color_attachments().object_at(i as u64).unwrap(); + let at_descriptor = descriptor_fn!(color_attachments()) + .object_at(i as u64) + .unwrap(); let ct = if let Some(color_target) = ct.as_ref() { color_target } else { @@ -1203,10 +1337,10 @@ impl crate::Device for super::Device { let raw_format = self.shared.private_caps.map_format(ds.format); let aspects = crate::FormatAspects::from(ds.format); if aspects.contains(crate::FormatAspects::DEPTH) { - descriptor.set_depth_attachment_pixel_format(raw_format); + descriptor_fn!(set_depth_attachment_pixel_format(raw_format)); } if aspects.contains(crate::FormatAspects::STENCIL) { - descriptor.set_stencil_attachment_pixel_format(raw_format); + descriptor_fn!(set_stencil_attachment_pixel_format(raw_format)); } let ds_descriptor = create_depth_stencil_desc(ds); @@ -1220,90 +1354,61 @@ impl crate::Device for super::Device { None => None, }; - if desc.layout.total_counters.vs.buffers + (desc_vertex_buffers.len() as u32) - > self.shared.private_caps.max_vertex_buffers - { - let msg = format!( - "pipeline needs too many buffers in the vertex stage: {} vertex and {} layout", - desc_vertex_buffers.len(), - desc.layout.total_counters.vs.buffers - ); - return Err(crate::PipelineError::Linkage( - wgt::ShaderStages::VERTEX, - msg, - )); - } - - if !desc_vertex_buffers.is_empty() { - let vertex_descriptor = metal::VertexDescriptor::new(); - for (i, vb) in desc_vertex_buffers.iter().enumerate() { - let buffer_index = - self.shared.private_caps.max_vertex_buffers as u64 - 1 - i as u64; - let buffer_desc = vertex_descriptor.layouts().object_at(buffer_index).unwrap(); - - // Metal expects the stride to be the actual size of the attributes. - // The semantics of array_stride == 0 can be achieved by setting - // the step function to constant and rate to 0. - if vb.array_stride == 0 { - let stride = vb - .attributes - .iter() - .map(|attribute| attribute.offset + attribute.format.size()) - .max() - .unwrap_or(0); - buffer_desc.set_stride(wgt::math::align_to(stride, 4)); - buffer_desc.set_step_function(MTLVertexStepFunction::Constant); - buffer_desc.set_step_rate(0); - } else { - buffer_desc.set_stride(vb.array_stride); - buffer_desc.set_step_function(conv::map_step_mode(vb.step_mode)); + if desc.multisample.count != 1 { + //TODO: handle sample mask + match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(ref inner) => { + inner.set_sample_count(desc.multisample.count as u64); } - - for at in vb.attributes { - let attribute_desc = vertex_descriptor - .attributes() - .object_at(at.shader_location as u64) - .unwrap(); - attribute_desc.set_format(conv::map_vertex_format(at.format)); - attribute_desc.set_buffer_index(buffer_index); - attribute_desc.set_offset(at.offset); + MetalGenericRenderPipelineDescriptor::Mesh(ref inner) => { + inner.set_raster_sample_count(desc.multisample.count as u64); } } - descriptor.set_vertex_descriptor(Some(vertex_descriptor)); - } - - if desc.multisample.count != 1 { - //TODO: handle sample mask - descriptor.set_sample_count(desc.multisample.count as u64); - descriptor - .set_alpha_to_coverage_enabled(desc.multisample.alpha_to_coverage_enabled); + descriptor_fn!(set_alpha_to_coverage_enabled( + desc.multisample.alpha_to_coverage_enabled + )); //descriptor.set_alpha_to_one_enabled(desc.multisample.alpha_to_one_enabled); } if let Some(name) = desc.label { - descriptor.set_label(name); + descriptor_fn!(set_label(name)); } - let raw = self - .shared - .device - .lock() - .new_render_pipeline_state(&descriptor) - .map_err(|e| { - crate::PipelineError::Linkage( - wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, - format!("new_render_pipeline_state: {e:?}"), - ) - })?; + let raw = match descriptor { + MetalGenericRenderPipelineDescriptor::Standard(d) => self + .shared + .device + .lock() + .new_render_pipeline_state(&d) + .map_err(|e| { + crate::PipelineError::Linkage( + wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, + format!("new_render_pipeline_state: {e:?}"), + ) + })?, + MetalGenericRenderPipelineDescriptor::Mesh(d) => self + .shared + .device + .lock() + .new_mesh_render_pipeline_state(&d) + .map_err(|e| { + crate::PipelineError::Linkage( + wgt::ShaderStages::TASK + | wgt::ShaderStages::MESH + | wgt::ShaderStages::FRAGMENT, + format!("new_render_pipeline_state: {e:?}"), + ) + })?, + }; self.counters.render_pipelines.add(1); Ok(super::RenderPipeline { raw, - vs_info: Some(vs_info), + vs_info, fs_info, - ts_info: None, - ms_info: None, + ts_info, + ms_info, raw_primitive_type, raw_triangle_fill_mode, raw_front_winding: conv::map_winding(desc.primitive.front_face), @@ -1376,6 +1481,7 @@ impl crate::Device for super::Device { sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer, sized_bindings: cs.sized_bindings, vertex_buffer_mappings: vec![], + raw_wg_size: cs.wg_size, }; if let Some(name) = desc.label { @@ -1399,7 +1505,6 @@ impl crate::Device for super::Device { Ok(super::ComputePipeline { raw, cs_info, - work_group_size: cs.wg_size, work_group_memory_sizes: cs.wg_memory_sizes, }) }) diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index ec4ae11cdef..a9d9e19b57b 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -836,6 +836,9 @@ struct PipelineStageInfo { /// Info on all bound vertex buffers. vertex_buffer_mappings: Vec, + + /// The workgroup size for compute, task or mesh stages + raw_wg_size: MTLSize, } impl PipelineStageInfo { @@ -881,7 +884,6 @@ impl crate::DynRenderPipeline for RenderPipeline {} pub struct ComputePipeline { raw: metal::ComputePipelineState, cs_info: PipelineStageInfo, - work_group_size: MTLSize, work_group_memory_sizes: Vec, } @@ -956,7 +958,6 @@ struct CommandState { compute: Option, raw_primitive_type: MTLPrimitiveType, index: Option, - raw_wg_size: MTLSize, stage_infos: MultiStageData, /// Sizes of currently bound [`wgt::BufferBindingType::Storage`] buffers. From 3d36680bca124a3d61a7e426553c71a6bdb4eab6 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 00:13:03 -0500 Subject: [PATCH 13/89] Oops --- wgpu-hal/src/metal/device.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 6474136f4d7..4f1154c42c3 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1183,7 +1183,7 @@ impl crate::Device for super::Device { } descriptor.set_vertex_descriptor(Some(vertex_descriptor)); } - todo!() + MetalGenericRenderPipelineDescriptor::Standard(descriptor) } crate::VertexProcessor::Mesh { ref task_stage, From c9c39fd4ab74d7d0516c3e556fb86e9935dde031 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 00:22:57 -0500 Subject: [PATCH 14/89] Another refactor --- wgpu-hal/src/metal/adapter.rs | 5 ++--- wgpu-hal/src/metal/command.rs | 6 ++++-- wgpu-hal/src/metal/device.rs | 11 ++++++----- wgpu-hal/src/metal/mod.rs | 4 +++- 4 files changed, 15 insertions(+), 11 deletions(-) diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 9517f0b4dd6..d298ee7da15 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -606,8 +606,6 @@ impl super::PrivateCapabilities { } let argument_buffers = device.argument_buffers_support(); - let mesh_shaders = device.supports_family(MTLGPUFamily::Apple7) - || device.supports_family(MTLGPUFamily::Mac2); Self { family_check, @@ -904,7 +902,8 @@ impl super::PrivateCapabilities { && (device.supports_family(MTLGPUFamily::Apple7) || device.supports_family(MTLGPUFamily::Mac2)), supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac), - mesh_shaders, + mesh_shaders: device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Mac2), } } diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 37beb41a9a3..db282a8d91e 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1335,14 +1335,16 @@ impl crate::CommandEncoder for super::CommandEncoder { } // update the threadgroup memory sizes - while self.state.work_group_memory_sizes.len() < pipeline.work_group_memory_sizes.len() { + while self.state.work_group_memory_sizes.len() + < pipeline.cs_info.work_group_memory_sizes.len() + { self.state.work_group_memory_sizes.push(0); } for (index, (cur_size, pipeline_size)) in self .state .work_group_memory_sizes .iter_mut() - .zip(pipeline.work_group_memory_sizes.iter()) + .zip(pipeline.cs_info.work_group_memory_sizes.iter()) .enumerate() { let size = pipeline_size.next_multiple_of(16); diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 4f1154c42c3..ee1a74b2131 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1129,6 +1129,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings, library: Some(vs.library), raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], } }); if desc.layout.total_counters.vs.buffers + (vertex_buffers.len() as u32) @@ -1213,6 +1214,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings: vec![], library: Some(ts.library), raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], }); } else { ts_info = None; @@ -1239,6 +1241,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings: vec![], library: Some(ms.library), raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], }); } MetalGenericRenderPipelineDescriptor::Mesh(descriptor) @@ -1288,6 +1291,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings: vec![], library: Some(fs.library), raw_wg_size: Default::default(), + work_group_memory_sizes: vec![], }) } None => { @@ -1482,6 +1486,7 @@ impl crate::Device for super::Device { sized_bindings: cs.sized_bindings, vertex_buffer_mappings: vec![], raw_wg_size: cs.wg_size, + work_group_memory_sizes: cs.wg_memory_sizes, }; if let Some(name) = desc.label { @@ -1502,11 +1507,7 @@ impl crate::Device for super::Device { self.counters.compute_pipelines.add(1); - Ok(super::ComputePipeline { - raw, - cs_info, - work_group_memory_sizes: cs.wg_memory_sizes, - }) + Ok(super::ComputePipeline { raw, cs_info }) }) } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index a9d9e19b57b..c2d2a80a214 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -839,6 +839,9 @@ struct PipelineStageInfo { /// The workgroup size for compute, task or mesh stages raw_wg_size: MTLSize, + + /// The workgroup memory sizes for compute task or mesh stages + work_group_memory_sizes: Vec, } impl PipelineStageInfo { @@ -884,7 +887,6 @@ impl crate::DynRenderPipeline for RenderPipeline {} pub struct ComputePipeline { raw: metal::ComputePipelineState, cs_info: PipelineStageInfo, - work_group_memory_sizes: Vec, } unsafe impl Send for ComputePipeline {} From fb330288f734898ac0e6a000ba32e4f3d23e3b4b Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 01:49:43 -0500 Subject: [PATCH 15/89] Another slight refactor --- wgpu-hal/src/metal/command.rs | 8 ++++---- wgpu-hal/src/metal/mod.rs | 4 +++- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index db282a8d91e..a91035b642f 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -24,7 +24,6 @@ impl Default for super::CommandState { stage_infos: Default::default(), storage_buffer_length_map: Default::default(), vertex_buffer_size_map: Default::default(), - work_group_memory_sizes: Vec::new(), push_constants: Vec::new(), pending_timer_queries: Vec::new(), } @@ -149,7 +148,6 @@ impl super::CommandState { self.stage_infos.vs.clear(); self.stage_infos.fs.clear(); self.stage_infos.cs.clear(); - self.work_group_memory_sizes.clear(); self.push_constants.clear(); } @@ -1335,13 +1333,15 @@ impl crate::CommandEncoder for super::CommandEncoder { } // update the threadgroup memory sizes - while self.state.work_group_memory_sizes.len() + while self.state.stage_infos.cs.work_group_memory_sizes.len() < pipeline.cs_info.work_group_memory_sizes.len() { - self.state.work_group_memory_sizes.push(0); + self.state.stage_infos.cs.work_group_memory_sizes.push(0); } for (index, (cur_size, pipeline_size)) in self .state + .stage_infos + .cs .work_group_memory_sizes .iter_mut() .zip(pipeline.cs_info.work_group_memory_sizes.iter()) diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index c2d2a80a214..c4d9992e7db 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -850,6 +850,9 @@ impl PipelineStageInfo { self.sizes_slot = None; self.sized_bindings.clear(); self.vertex_buffer_mappings.clear(); + self.library = None; + self.work_group_memory_sizes.clear(); + self.raw_wg_size = Default::default(); } fn assign_from(&mut self, other: &Self) { @@ -985,7 +988,6 @@ struct CommandState { vertex_buffer_size_map: FastHashMap, - work_group_memory_sizes: Vec, push_constants: Vec, /// Timer query that should be executed when the next pass starts. From ece1ea10c1c9c3cebbc2db4ec1742e5aac1eb289 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 01:58:22 -0500 Subject: [PATCH 16/89] Another slight refactor --- wgpu-hal/src/metal/command.rs | 2 ++ wgpu-hal/src/metal/mod.rs | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index a91035b642f..a83540a9a37 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -148,6 +148,8 @@ impl super::CommandState { self.stage_infos.vs.clear(); self.stage_infos.fs.clear(); self.stage_infos.cs.clear(); + self.stage_infos.ts.clear(); + self.stage_infos.ms.clear(); self.push_constants.clear(); } diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index c4d9992e7db..1e7b5281240 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -863,6 +863,11 @@ impl PipelineStageInfo { self.vertex_buffer_mappings.clear(); self.vertex_buffer_mappings .extend_from_slice(&other.vertex_buffer_mappings); + self.library = Some(other.library.as_ref().unwrap().clone()); + self.raw_wg_size = other.raw_wg_size; + self.work_group_memory_sizes.clear(); + self.work_group_memory_sizes + .extend_from_slice(&other.work_group_memory_sizes); } } From 47c187b40ed14a08112ad0221dff9295c6190259 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 01:59:49 -0500 Subject: [PATCH 17/89] Fixed it --- wgpu-hal/src/metal/command.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index a83540a9a37..542287983e9 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -915,11 +915,11 @@ impl crate::CommandEncoder for super::CommandEncoder { } match pipeline.ts_info { Some(ref info) => self.state.stage_infos.ts.assign_from(info), - None => self.state.stage_infos.vs.clear(), + None => self.state.stage_infos.ts.clear(), } match pipeline.ms_info { Some(ref info) => self.state.stage_infos.ms.assign_from(info), - None => self.state.stage_infos.fs.clear(), + None => self.state.stage_infos.ms.clear(), } let encoder = self.state.render.as_ref().unwrap(); From 8bc63b662a542971d6b8f3c19284e35fc3417592 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 02:33:22 -0500 Subject: [PATCH 18/89] Worked a little more on trying to add it to example --- examples/features/src/mesh_shader/mod.rs | 32 ++++++-- .../features/src/mesh_shader/shader.metal | 74 +++++++++++++++++++ wgpu-types/src/lib.rs | 4 +- 3 files changed, 102 insertions(+), 8 deletions(-) create mode 100644 examples/features/src/mesh_shader/shader.metal diff --git a/examples/features/src/mesh_shader/mod.rs b/examples/features/src/mesh_shader/mod.rs index 675150f5106..e21e7ae2c95 100644 --- a/examples/features/src/mesh_shader/mod.rs +++ b/examples/features/src/mesh_shader/mod.rs @@ -33,13 +33,25 @@ fn compile_glsl( } } +fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule { + unsafe { + device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough { + entry_point: entry.to_owned(), + label: None, + msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))), + num_workgroups: (1, 1, 1), + ..Default::default() + }) + } +} + pub struct Example { pipeline: wgpu::RenderPipeline, } impl crate::framework::Example for Example { fn init( config: &wgpu::SurfaceConfiguration, - _adapter: &wgpu::Adapter, + adapter: &wgpu::Adapter, device: &wgpu::Device, _queue: &wgpu::Queue, ) -> Self { @@ -48,11 +60,19 @@ impl crate::framework::Example for Example { bind_group_layouts: &[], push_constant_ranges: &[], }); - let (ts, ms, fs) = ( - compile_glsl(device, include_bytes!("shader.task"), "task"), - compile_glsl(device, include_bytes!("shader.mesh"), "mesh"), - compile_glsl(device, include_bytes!("shader.frag"), "frag"), - ); + let (ts, ms, fs) = if adapter.get_info().backend == wgpu::Backend::Metal { + ( + compile_msl(device, "taskShader"), + compile_msl(device, "meshShader"), + compile_msl(device, "fragShader"), + ) + } else { + ( + compile_glsl(device, include_bytes!("shader.task"), "task"), + compile_glsl(device, include_bytes!("shader.mesh"), "mesh"), + compile_glsl(device, include_bytes!("shader.frag"), "frag"), + ) + }; let pipeline = device.create_mesh_pipeline(&wgpu::MeshPipelineDescriptor { label: None, layout: Some(&pipeline_layout), diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal new file mode 100644 index 00000000000..0a563132a19 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.metal @@ -0,0 +1,74 @@ +using namespace metal; + +struct OutVertex { + float4 Position [[position]]; + float4 Color; +}; + +struct OutPrimitive { + float4 ColorMask [[flat]]; + bool CullPrimitive; +}; + +struct InVertex { + float4 Color; +}; + +struct InPrimitive { + float4 ColorMask [[flat]]; +}; + +struct PayloadData { + float4 ColorMask; + bool Visible; +}; + +using Meshlet = metal::mesh; + + +constant float4 positions[3] = { + float4(0.0, 1.0, 0.0, 1.0), + float4(-1.0, -1.0, 0.0, 1.0), + float4(1.0, -1.0, 0.0, 1.0) +}; + +constant float4 colors[3] = { + float4(0.0, 1.0, 0.0, 1.0), + float4(0.0, 0.0, 1.0, 1.0), + float4(1.0, 0.0, 0.0, 1.0) +}; + + +[[object]] +void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], grid_properties grid) { + outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0); + outPayload.Visible = true; + grid.set_threadgroups_per_grid(uint3(3, 1, 1)); +} + +[[mesh, topology(triangle)]] +void meshShader( + object_data PayloadData const& payload [[payload]], + Meshlet out, +) +{ + out.set_primitive_count(1); + + for(int i = 0;i < 3;i++) { + OutVertex vert; + vert.Position = positions[i]; + vert.Color = colors[i] * payload.ColorMask; + mesh.set_vertex(i, vert); + out.set_index(i, i); + } + + triangles[0] = uint3(0, 1, 2); + OutPrimitive prim; + prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0); + prim.CullPrimitive = !payload.Visible; + out.set_primitive(0, prim); +} + +fragment float4 fragShader(OutVertex inVertex [[stage_in]], OutPrimitive inPrimitive [[stage_in]]) { + return inVertex.Color * inPrimitive.ColorMask; +} diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index ea2a09eb62a..828136a690c 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -979,8 +979,8 @@ impl Limits { // Literally just made this up as 256^2 or 2^16. // My GPU supports 2^22, and compute shaders don't have this kind of limit. // This very likely is never a real limiter - max_task_workgroup_total_count: 65536, - max_task_workgroups_per_dimension: 256, + max_task_workgroup_total_count: 1024, + max_task_workgroups_per_dimension: 1024, // llvmpipe reports 0 multiview count, which just means no multiview is allowed max_mesh_multiview_count: 0, // llvmpipe once again requires this to be 8. An RTX 3060 supports well over 1024. From 55d6bf3b3ab94c3ea16d09856bc393d98a9b67ed Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 02:40:34 -0500 Subject: [PATCH 19/89] Fixed metal shader --- examples/features/src/mesh_shader/shader.metal | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal index 0a563132a19..65edc83e442 100644 --- a/examples/features/src/mesh_shader/shader.metal +++ b/examples/features/src/mesh_shader/shader.metal @@ -18,6 +18,11 @@ struct InPrimitive { float4 ColorMask [[flat]]; }; +struct FragmentIn { + InVertex vert; + InPrimitive prim; +}; + struct PayloadData { float4 ColorMask; bool Visible; @@ -40,16 +45,16 @@ constant float4 colors[3] = { [[object]] -void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], grid_properties grid) { +void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) { outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0); outPayload.Visible = true; grid.set_threadgroups_per_grid(uint3(3, 1, 1)); } -[[mesh, topology(triangle)]] +[[mesh]] void meshShader( object_data PayloadData const& payload [[payload]], - Meshlet out, + Meshlet out ) { out.set_primitive_count(1); @@ -58,17 +63,16 @@ void meshShader( OutVertex vert; vert.Position = positions[i]; vert.Color = colors[i] * payload.ColorMask; - mesh.set_vertex(i, vert); + out.set_vertex(i, vert); out.set_index(i, i); } - triangles[0] = uint3(0, 1, 2); OutPrimitive prim; prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0); prim.CullPrimitive = !payload.Visible; out.set_primitive(0, prim); } -fragment float4 fragShader(OutVertex inVertex [[stage_in]], OutPrimitive inPrimitive [[stage_in]]) { - return inVertex.Color * inPrimitive.ColorMask; +fragment float4 fragShader(FragmentIn data [[stage_in]]) { + return data.vert.Color * data.prim.ColorMask; } From edfd494cbd71ffa8b9b2dd82583e91a9e5c68a18 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 02:56:17 -0500 Subject: [PATCH 20/89] Fixed some passthrough stuff, now it runs (uggh) --- wgpu-hal/src/metal/device.rs | 326 ++++++++++++++++++----------------- wgpu-hal/src/metal/mod.rs | 3 +- 2 files changed, 174 insertions(+), 155 deletions(-) diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index ee1a74b2131..3a48c9e8ead 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -133,176 +133,194 @@ impl super::Device { primitive_class: MTLPrimitiveTopologyClass, naga_stage: naga::ShaderStage, ) -> Result { - let naga_shader = if let ShaderModuleSource::Naga(naga) = &stage.module.source { - naga - } else { - panic!("load_shader required a naga shader"); - }; - let stage_bit = map_naga_stage(naga_stage); - let (module, module_info) = naga::back::pipeline_constants::process_overrides( - &naga_shader.module, - &naga_shader.info, - Some((naga_stage, stage.entry_point)), - stage.constants, - ) - .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}")))?; - - let ep_resources = &layout.per_stage_map[naga_stage]; - - let bounds_check_policy = if stage.module.bounds_checks.bounds_checks { - naga::proc::BoundsCheckPolicy::Restrict - } else { - naga::proc::BoundsCheckPolicy::Unchecked - }; + match stage.module.source { + ShaderModuleSource::Naga(ref naga_shader) => { + let stage_bit = map_naga_stage(naga_stage); + let (module, module_info) = naga::back::pipeline_constants::process_overrides( + &naga_shader.module, + &naga_shader.info, + Some((naga_stage, stage.entry_point)), + stage.constants, + ) + .map_err(|e| { + crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {e:?}")) + })?; - let options = naga::back::msl::Options { - lang_version: match self.shared.private_caps.msl_version { - MTLLanguageVersion::V1_0 => (1, 0), - MTLLanguageVersion::V1_1 => (1, 1), - MTLLanguageVersion::V1_2 => (1, 2), - MTLLanguageVersion::V2_0 => (2, 0), - MTLLanguageVersion::V2_1 => (2, 1), - MTLLanguageVersion::V2_2 => (2, 2), - MTLLanguageVersion::V2_3 => (2, 3), - MTLLanguageVersion::V2_4 => (2, 4), - MTLLanguageVersion::V3_0 => (3, 0), - MTLLanguageVersion::V3_1 => (3, 1), - }, - inline_samplers: Default::default(), - spirv_cross_compatibility: false, - fake_missing_bindings: false, - per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([( - stage.entry_point.to_owned(), - ep_resources.clone(), - )]), - bounds_check_policies: naga::proc::BoundsCheckPolicies { - index: bounds_check_policy, - buffer: bounds_check_policy, - image_load: bounds_check_policy, - // TODO: support bounds checks on binding arrays - binding_array: naga::proc::BoundsCheckPolicy::Unchecked, - }, - zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory, - force_loop_bounding: stage.module.bounds_checks.force_loop_bounding, - }; + let ep_resources = &layout.per_stage_map[naga_stage]; - let pipeline_options = naga::back::msl::PipelineOptions { - entry_point: Some((naga_stage, stage.entry_point.to_owned())), - allow_and_force_point_size: match primitive_class { - MTLPrimitiveTopologyClass::Point => true, - _ => false, - }, - vertex_pulling_transform: true, - vertex_buffer_mappings: vertex_buffer_mappings.to_vec(), - }; + let bounds_check_policy = if stage.module.bounds_checks.bounds_checks { + naga::proc::BoundsCheckPolicy::Restrict + } else { + naga::proc::BoundsCheckPolicy::Unchecked + }; - let (source, info) = - naga::back::msl::write_string(&module, &module_info, &options, &pipeline_options) - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?; + let options = naga::back::msl::Options { + lang_version: match self.shared.private_caps.msl_version { + MTLLanguageVersion::V1_0 => (1, 0), + MTLLanguageVersion::V1_1 => (1, 1), + MTLLanguageVersion::V1_2 => (1, 2), + MTLLanguageVersion::V2_0 => (2, 0), + MTLLanguageVersion::V2_1 => (2, 1), + MTLLanguageVersion::V2_2 => (2, 2), + MTLLanguageVersion::V2_3 => (2, 3), + MTLLanguageVersion::V2_4 => (2, 4), + MTLLanguageVersion::V3_0 => (3, 0), + MTLLanguageVersion::V3_1 => (3, 1), + }, + inline_samplers: Default::default(), + spirv_cross_compatibility: false, + fake_missing_bindings: false, + per_entry_point_map: naga::back::msl::EntryPointResourceMap::from([( + stage.entry_point.to_owned(), + ep_resources.clone(), + )]), + bounds_check_policies: naga::proc::BoundsCheckPolicies { + index: bounds_check_policy, + buffer: bounds_check_policy, + image_load: bounds_check_policy, + // TODO: support bounds checks on binding arrays + binding_array: naga::proc::BoundsCheckPolicy::Unchecked, + }, + zero_initialize_workgroup_memory: stage.zero_initialize_workgroup_memory, + force_loop_bounding: stage.module.bounds_checks.force_loop_bounding, + }; - log::debug!( - "Naga generated shader for entry point '{}' and stage {:?}\n{}", - stage.entry_point, - naga_stage, - &source - ); + let pipeline_options = naga::back::msl::PipelineOptions { + entry_point: Some((naga_stage, stage.entry_point.to_owned())), + allow_and_force_point_size: match primitive_class { + MTLPrimitiveTopologyClass::Point => true, + _ => false, + }, + vertex_pulling_transform: true, + vertex_buffer_mappings: vertex_buffer_mappings.to_vec(), + }; - let options = metal::CompileOptions::new(); - options.set_language_version(self.shared.private_caps.msl_version); + let (source, info) = naga::back::msl::write_string( + &module, + &module_info, + &options, + &pipeline_options, + ) + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {e:?}")))?; - if self.shared.private_caps.supports_preserve_invariance { - options.set_preserve_invariance(true); - } + log::debug!( + "Naga generated shader for entry point '{}' and stage {:?}\n{}", + stage.entry_point, + naga_stage, + &source + ); - let library = self - .shared - .device - .lock() - .new_library_with_source(source.as_ref(), &options) - .map_err(|err| { - log::warn!("Naga generated shader:\n{source}"); - crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}")) - })?; - - let ep_index = module - .entry_points - .iter() - .position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point) - .ok_or(crate::PipelineError::EntryPoint(naga_stage))?; - let ep = &module.entry_points[ep_index]; - let translated_ep_name = info.entry_point_names[0] - .as_ref() - .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; - - let wg_size = MTLSize { - width: ep.workgroup_size[0] as _, - height: ep.workgroup_size[1] as _, - depth: ep.workgroup_size[2] as _, - }; + let options = metal::CompileOptions::new(); + options.set_language_version(self.shared.private_caps.msl_version); - let function = library - .get_function(translated_ep_name, None) - .map_err(|e| { - log::error!("get_function: {e:?}"); - crate::PipelineError::EntryPoint(naga_stage) - })?; - - // collect sizes indices, immutable buffers, and work group memory sizes - let ep_info = &module_info.get_entry_point(ep_index); - let mut wg_memory_sizes = Vec::new(); - let mut sized_bindings = Vec::new(); - let mut immutable_buffer_mask = 0; - for (var_handle, var) in module.global_variables.iter() { - match var.space { - naga::AddressSpace::WorkGroup => { - if !ep_info[var_handle].is_empty() { - let size = module.types[var.ty].inner.size(module.to_ctx()); - wg_memory_sizes.push(size); - } + if self.shared.private_caps.supports_preserve_invariance { + options.set_preserve_invariance(true); } - naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => { - let br = match var.binding { - Some(br) => br, - None => continue, - }; - let storage_access_store = match var.space { - naga::AddressSpace::Storage { access } => { - access.contains(naga::StorageAccess::STORE) + + let library = self + .shared + .device + .lock() + .new_library_with_source(source.as_ref(), &options) + .map_err(|err| { + log::warn!("Naga generated shader:\n{source}"); + crate::PipelineError::Linkage(stage_bit, format!("Metal: {err}")) + })?; + + let ep_index = module + .entry_points + .iter() + .position(|ep| ep.stage == naga_stage && ep.name == stage.entry_point) + .ok_or(crate::PipelineError::EntryPoint(naga_stage))?; + let ep = &module.entry_points[ep_index]; + let translated_ep_name = info.entry_point_names[0] + .as_ref() + .map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("{e}")))?; + + let wg_size = MTLSize { + width: ep.workgroup_size[0] as _, + height: ep.workgroup_size[1] as _, + depth: ep.workgroup_size[2] as _, + }; + + let function = library + .get_function(translated_ep_name, None) + .map_err(|e| { + log::error!("get_function: {e:?}"); + crate::PipelineError::EntryPoint(naga_stage) + })?; + + // collect sizes indices, immutable buffers, and work group memory sizes + let ep_info = &module_info.get_entry_point(ep_index); + let mut wg_memory_sizes = Vec::new(); + let mut sized_bindings = Vec::new(); + let mut immutable_buffer_mask = 0; + for (var_handle, var) in module.global_variables.iter() { + match var.space { + naga::AddressSpace::WorkGroup => { + if !ep_info[var_handle].is_empty() { + let size = module.types[var.ty].inner.size(module.to_ctx()); + wg_memory_sizes.push(size); + } } - _ => false, - }; + naga::AddressSpace::Uniform | naga::AddressSpace::Storage { .. } => { + let br = match var.binding { + Some(br) => br, + None => continue, + }; + let storage_access_store = match var.space { + naga::AddressSpace::Storage { access } => { + access.contains(naga::StorageAccess::STORE) + } + _ => false, + }; - // check for an immutable buffer - if !ep_info[var_handle].is_empty() && !storage_access_store { - let slot = ep_resources.resources[&br].buffer.unwrap(); - immutable_buffer_mask |= 1 << slot; - } + // check for an immutable buffer + if !ep_info[var_handle].is_empty() && !storage_access_store { + let slot = ep_resources.resources[&br].buffer.unwrap(); + immutable_buffer_mask |= 1 << slot; + } - let mut dynamic_array_container_ty = var.ty; - if let naga::TypeInner::Struct { ref members, .. } = module.types[var.ty].inner - { - dynamic_array_container_ty = members.last().unwrap().ty; - } - if let naga::TypeInner::Array { - size: naga::ArraySize::Dynamic, - .. - } = module.types[dynamic_array_container_ty].inner - { - sized_bindings.push(br); + let mut dynamic_array_container_ty = var.ty; + if let naga::TypeInner::Struct { ref members, .. } = + module.types[var.ty].inner + { + dynamic_array_container_ty = members.last().unwrap().ty; + } + if let naga::TypeInner::Array { + size: naga::ArraySize::Dynamic, + .. + } = module.types[dynamic_array_container_ty].inner + { + sized_bindings.push(br); + } + } + _ => {} } } - _ => {} + + Ok(CompiledShader { + library, + function, + wg_size, + wg_memory_sizes, + sized_bindings, + immutable_buffer_mask, + }) } + ShaderModuleSource::Passthrough(ref shader) => Ok(CompiledShader { + library: shader.library.clone(), + function: shader.function.clone(), + wg_size: MTLSize { + width: shader.num_workgroups.0 as u64, + height: shader.num_workgroups.1 as u64, + depth: shader.num_workgroups.2 as u64, + }, + wg_memory_sizes: vec![], + sized_bindings: vec![], + immutable_buffer_mask: 0, + }), } - - Ok(CompiledShader { - library, - function, - wg_size, - wg_memory_sizes, - sized_bindings, - immutable_buffer_mask, - }) } fn set_buffers_mutability( diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 1e7b5281240..fda7e001906 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -624,7 +624,8 @@ impl ops::Index for MultiStageData { naga::ShaderStage::Vertex => &self.vs, naga::ShaderStage::Fragment => &self.fs, naga::ShaderStage::Compute => &self.cs, - naga::ShaderStage::Task | naga::ShaderStage::Mesh => unreachable!(), + naga::ShaderStage::Task => &self.ts, + naga::ShaderStage::Mesh => &self.ms, } } } From d4725b1f14cd6f8c6ba00068884e0dea22a71343 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 13:54:15 -0500 Subject: [PATCH 21/89] Small update to test shader (still blank screen) --- examples/features/src/mesh_shader/shader.metal | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal index 65edc83e442..5c99fffc231 100644 --- a/examples/features/src/mesh_shader/shader.metal +++ b/examples/features/src/mesh_shader/shader.metal @@ -2,20 +2,20 @@ using namespace metal; struct OutVertex { float4 Position [[position]]; - float4 Color; + float4 Color [[user(locn0)]]; }; struct OutPrimitive { - float4 ColorMask [[flat]]; - bool CullPrimitive; + float4 ColorMask [[flat]] [[user(locn1)]]; + bool CullPrimitive [[primitive_culled]]; }; struct InVertex { - float4 Color; + float4 Color [[user(locn0)]]; }; struct InPrimitive { - float4 ColorMask [[flat]]; + float4 ColorMask [[flat]] [[user(locn1)]]; }; struct FragmentIn { From 7efae60bd9138c37f095c946fd4b349b0f454eb6 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 13:56:20 -0500 Subject: [PATCH 22/89] Another quick update to the shader --- examples/features/src/mesh_shader/shader.metal | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal index 5c99fffc231..4c7da503832 100644 --- a/examples/features/src/mesh_shader/shader.metal +++ b/examples/features/src/mesh_shader/shader.metal @@ -11,7 +11,6 @@ struct OutPrimitive { }; struct InVertex { - float4 Color [[user(locn0)]]; }; struct InPrimitive { @@ -19,8 +18,8 @@ struct InPrimitive { }; struct FragmentIn { - InVertex vert; - InPrimitive prim; + float4 Color [[user(locn0)]]; + float4 ColorMask [[flat]] [[user(locn1)]]; }; struct PayloadData { @@ -74,5 +73,5 @@ void meshShader( } fragment float4 fragShader(FragmentIn data [[stage_in]]) { - return data.vert.Color * data.prim.ColorMask; + return data.Color * data.ColorMask; } From bd79d513438e30f1fe845490ddda835f6f9f46ef Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 13:59:38 -0500 Subject: [PATCH 23/89] Made mesh shader tests get skipped on metal due to not having MSL passthrough yet --- tests/tests/wgpu-gpu/mesh_shader/mod.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/tests/wgpu-gpu/mesh_shader/mod.rs b/tests/tests/wgpu-gpu/mesh_shader/mod.rs index 4dd897129f6..ae705c92341 100644 --- a/tests/tests/wgpu-gpu/mesh_shader/mod.rs +++ b/tests/tests/wgpu-gpu/mesh_shader/mod.rs @@ -86,6 +86,9 @@ fn mesh_pipeline_build( frag: Option<&[u8]>, draw: bool, ) { + if ctx.adapter.get_info().backend != wgpu::Backend::Vulkan { + return; + } let device = &ctx.device; let (_depth_image, depth_view, depth_state) = create_depth(device); let task = task.map(|t| compile_glsl(device, t, "task")); @@ -160,6 +163,9 @@ pub enum DrawType { } fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { + if ctx.adapter.get_info().backend != wgpu::Backend::Vulkan { + return; + } let device = &ctx.device; let (_depth_image, depth_view, depth_state) = create_depth(device); let task = compile_glsl(device, BASIC_TASK, "task"); From 760de4b59e007cbde0d0d80e761d7ea839b7b476 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 14:33:19 -0500 Subject: [PATCH 24/89] Add changelog entry --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 11fb072dcab..2f679924d06 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -110,6 +110,9 @@ This allows using precompiled shaders without manually checking which backend's - Allow disabling waiting for latency waitable object. By @marcpabst in [#7400](https://github.com/gfx-rs/wgpu/pull/7400) +#### Metal +- Add support for mesh shaders. By @SupaMaggie70Incorporated in [#8139](https://github.com/gfx-rs/wgpu/pull/8139) + ### Bug Fixes #### General From 3f56df6b4842a14a79301ec8eeeeac309deac01b Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 15:10:47 -0500 Subject: [PATCH 25/89] Made some stuff more generic (bind groups & push constants) --- wgpu-hal/src/metal/command.rs | 300 +++++++++++++++++----------------- 1 file changed, 148 insertions(+), 152 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 542287983e9..1e4ac8d2419 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -672,168 +672,150 @@ impl crate::CommandEncoder for super::CommandEncoder { dynamic_offsets: &[wgt::DynamicOffset], ) { let bg_info = &layout.bind_group_infos[group_index as usize]; - - if let Some(ref encoder) = self.state.render { - let mut changes_sizes_buffer = false; - for index in 0..group.counters.vs.buffers { - let buf = &group.buffers[index as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + let render_encoder = self.state.render.clone(); + let compute_encoder = self.state.compute.clone(); + let mut update_stage = + |stage: naga::ShaderStage, + render_encoder: Option<&metal::RenderCommandEncoder>, + compute_encoder: Option<&metal::ComputeCommandEncoder>| { + let buffers = match stage { + naga::ShaderStage::Vertex => group.counters.vs.buffers, + naga::ShaderStage::Fragment => group.counters.fs.buffers, + naga::ShaderStage::Task => group.counters.ts.buffers, + naga::ShaderStage::Mesh => group.counters.ms.buffers, + naga::ShaderStage::Compute => group.counters.cs.buffers, + }; + let mut changes_sizes_buffer = false; + for index in 0..buffers { + let buf = &group.buffers[index as usize]; + let mut offset = buf.offset; + if let Some(dyn_index) = buf.dynamic_index { + offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + } + let a1 = (bg_info.base_resource_indices.vs.buffers + index) as u64; + let a2 = Some(buf.ptr.as_native()); + let a3 = offset; + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_buffer(a1, a2, a3) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_buffer(a1, a2, a3) + } + naga::ShaderStage::Task => { + render_encoder.unwrap().set_object_buffer(a1, a2, a3) + } + naga::ShaderStage::Mesh => { + render_encoder.unwrap().set_mesh_buffer(a1, a2, a3) + } + naga::ShaderStage::Compute => { + compute_encoder.unwrap().set_buffer(a1, a2, a3) + } + } + if let Some(size) = buf.binding_size { + let br = naga::ResourceBinding { + group: group_index, + binding: buf.binding_location, + }; + self.state.storage_buffer_length_map.insert(br, size); + changes_sizes_buffer = true; + } } - encoder.set_vertex_buffer( - (bg_info.base_resource_indices.vs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, - ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; + if changes_sizes_buffer { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(stage, &mut self.temp.binding_sizes) + { + let a1 = index as _; + let a2 = (sizes.len() * WORD_SIZE) as u64; + let a3 = sizes.as_ptr().cast(); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_bytes(a1, a2, a3) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_bytes(a1, a2, a3) + } + naga::ShaderStage::Task => { + render_encoder.unwrap().set_object_bytes(a1, a2, a3) + } + naga::ShaderStage::Mesh => { + render_encoder.unwrap().set_mesh_bytes(a1, a2, a3) + } + naga::ShaderStage::Compute => { + compute_encoder.unwrap().set_bytes(a1, a2, a3) + } + } + } } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Vertex, - &mut self.temp.binding_sizes, - ) { - encoder.set_vertex_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); + let samplers = match stage { + naga::ShaderStage::Vertex => group.counters.vs.samplers, + naga::ShaderStage::Fragment => group.counters.fs.samplers, + naga::ShaderStage::Task => group.counters.ts.samplers, + naga::ShaderStage::Mesh => group.counters.ms.samplers, + naga::ShaderStage::Compute => group.counters.cs.samplers, + }; + for index in 0..samplers { + let res = group.samplers[(group.counters.vs.samplers + index) as usize]; + let a1 = (bg_info.base_resource_indices.fs.samplers + index) as u64; + let a2 = Some(res.as_native()); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_sampler_state(a1, a2) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_sampler_state(a1, a2) + } + naga::ShaderStage::Task => { + render_encoder.unwrap().set_object_sampler_state(a1, a2) + } + naga::ShaderStage::Mesh => { + render_encoder.unwrap().set_mesh_sampler_state(a1, a2) + } + naga::ShaderStage::Compute => { + compute_encoder.unwrap().set_sampler_state(a1, a2) + } + } } - } - changes_sizes_buffer = false; - for index in 0..group.counters.fs.buffers { - let buf = &group.buffers[(group.counters.vs.buffers + index) as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; - } - encoder.set_fragment_buffer( - (bg_info.base_resource_indices.fs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, - ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; - } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Fragment, - &mut self.temp.binding_sizes, - ) { - encoder.set_fragment_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); + let textures = match stage { + naga::ShaderStage::Vertex => group.counters.vs.textures, + naga::ShaderStage::Fragment => group.counters.fs.textures, + naga::ShaderStage::Task => group.counters.ts.textures, + naga::ShaderStage::Mesh => group.counters.ms.textures, + naga::ShaderStage::Compute => group.counters.cs.textures, + }; + for index in 0..textures { + let res = group.textures[index as usize]; + let a1 = (bg_info.base_resource_indices.vs.textures + index) as u64; + let a2 = Some(res.as_native()); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_texture(a1, a2) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_texture(a1, a2) + } + naga::ShaderStage::Task => { + render_encoder.unwrap().set_object_texture(a1, a2) + } + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2), + } } - } - - for index in 0..group.counters.vs.samplers { - let res = group.samplers[index as usize]; - encoder.set_vertex_sampler_state( - (bg_info.base_resource_indices.vs.samplers + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.fs.samplers { - let res = group.samplers[(group.counters.vs.samplers + index) as usize]; - encoder.set_fragment_sampler_state( - (bg_info.base_resource_indices.fs.samplers + index) as u64, - Some(res.as_native()), - ); - } - - for index in 0..group.counters.vs.textures { - let res = group.textures[index as usize]; - encoder.set_vertex_texture( - (bg_info.base_resource_indices.vs.textures + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.fs.textures { - let res = group.textures[(group.counters.vs.textures + index) as usize]; - encoder.set_fragment_texture( - (bg_info.base_resource_indices.fs.textures + index) as u64, - Some(res.as_native()), - ); - } - + }; + if let Some(encoder) = render_encoder { + update_stage(naga::ShaderStage::Vertex, Some(&encoder), None); + update_stage(naga::ShaderStage::Fragment, Some(&encoder), None); + update_stage(naga::ShaderStage::Task, Some(&encoder), None); + update_stage(naga::ShaderStage::Mesh, Some(&encoder), None); // Call useResource on all textures and buffers used indirectly so they are alive for (resource, use_info) in group.resources_to_use.iter() { encoder.use_resource_at(resource.as_native(), use_info.uses, use_info.stages); } } - - if let Some(ref encoder) = self.state.compute { - let index_base = super::ResourceData { - buffers: group.counters.vs.buffers + group.counters.fs.buffers, - samplers: group.counters.vs.samplers + group.counters.fs.samplers, - textures: group.counters.vs.textures + group.counters.fs.textures, - }; - - let mut changes_sizes_buffer = false; - for index in 0..group.counters.cs.buffers { - let buf = &group.buffers[(index_base.buffers + index) as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; - } - encoder.set_buffer( - (bg_info.base_resource_indices.cs.buffers + index) as u64, - Some(buf.ptr.as_native()), - offset, - ); - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; - } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self.state.make_sizes_buffer_update( - naga::ShaderStage::Compute, - &mut self.temp.binding_sizes, - ) { - encoder.set_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); - } - } - - for index in 0..group.counters.cs.samplers { - let res = group.samplers[(index_base.samplers + index) as usize]; - encoder.set_sampler_state( - (bg_info.base_resource_indices.cs.samplers + index) as u64, - Some(res.as_native()), - ); - } - for index in 0..group.counters.cs.textures { - let res = group.textures[(index_base.textures + index) as usize]; - encoder.set_texture( - (bg_info.base_resource_indices.cs.textures + index) as u64, - Some(res.as_native()), - ); - } - + if let Some(encoder) = compute_encoder { + update_stage(naga::ShaderStage::Compute, None, Some(&encoder)); // Call useResource on all textures and buffers used indirectly so they are alive for (resource, use_info) in group.resources_to_use.iter() { if !use_info.visible_in_compute { @@ -881,6 +863,20 @@ impl crate::CommandEncoder for super::CommandEncoder { state_pc.as_ptr().cast(), ) } + if stages.contains(wgt::ShaderStages::TASK) { + self.state.render.as_ref().unwrap().set_object_bytes( + layout.push_constants_infos.ts.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr().cast(), + ) + } + if stages.contains(wgt::ShaderStages::MESH) { + self.state.render.as_ref().unwrap().set_object_bytes( + layout.push_constants_infos.ms.unwrap().buffer_index as _, + (layout.total_push_constants as usize * WORD_SIZE) as _, + state_pc.as_ptr().cast(), + ) + } } unsafe fn insert_debug_marker(&mut self, label: &str) { From d6931d2f0a2217b8a8ba28f23e15cae0b89f54f2 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sun, 24 Aug 2025 15:27:35 -0500 Subject: [PATCH 26/89] Applied some fixes --- wgpu-hal/src/metal/command.rs | 11 ++++++----- wgpu-hal/src/metal/device.rs | 4 ++-- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 1e4ac8d2419..2ebf80f0f26 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1146,12 +1146,13 @@ impl crate::CommandEncoder for super::CommandEncoder { group_count_z: u32, ) { let encoder = self.state.render.as_ref().unwrap(); + let size = MTLSize { + width: group_count_x as u64, + height: group_count_y as u64, + depth: group_count_z as u64, + }; encoder.draw_mesh_threadgroups( - MTLSize { - width: group_count_x as u64, - height: group_count_y as u64, - depth: group_count_z as u64, - }, + size, self.state.stage_infos.ts.raw_wg_size, self.state.stage_infos.ms.raw_wg_size, ); diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index ca16a222efb..70753a3ff6c 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1264,7 +1264,7 @@ impl crate::Device for super::Device { sized_bindings: ts.sized_bindings, vertex_buffer_mappings: vec![], library: Some(ts.library), - raw_wg_size: Default::default(), + raw_wg_size: ts.wg_size, work_group_memory_sizes: vec![], }); } else { @@ -1291,7 +1291,7 @@ impl crate::Device for super::Device { sized_bindings: ms.sized_bindings, vertex_buffer_mappings: vec![], library: Some(ms.library), - raw_wg_size: Default::default(), + raw_wg_size: ms.wg_size, work_group_memory_sizes: vec![], }); } From 7bec4dd3fed42a01b2a6f3ecb35dd965a23ccbb0 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Sun, 24 Aug 2025 17:36:41 -0700 Subject: [PATCH 27/89] some doc tweaks --- wgpu/src/api/render_pass.rs | 22 ++++++++++++++++++- wgpu/src/api/render_pipeline.rs | 38 +++++++++++++++++++++++++++++++-- 2 files changed, 57 insertions(+), 3 deletions(-) diff --git a/wgpu/src/api/render_pass.rs b/wgpu/src/api/render_pass.rs index 8163b4261f0..5779d1a0ff3 100644 --- a/wgpu/src/api/render_pass.rs +++ b/wgpu/src/api/render_pass.rs @@ -226,7 +226,27 @@ impl RenderPass<'_> { self.inner.draw_indexed(indices, base_vertex, instances); } - /// Draws using a mesh shader pipeline + /// Draws using a mesh shader pipeline. + /// + /// The current pipeline must be a mesh shader pipeline. + /// + /// If the current pipeline has a task shader, run it with an invocation for + /// every `vec3(i, j, k)` where `i`, `j`, and `k` are between `0` and + /// `group_count_x`, `group_count_y`, and `group_count_z`. Each invocation's + /// return value indicates a set of mesh shaders to invoke, and passes + /// payload values for them to consume. TODO: provide specifics on return value + /// + /// If the current pipeline lacks a task shader, run its mesh shader with an + /// invocation for every `vec3(i, j, k)` where `i`, `j`, and `k` are + /// between `0` and `group_count_x`, `group_count_y`, and `group_count_z`. + /// + /// Each mesh shader invocation's return value produces a list of primitives + /// to draw. TODO: provide specifics on return value + /// + /// Each primitive is then rendered with the current pipeline's fragment + /// shader, if present. Otherwise, [No Color Output mode] is used. + /// + /// [No Color Output mode]: https://www.w3.org/TR/webgpu/#no-color-output pub fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) { self.inner .draw_mesh_tasks(group_count_x, group_count_y, group_count_z); diff --git a/wgpu/src/api/render_pipeline.rs b/wgpu/src/api/render_pipeline.rs index e887bb4b97e..07ec909b28c 100644 --- a/wgpu/src/api/render_pipeline.rs +++ b/wgpu/src/api/render_pipeline.rs @@ -238,7 +238,41 @@ static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync); /// Describes a mesh shader (graphics) pipeline. /// -/// For use with [`Device::create_mesh_pipeline`]. +/// For use with [`Device::create_mesh_pipeline`]. A mesh pipeline is very much +/// like a render pipeline, except that instead of [`RenderPass::draw`] it is +/// invoked with [`RenderPass::draw_mesh_tasks`], and instead of a vertex shader +/// and a fragment shader: +/// +/// - [`task`] specifies an optional task shader entry point, which generates +/// groups of mesh shaders to dispatch. +/// +/// - [`mesh`] specifies a mesh shader entry point, which generates groups of +/// primitives to draw +/// +/// - [`fragment`] specifies as fragment shader for drawing those primitive, +/// just like in an ordinary render pipeline. +/// +/// The key difference is that, whereas a vertex shader is invoked on the +/// elements of vertex buffers, the task shader gets to decide how many mesh +/// shader invocations to make, and then each mesh shader invocation gets to +/// decide which primitives it wants to generate, and what their vertex +/// attributes are. Task and mesh shaders can use whatever they please as +/// inputs, like a compute shader. (Fancy [vertex formats] are up to the mesh +/// shader to implement itself.) +/// +/// A mesh pipeline is invoked by [`RenderPass::draw_mesh_tasks`], which looks +/// like a compute shader dispatch with [`ComputePass::dispatch_workgroups`]: +/// you pass `x`, `y`, and `z` values indicating the number of task shaders to +/// invoke in parallel. TODO: what is the output of a task shader? +/// +/// If the task shader is omitted, then the (`x`, `y`, `z`) parameters to +/// `draw_mesh_tasks` are used to decide how many invocations of the mesh shader +/// to invoke directly. +/// +/// [vertex formats]: wgpu_types::VertexFormat +/// [`task`]: Self::task +/// [`mesh`]: Self::mesh +/// [`fragment`]: Self::fragment #[derive(Clone, Debug)] pub struct MeshPipelineDescriptor<'a> { /// Debug label of the pipeline. This will show up in graphics debuggers for easy identification. @@ -263,7 +297,7 @@ pub struct MeshPipelineDescriptor<'a> { /// /// [default layout]: https://www.w3.org/TR/webgpu/#default-pipeline-layout pub layout: Option<&'a PipelineLayout>, - /// The compiled task stage, its entry point, and the color targets. + /// The compiled task stage and its entry point. pub task: Option>, /// The compiled mesh stage and its entry point pub mesh: MeshState<'a>, From 2fcb8539c2d6e22c10e769035c185442cfe23226 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Mon, 25 Aug 2025 01:27:22 -0500 Subject: [PATCH 28/89] Tried to clarify docs a little --- wgpu/src/api/render_pass.rs | 32 +++++++++++++++++++------------- wgpu/src/api/render_pipeline.rs | 18 ++++++++++-------- 2 files changed, 29 insertions(+), 21 deletions(-) diff --git a/wgpu/src/api/render_pass.rs b/wgpu/src/api/render_pass.rs index 5779d1a0ff3..a832e380fbf 100644 --- a/wgpu/src/api/render_pass.rs +++ b/wgpu/src/api/render_pass.rs @@ -228,23 +228,29 @@ impl RenderPass<'_> { /// Draws using a mesh shader pipeline. /// - /// The current pipeline must be a mesh shader pipeline. + /// The current pipeline must be a mesh shader pipeline. /// - /// If the current pipeline has a task shader, run it with an invocation for + /// If the current pipeline has a task shader, run it with an workgroup for /// every `vec3(i, j, k)` where `i`, `j`, and `k` are between `0` and - /// `group_count_x`, `group_count_y`, and `group_count_z`. Each invocation's - /// return value indicates a set of mesh shaders to invoke, and passes - /// payload values for them to consume. TODO: provide specifics on return value - /// - /// If the current pipeline lacks a task shader, run its mesh shader with an - /// invocation for every `vec3(i, j, k)` where `i`, `j`, and `k` are + /// `group_count_x`, `group_count_y`, and `group_count_z`. The invocation with + /// index zero in each group is responsible for determining the mesh shader dispatch. + /// Its return value indicates the number of workgroups of mesh shaders to invoke. It also + /// passes a payload value for them to consume. Because each task workgroup is essentially + /// a mesh shader draw call, mesh workgroups dispatched by different task workgroups + /// cannot interact in any way, and `workgroup_id` corresponds to its location in the + /// calling specific task shader's dispatch group. + /// + /// If the current pipeline lacks a task shader, run its mesh shader with a + /// workgroup for every `vec3(i, j, k)` where `i`, `j`, and `k` are /// between `0` and `group_count_x`, `group_count_y`, and `group_count_z`. /// - /// Each mesh shader invocation's return value produces a list of primitives - /// to draw. TODO: provide specifics on return value - /// - /// Each primitive is then rendered with the current pipeline's fragment - /// shader, if present. Otherwise, [No Color Output mode] is used. + /// Each mesh shader workgroup outputs a set of vertices and indices for primitives. + /// The indices outputted correspond to the vertices outputted by that same workgroup; + /// there is no global vertex buffer. These primitives are passed to the rasterizer and + /// essentially treated like a vertex shader output, except that the mesh shader may + /// choose to cull specific primitives or pass per-primitive non-interpolated values + /// to the mesh shader. As such, each primitive is then rendered with the current + /// pipeline's fragment shader, if present. Otherwise, [No Color Output mode] is used. /// /// [No Color Output mode]: https://www.w3.org/TR/webgpu/#no-color-output pub fn draw_mesh_tasks(&mut self, group_count_x: u32, group_count_y: u32, group_count_z: u32) { diff --git a/wgpu/src/api/render_pipeline.rs b/wgpu/src/api/render_pipeline.rs index 07ec909b28c..be16d91f27a 100644 --- a/wgpu/src/api/render_pipeline.rs +++ b/wgpu/src/api/render_pipeline.rs @@ -243,31 +243,33 @@ static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync); /// invoked with [`RenderPass::draw_mesh_tasks`], and instead of a vertex shader /// and a fragment shader: /// -/// - [`task`] specifies an optional task shader entry point, which generates -/// groups of mesh shaders to dispatch. +/// - [`task`] specifies an optional task shader entry point, which determines how +/// many groups of mesh shaders to dispatch. /// /// - [`mesh`] specifies a mesh shader entry point, which generates groups of /// primitives to draw /// -/// - [`fragment`] specifies as fragment shader for drawing those primitive, +/// - [`fragment`] specifies as fragment shader for drawing those primitives, /// just like in an ordinary render pipeline. /// /// The key difference is that, whereas a vertex shader is invoked on the /// elements of vertex buffers, the task shader gets to decide how many mesh -/// shader invocations to make, and then each mesh shader invocation gets to +/// shader workgroups to make, and then each mesh shader workgroup gets to /// decide which primitives it wants to generate, and what their vertex /// attributes are. Task and mesh shaders can use whatever they please as -/// inputs, like a compute shader. (Fancy [vertex formats] are up to the mesh -/// shader to implement itself.) +/// inputs, like a compute shader. However, they cannot use specialized vertex +/// or index buffers. /// /// A mesh pipeline is invoked by [`RenderPass::draw_mesh_tasks`], which looks /// like a compute shader dispatch with [`ComputePass::dispatch_workgroups`]: /// you pass `x`, `y`, and `z` values indicating the number of task shaders to -/// invoke in parallel. TODO: what is the output of a task shader? +/// invoke in parallel. The output value of the first thread in a task shader +/// workgroup determines how many mesh workgroups should be dispatched from there. +/// Those mesh workgroups also get a special payload passed from the task shader. /// /// If the task shader is omitted, then the (`x`, `y`, `z`) parameters to /// `draw_mesh_tasks` are used to decide how many invocations of the mesh shader -/// to invoke directly. +/// to invoke directly, without a task payload. /// /// [vertex formats]: wgpu_types::VertexFormat /// [`task`]: Self::task From 8bfe1067e8658f728166d2a84d8be6dc64e47476 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Mon, 25 Aug 2025 02:10:41 -0500 Subject: [PATCH 29/89] Tried to update spec --- docs/api-specs/mesh_shading.md | 27 ++++++++++++++++++++++++++- 1 file changed, 26 insertions(+), 1 deletion(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index e9b6df3710d..24a4cde2cda 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -11,6 +11,31 @@ to breaking changes, suggestions for the API exposed by this should be posted on ***This is not*** a thorough explanation of mesh shading and how it works. Those wishing to understand mesh shading more broadly should look elsewhere first. +## Mesh shaders overview + +### What are mesh shaders +Mesh shaders are a new kind of rasterization pipeline intended to address some of the shortfalls with the vertex shader pipeline. The core idea of mesh shaders is that the GPU decides how to render the many small parts of a scene instead of the CPU issuing a draw call for every small part or issuing an inefficient monolithic draw call for a large part of the scene. + +Mesh shaders are specifically designed to be used with **meshlet rendering**, a technique where every object is split into many subobjects called meshlets that are each rendered with their own parameters. With the standard vertex pipeline, each draw call specifies an exact number of primitives to render and the same parameters for all vertex shaders on an entire object (or even multiple objects). This doesn't leave room for different LODs for different parts of an object, for example a closer part having more detail, nor does it allow culling smaller sections (or primitives) of objects. With mesh shaders, each task workgroup might get assigned to a single object. It can then analyze the different meshlets(sections) of that object, determine which are visible and should actually be rendered, and for those meshlets determine what LOD to use based on the distance from the camera. It can then dispatch a mesh workgroup for each meshlet, with each mesh workgroup then reading the data for that specific LOD of its meshlet, determining which and how many vertices and primitives to output, determining which remaining primitives need to be culled, and passing the resulting primitives to the rasterizer. + +Mesh shaders are most effective in scenes with many polygons. They can allow skipping processing of entire groups of primitives that are facing away from the camera or otherwise occluded, which reduces the number of primitives that need to be processed by more than half in most cases, and they can reduce the number of primitives that need to be processed for more distant objects. Scenes that are not bottlenecked by geometry (perhaps instead by fragment processing or post processing) will not see much benefit from using them. + +Mesh shaders were first shown off in [NVIDIA's asteroids demo](https://www.youtube.com/watch?v=CRfZYJ_sk5E). Now, they form the basis for [Unreal Engine's Nanite](https://www.unrealengine.com/en-US/blog/unreal-engine-5-is-now-available-in-preview#Nanite). + +### Mesh shader pipeline +A mesh shader pipeline is just like a standard render pipeline, except that the vertex shader stage is replaced by a mesh shader stage (and optionally a task shader stage). This functions as follows: + +* If there is a task shader stage, task shader workgroups are invoked first, with the number of workgroups determined by the draw call. Each task shader workgroup outputs a workgroup size and a task payload. A dispatch group of mesh shaders with the given workgroup size is then invoked with the task payload as a parameter. +* Otherwise, a single dispatch group of mesh shaders with workgroup size given by the draw call is invoked. +* Each mesh shader dispatch group functions exactly as a compute dispatch group, except that it has special outputs and may take a task payload as input. Mesh dispatch groups invoked by different task shader workgroups cannot interact. +* Each workgroup within the mesh shader dispatch group can output vertices and primitives + * It determines how many vertices and primitives to write and then sets those vertices and primitives. + * Primitives have an indices field which determines the indices of the vertices of that primitive. The indices are based on the output of that mesh shader workgroup only; there is no sharing of vertices across workgroups (no vertex or index buffer equivalents). + * Primitives can then be culled by setting the appropriate builtin + * Each vertex output functions exactly as the output from a vertex shader would + * There can also be per-primitive outputs passed to fragment shaders; these are not interpolated or based on the vertices of the primitive in any way. +* Once all of the primitives are written, those that weren't culled are are rasterized. From this point forward, the only difference from a standard render pipeline is that there may be some per-primitive inputs passed to fragment shaders. + ## `wgpu` API ### New `wgpu` functions @@ -101,7 +126,7 @@ Mesh shader primitive outputs must also specify exactly one of `@builtin(triangl Additionally, the `@location` attributes from the vertex and primitive outputs can't overlap. -Before setting any vertices or indices, or exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly these numbers of vertices and primitives. A varying member with `@per_primitive` cannot be used in function interfaces except as the primitive output for mesh shaders or as input for fragment shaders. +Before exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly this range of vertices and primitives. A varying member with `@per_primitive` cannot be used in function interfaces except as a primitive output for mesh shaders or as input for fragment shaders. The mesh shader can write to vertices using the `setVertex(idx: u32, vertex: VertexOutput)` where `VertexOutput` is replaced with the vertex type declared in `@vertex_output`, and `idx` is the index of the vertex to write. Similarly, the mesh shader can write to vertices using `setPrimitive(idx: u32, primitive: PrimitiveOutput)`. These can be written to multiple times, however unsynchronized writes are undefined behavior. The primitives and indices are shared across the entire mesh shader workgroup. From 6ccaeec5e96e50abc136acffda6841d92d52036d Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Mon, 25 Aug 2025 02:14:45 -0500 Subject: [PATCH 30/89] Removed a warning --- docs/api-specs/mesh_shading.md | 2 -- 1 file changed, 2 deletions(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index 24a4cde2cda..df0a5149f9f 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -9,8 +9,6 @@ For this reason, all shaders must be created with `Device::create_shader_module_ **Note**: The features documented here may have major bugs in them and are expected to be subject to breaking changes, suggestions for the API exposed by this should be posted on [the mesh-shading issue](https://github.com/gfx-rs/wgpu/issues/7197). -***This is not*** a thorough explanation of mesh shading and how it works. Those wishing to understand mesh shading more broadly should look elsewhere first. - ## Mesh shaders overview ### What are mesh shaders From 5b7ba116b70380827a14bf1da4e67a023529703e Mon Sep 17 00:00:00 2001 From: SupaMaggie70Incorporated Date: Mon, 25 Aug 2025 13:34:27 -0500 Subject: [PATCH 31/89] Addressed comment about docs mistake --- wgpu/src/api/render_pass.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/wgpu/src/api/render_pass.rs b/wgpu/src/api/render_pass.rs index a832e380fbf..0c3acad7ac8 100644 --- a/wgpu/src/api/render_pass.rs +++ b/wgpu/src/api/render_pass.rs @@ -249,7 +249,7 @@ impl RenderPass<'_> { /// there is no global vertex buffer. These primitives are passed to the rasterizer and /// essentially treated like a vertex shader output, except that the mesh shader may /// choose to cull specific primitives or pass per-primitive non-interpolated values - /// to the mesh shader. As such, each primitive is then rendered with the current + /// to the fragment shader. As such, each primitive is then rendered with the current /// pipeline's fragment shader, if present. Otherwise, [No Color Output mode] is used. /// /// [No Color Output mode]: https://www.w3.org/TR/webgpu/#no-color-output From 46576462ebd75cbf0d25f2f5a1a4d79cb2ec8af5 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Tue, 2 Sep 2025 08:11:38 -0700 Subject: [PATCH 32/89] Review in progress - Extensive revisions to `docs/api-specs/mesh_shading.md`. - Doc comments. - Ensure `Module` stays at the bottom of `ir/mod.rs`. - Avoid a clone. - Rename some arguments to be more specific. - Minor readability tweaks. --- docs/api-specs/mesh_shading.md | 113 +++++++++++++++++++++++-------- naga/src/ir/mod.rs | 115 ++++++++++++++++++-------------- naga/src/valid/analyzer.rs | 9 +-- naga/src/valid/interface.rs | 19 +++--- wgpu/src/api/render_pass.rs | 6 +- wgpu/src/api/render_pipeline.rs | 21 ++++-- 6 files changed, 184 insertions(+), 99 deletions(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index df0a5149f9f..fcead0898bb 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -11,7 +11,8 @@ to breaking changes, suggestions for the API exposed by this should be posted on ## Mesh shaders overview -### What are mesh shaders +### What are mesh shaders? + Mesh shaders are a new kind of rasterization pipeline intended to address some of the shortfalls with the vertex shader pipeline. The core idea of mesh shaders is that the GPU decides how to render the many small parts of a scene instead of the CPU issuing a draw call for every small part or issuing an inefficient monolithic draw call for a large part of the scene. Mesh shaders are specifically designed to be used with **meshlet rendering**, a technique where every object is split into many subobjects called meshlets that are each rendered with their own parameters. With the standard vertex pipeline, each draw call specifies an exact number of primitives to render and the same parameters for all vertex shaders on an entire object (or even multiple objects). This doesn't leave room for different LODs for different parts of an object, for example a closer part having more detail, nor does it allow culling smaller sections (or primitives) of objects. With mesh shaders, each task workgroup might get assigned to a single object. It can then analyze the different meshlets(sections) of that object, determine which are visible and should actually be rendered, and for those meshlets determine what LOD to use based on the distance from the camera. It can then dispatch a mesh workgroup for each meshlet, with each mesh workgroup then reading the data for that specific LOD of its meshlet, determining which and how many vertices and primitives to output, determining which remaining primitives need to be culled, and passing the resulting primitives to the rasterizer. @@ -21,18 +22,51 @@ Mesh shaders are most effective in scenes with many polygons. They can allow ski Mesh shaders were first shown off in [NVIDIA's asteroids demo](https://www.youtube.com/watch?v=CRfZYJ_sk5E). Now, they form the basis for [Unreal Engine's Nanite](https://www.unrealengine.com/en-US/blog/unreal-engine-5-is-now-available-in-preview#Nanite). ### Mesh shader pipeline -A mesh shader pipeline is just like a standard render pipeline, except that the vertex shader stage is replaced by a mesh shader stage (and optionally a task shader stage). This functions as follows: - -* If there is a task shader stage, task shader workgroups are invoked first, with the number of workgroups determined by the draw call. Each task shader workgroup outputs a workgroup size and a task payload. A dispatch group of mesh shaders with the given workgroup size is then invoked with the task payload as a parameter. -* Otherwise, a single dispatch group of mesh shaders with workgroup size given by the draw call is invoked. -* Each mesh shader dispatch group functions exactly as a compute dispatch group, except that it has special outputs and may take a task payload as input. Mesh dispatch groups invoked by different task shader workgroups cannot interact. -* Each workgroup within the mesh shader dispatch group can output vertices and primitives - * It determines how many vertices and primitives to write and then sets those vertices and primitives. - * Primitives have an indices field which determines the indices of the vertices of that primitive. The indices are based on the output of that mesh shader workgroup only; there is no sharing of vertices across workgroups (no vertex or index buffer equivalents). - * Primitives can then be culled by setting the appropriate builtin - * Each vertex output functions exactly as the output from a vertex shader would - * There can also be per-primitive outputs passed to fragment shaders; these are not interpolated or based on the vertices of the primitive in any way. -* Once all of the primitives are written, those that weren't culled are are rasterized. From this point forward, the only difference from a standard render pipeline is that there may be some per-primitive inputs passed to fragment shaders. + +With the current pipeline set to a mesh pipeline, a draw command like +`render_pass.draw_mesh_tasks(x, y, z)` takes the following steps: + +* If the pipeline has a task shader stage: + + * Dispatch a grid of task shader workgroups, where `x`, `y`, and `z` give + the number of workgroups along each axis of the grid. Each task shader + workgroup produces a mesh shader workgroup grid size `(mx, my, mz)` and a + task payload value `mp`. + + * For each task shader workgroup, dispatch a grid of mesh shader workgroups, + where `mx`, `my`, and `mz` give the number of workgroups along each axis + of the grid. Pass `mp` to each of these workgroup's mesh shader + invocations. + +* Alternatively, if the pipeline does not have a task shader stage: + + * Dispatch a single grid of mesh shader workgroups, where `x`, `y`, and `z` + give the number of workgroups along each axis of the grid. These mesh + shaders receive no task payload value. + +* Each mesh shader workgroup produces a list of output vertices, and a list of + primitives built from those vertices. The workgroup can supply per-primitive + values as well, if needed. Each primitive selects its vertices by index, like + an indexed draw call, from among the vertices generated by this workgroup. + + Unlike a grid of ordinary compute shader workgroups collaborating to build + vertex and index data in common storage buffers, the vertices and primitives + produced by a mesh shader workgroup are entirely private to that workgroup, + and are not accessible by other workgroups. + +* Primitives produced by a mesh shader workgroup can have a culling flag. If a + primitive's culling flag is false, it is skipped during rasterization. + +* The primitives produced by all mesh shader workgroups are then rasterized in + the usual way, with each fragment shader invocation handling one pixel. + + Attributes from the vertices produced by the mesh shader workgroup are + provided to the fragment shader with interpolation applied as appropriate. + + If the mesh shader workgroup supplied per-primitive values, these are + available to each primitive's fragment shader invocations. Per-primitive + values are never interpolated; fragment shaders simply receive the values + the mesh shader workgroup associated with their primitive. ## `wgpu` API @@ -99,34 +133,57 @@ Using any of these features in a `wgsl` program will require adding the `enable Two new shader stages will be added to `WGSL`. Fragment shaders are also modified slightly. Both task shaders and mesh shaders are allowed to use any compute-specific functionality, such as subgroup operations. ### Task shader -This shader stage can be selected by marking a function with `@task`. Task shaders must return a `vec3` as their output type. Similar to compute shaders, task shaders run in a workgroup. The output must be uniform across all threads in a workgroup. -The output of this determines how many workgroups of mesh shaders will be dispatched. Once dispatched, global id variables will be local to the task shader workgroup dispatch, and mesh shaders won't know the position of their dispatch among all mesh shader dispatches unless this is passed through the payload. The output may be zero to skip dispatching any mesh shader workgroups for the task shader workgroup. +A function with the `@task` attribute is a **task shader entry point**. A mesh shader pipeline may optionally specify a task shader entry point, and if it does, mesh draw commands using that pipeline dispatch a **task shader grid** of workgroups running the task shader entry point. Like compute shader dispatches, the three-component size passed to `draw_mesh_tasks`, or drawn from the indirect buffer for its indirect variants, specifies the size of the task shader grid as the number of workgroups along each of the grid's three axes. -Task shaders must be marked with `@payload(someVar)`, where `someVar` is global variable declared like `var someVar: `. Task shaders may use `someVar` as if it is a read-write workgroup storage variable. This payload is passed to the mesh shader workgroup that is invoked. +A task shader entry point must have a `@workgroup_size` attribute, meeting the same requirements as one appearing on a compute shader entry point. + +A task shader entry point must return a `vec3` value. The return value of each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) is taken as the size of a **mesh shader grid** to dispatch, measured in workgroups. (If the task shader entry point returns `vec3(0, 0, 0)`, then no mesh shaders are dispatched.) Mesh shader grids are described in the next section. + +If a task shader entry point has a `@payload(G)` property, then `G` must be the name of a global variable in the `task_payload` address space. Each task shader workgroup has its own instance of this variable, visible to all invocations in the workgroup. Whatever value the workgroup collectively stores in that global variable becomes the **task payload**, and is provided to all invocations in the mesh shader grid dispatched for the workgroup. + +Each task shader workgroup dispatches an independent mesh shader grid: in mesh shader invocations, `@builtin` values like `workgroup_id` and `global_invocation_id` describe the position of the workgroup and invocation within that grid; +and `@builtin(num_workgroups)` matches the task shader workgroup's return value. Mesh shaders dispatched for other task shader workgroups are not included in the count. If it is necessary for a mesh shader to know which task shader workgroup dispatched it, the task shader can include its own workgroup id in the task payload. ### Mesh shader -This shader stage can be selected by marking a function with `@mesh`. Mesh shaders must not return anything. -Mesh shaders can be marked with `@payload(someVar)` similar to task shaders. Unlike task shaders, this is optional, and mesh shaders cannot write to this memory. Declaring `@payload` in a pipeline with no task shader or in a task shader with an `@payload` that is statically sized and differently than the mesh shader payload is illegal. The `@payload` attribute can only be ignored in pipelines that don't have a task shader. +A function with the `@mesh` attribute is a **mesh shader entry point**. Mesh shaders must not return anything. + +Like compute shaders, mesh shaders are invoked in a grid of workgroups, called a **mesh shader grid**. If the mesh shader pipeline has a task shader, then each task shader workgroup determines the size of a mesh shader grid to be dispatched, as described above. Otherwise, the three-component size passed to `draw_mesh_tasks`, or drawn from the indirect buffer for its indirect variants, specifies the size of the mesh shader grid directly, as the number of workgroups along each of the grid's three axes. + +A mesh shader entry point must have a `@workgroup_size` attribute, meeting the same requirements as one appearing on a compute shader entry point. + +If the mesh shader pipeline has a task shader entry point with a `@payload(G)` attribute, then the pipeline's mesh shader entry point must also have a `@payload(G)` attribute, naming the same variable. Mesh shader invocations can read, but not write, this variable, which is initialized to whatever value was written to it by the task shader workgroup that dispatched this mesh shader grid. + +If the mesh shader pipeline does not have a task shader entry point, or the task shader entry point does not have a `@payload(G)` attribute, then the mesh shader entry point must not have any `@payload` attribute. + +A mesh shader entry point must have the following attributes: + +- `@vertex_output(V, NV)`: This indicates that the mesh shader workgroup will generate at most `NV` vertex values, each of type `V`. + +- `@primitive_output(P, NP)`: This indicates that the mesh shader workgroup will generate at most `NP` primitives, each of type `P`. + +Each mesh shader entry point invocation must call the `setMeshOutputs(numVertices: u32, numPrimitives: u32)` builtin function exactly once, in uniform control flow. The values passed by each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) determine how many vertices (values of type `V`) and primitives (values of type `P`) the workgroup must produce. This call essentially establishes two implicit arrays of vertex and primitive values, shared across the workgroup, for invocations to populate. + +The `numVertices` and `numPrimitives` arguments must be no greater than `NV` and `NP` from the `@vertex_output` and `@primitive_output` attributes. -Mesh shaders must be marked with `@vertex_output(OutputType, numOutputs)`, where `numOutputs` is the maximum number of vertices to be output by a mesh shader, and `OutputType` is the data associated with vertices, similar to a standard vertex shader output, and must be a struct. +To produce vertex data, the workgroup as a whole must make `numVertices` calls to the `setVertex(i: u32, vertex: V)` builtin function. This establishes `vertex` as the value of the `i`'th vertex. `V` is the type given in the `@vertex_output` attribute. `V` must meet the same requirements as a struct type returned by a `@vertex` entry point: all members must have either `@builtin` or `@location` attributes, there must be a '@builtin(position)`, and so on. An invocation may only call `setVertex` after its call to `setMeshOutputs`. -Mesh shaders must also be marked with `@primitive_output(OutputType, numOutputs)`, which is similar to `@vertex_output` except it describes the primitive outputs. +To produce primitives, the workgroup as a whole must make `numPrimitives` calls to the `setPrimitive(i: u32, primitive: P)` builtin function. This establishes `primitive` as the value of the `i`'th primitive. `P` is the type given in the `@primitive_output` attribute. `P` must be a struct type, every member of which either has a `@location` or `@builtin` attribute. The following `@builtin` attributes are allowed: -### Mesh shader outputs +- `triangle_indices`, `line_indices`, or `point_index`: The annotated member must be of type `vec3`, `vec2`, or `u32`. -Vertex outputs from mesh shaders function identically to outputs of vertex shaders, and as such must have a field with `@builtin(position)`. + The member's components are indices (or, its value is an index) into the list of vertices generated by this workgroup, identifying the vertices of the primitive to be drawn. These indices must be less than the value of `numVertices` passed to `setMeshOutputs`. -Primitive outputs from mesh shaders have some additional builtins they can set. These include `@builtin(cull_primitive)`, which must be a boolean value. If this is set to true, then the primitive is skipped during rendering. All non-builtin primitive outputs must be decorated with `@per_primitive`. + The type `P` must contain exactly one member with one of these attributes, determining what sort of primitives the mesh shader generates. -Mesh shader primitive outputs must also specify exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`. This determines the output topology of the mesh shader, and must match the output topology of the pipeline descriptor the mesh shader is used with. These must be of type `vec3`, `vec2`, and `u32` respectively. When setting this, each of the indices must be less than the number of vertices declared in `setMeshOutputs`. +- `cull_primitive`: The annotated member must be of type `bool`. If it is true, then the primitive is skipped during rendering. -Additionally, the `@location` attributes from the vertex and primitive outputs can't overlap. +Every member of `P` with a `@location` attribute must either have a `@per_primitive` attribute, or be part of a struct type that appears in the primitive data as a struct member with the `@per_primitive` attribute. -Before exiting, the mesh shader must call `setMeshOutputs(numVertices: u32, numIndices: u32)`, which declares the number of vertices and indices that will be written to. These must be less than the corresponding maximums set in `@vertex_output` and `@primitive_output`. The mesh shader must then write to exactly this range of vertices and primitives. A varying member with `@per_primitive` cannot be used in function interfaces except as a primitive output for mesh shaders or as input for fragment shaders. +The `@location` attributes of `P` and `V` must not overlap, since they are merged to produce the user-defined inputs to the fragment shader. -The mesh shader can write to vertices using the `setVertex(idx: u32, vertex: VertexOutput)` where `VertexOutput` is replaced with the vertex type declared in `@vertex_output`, and `idx` is the index of the vertex to write. Similarly, the mesh shader can write to vertices using `setPrimitive(idx: u32, primitive: PrimitiveOutput)`. These can be written to multiple times, however unsynchronized writes are undefined behavior. The primitives and indices are shared across the entire mesh shader workgroup. +It is possible to write to the same vertex or primitive index repeatedly. Since the implicit arrays written by `setVertex` and `setPrimitive` are shared by the workgroup, data races on writes to the same index for a given type are undefined behavior. ### Fragment shader @@ -210,4 +267,4 @@ fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocati fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { return vertex.color * primitive.colorMask; } -``` \ No newline at end of file +``` diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 2856872db27..94159ae7bf6 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -320,14 +320,21 @@ pub enum ConservativeDepth { #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -#[allow(missing_docs)] // The names are self evident pub enum ShaderStage { + /// A vertex shader, in a render pipeline. Vertex, - Fragment, - Compute, - // Mesh shader stages + + /// A task shader, in a mesh render pipeline. Task, + + /// A mesh shader, in a mesh render pipeline. Mesh, + + /// A fragment shader, in a render pipeline. + Fragment, + + /// Compute pipeline shader. + Compute, } impl ShaderStage { @@ -964,6 +971,9 @@ pub enum Binding { /// Indexed location. /// + /// This is a value passed to a [`Fragment`] shader from a [`Vertex`] or + /// [`Mesh`] shader. + /// /// Values passed from the [`Vertex`] stage to the [`Fragment`] stage must /// have their `interpolation` defaulted (i.e. not `None`) by the front end /// as appropriate for that language. @@ -977,6 +987,7 @@ pub enum Binding { /// interpolation must be `Flat`. /// /// [`Vertex`]: crate::ShaderStage::Vertex + /// [`Mesh`]: crate::ShaderStage::Mesh /// [`Fragment`]: crate::ShaderStage::Fragment Location { location: u32, @@ -1751,10 +1762,12 @@ pub enum Expression { query: Handle, committed: bool, }, + /// Result of a [`SubgroupBallot`] statement. /// /// [`SubgroupBallot`]: Statement::SubgroupBallot SubgroupBallotResult, + /// Result of a [`SubgroupCollectiveOperation`] or [`SubgroupGather`] statement. /// /// [`SubgroupCollectiveOperation`]: Statement::SubgroupCollectiveOperation @@ -2343,7 +2356,9 @@ pub struct EntryPoint { pub workgroup_size_overrides: Option<[Option>; 3]>, /// The entrance function. pub function: Function, - /// The information relating to a mesh shader + /// Information for [`Mesh`] shaders. + /// + /// [`Mesh`]: ShaderStage::Mesh pub mesh_info: Option, /// The unique global variable used as a task payload from task shader to mesh shader pub task_payload: Option>, @@ -2523,6 +2538,51 @@ pub struct DocComments { pub module: Vec, } +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MeshOutputTopology { + Points, + Lines, + Triangles, +} + +#[derive(Debug, Clone)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +#[allow(dead_code)] +pub struct MeshStageInfo { + pub topology: MeshOutputTopology, + pub max_vertices: u32, + pub max_vertices_override: Option>, + pub max_primitives: u32, + pub max_primitives_override: Option>, + pub vertex_output_type: Handle, + pub primitive_output_type: Handle, +} + +/// Mesh shader intrinsics +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "serialize", derive(Serialize))] +#[cfg_attr(feature = "deserialize", derive(Deserialize))] +#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] +pub enum MeshFunction { + SetMeshOutputs { + vertex_count: Handle, + primitive_count: Handle, + }, + SetVertex { + index: Handle, + value: Handle, + }, + SetPrimitive { + index: Handle, + value: Handle, + }, +} + /// Shader module. /// /// A module is a set of constants, global variables and functions, as well as @@ -2611,48 +2671,3 @@ pub struct Module { /// Doc comments. pub doc_comments: Option>, } - -#[derive(Debug, Clone, Copy)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum MeshOutputTopology { - Points, - Lines, - Triangles, -} - -#[derive(Debug, Clone)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -#[allow(dead_code)] -pub struct MeshStageInfo { - pub topology: MeshOutputTopology, - pub max_vertices: u32, - pub max_vertices_override: Option>, - pub max_primitives: u32, - pub max_primitives_override: Option>, - pub vertex_output_type: Handle, - pub primitive_output_type: Handle, -} - -/// Mesh shader intrinsics -#[derive(Debug, Clone, Copy)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum MeshFunction { - SetMeshOutputs { - vertex_count: Handle, - primitive_count: Handle, - }, - SetVertex { - index: Handle, - value: Handle, - }, - SetPrimitive { - index: Handle, - value: Handle, - }, -} diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 101ea046487..6d9fd7f6a08 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1151,7 +1151,7 @@ impl FunctionInfo { let _ = self.add_ref(index); let _ = self.add_ref(value); let ty = - self.expressions[value.index()].ty.clone().handle().ok_or( + self.expressions[value.index()].ty.handle().ok_or( FunctionError::InvalidMeshShaderOutputType(value).with_span(), )?; @@ -1244,14 +1244,15 @@ impl FunctionInfo { Ok(()) } + /// Update this function's mesh shader info, given that it calls `callee`. fn try_update_mesh_info( &mut self, - other: &FunctionMeshShaderInfo, + callee: &FunctionMeshShaderInfo, ) -> Result<(), WithSpan> { - if let &Some(ref other_vertex) = &other.vertex_type { + if let &Some(ref other_vertex) = &callee.vertex_type { self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?; } - if let &Some(ref other_primitive) = &other.vertex_type { + if let &Some(ref other_primitive) = &callee.vertex_type { self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?; } Ok(()) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 1fed0fda529..9f5cb278330 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -856,13 +856,14 @@ impl super::Validator { { return Err(EntryPointError::UnexpectedMeshShaderEntryResult.with_span()); } - // Cannot have any other built-ins or @location outputs as those are per-vertex or per-primitive - if ep.stage == crate::ShaderStage::Task - && (!result_built_ins.contains(&crate::BuiltIn::MeshTaskSize) - || result_built_ins.len() != 1 - || !self.location_mask.is_empty()) - { - return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); + // Task shaders must have a single `MeshTaskSize` output, and nothing else. + if ep.stage == crate::ShaderStage::Task { + let ok = result_built_ins.contains(&crate::BuiltIn::MeshTaskSize) + && result_built_ins.len() == 1 + && self.location_mask.is_empty(); + if !ok { + return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); + } } if !self.blend_src_mask.is_empty() { info.dual_source_blending = true; @@ -960,8 +961,10 @@ impl super::Validator { } } + // If this is a `Mesh` entry point, check its interface. if let &Some(ref mesh_info) = &ep.mesh_info { - // Technically it is allowed to not output anything + // Mesh shaders don't return any value. All their results are supplied through + // [`SetVertex`] and [`SetPrimitive`] calls. // TODO: check that only the allowed builtins are used here if let Some(used_vertex_type) = info.mesh_shader_info.vertex_type { if used_vertex_type.0 != mesh_info.vertex_output_type { diff --git a/wgpu/src/api/render_pass.rs b/wgpu/src/api/render_pass.rs index 9103264eed9..c73394db261 100644 --- a/wgpu/src/api/render_pass.rs +++ b/wgpu/src/api/render_pass.rs @@ -231,9 +231,9 @@ impl RenderPass<'_> { self.inner.draw_indexed(indices, base_vertex, instances); } - /// Draws using a mesh shader pipeline. + /// Draws using a mesh pipeline. /// - /// The current pipeline must be a mesh shader pipeline. + /// The current pipeline must be a mesh pipeline. /// /// If the current pipeline has a task shader, run it with an workgroup for /// every `vec3(i, j, k)` where `i`, `j`, and `k` are between `0` and @@ -290,7 +290,7 @@ impl RenderPass<'_> { .draw_indexed_indirect(&indirect_buffer.inner, indirect_offset); } - /// Draws using a mesh shader pipeline, + /// Draws using a mesh pipeline, /// based on the contents of the `indirect_buffer` /// /// This is like calling [`RenderPass::draw_mesh_tasks`] but the contents of the call are specified in the `indirect_buffer`. diff --git a/wgpu/src/api/render_pipeline.rs b/wgpu/src/api/render_pipeline.rs index be16d91f27a..35b74100d00 100644 --- a/wgpu/src/api/render_pipeline.rs +++ b/wgpu/src/api/render_pipeline.rs @@ -152,13 +152,15 @@ static_assertions::assert_impl_all!(FragmentState<'_>: Send, Sync); pub struct TaskState<'a> { /// The compiled shader module for this stage. pub module: &'a ShaderModule, - /// The name of the entry point in the compiled shader to use. + + /// The name of the task shader entry point in the shader module to use. /// - /// If [`Some`], there must be a vertex-stage shader entry point with this name in `module`. - /// Otherwise, expect exactly one vertex-stage entry point in `module`, which will be - /// selected. + /// If [`Some`], there must be a task shader entry point with the given name + /// in `module`. Otherwise, there must be exactly one task shader entry + /// point in `module`, which will be selected. pub entry_point: Option<&'a str>, - /// Advanced options for when this pipeline is compiled + + /// Advanced options for when this pipeline is compiled. /// /// This implements `Default`, and for most users can be set to `Default::default()` pub compilation_options: PipelineCompilationOptions<'a>, @@ -299,8 +301,15 @@ pub struct MeshPipelineDescriptor<'a> { /// /// [default layout]: https://www.w3.org/TR/webgpu/#default-pipeline-layout pub layout: Option<&'a PipelineLayout>, - /// The compiled task stage and its entry point. + + /// The mesh pipeline's task shader. + /// + /// If this is `None`, the mesh pipeline has no task shader. Executing a + /// mesh drawing command simply dispatches a grid of mesh shaders directly. + /// + /// [`draw_mesh_tasks`]: RenderPass::draw_mesh_tasks pub task: Option>, + /// The compiled mesh stage and its entry point pub mesh: MeshState<'a>, /// The properties of the pipeline at the primitive assembly and rasterization level. From 41b654ce811f9b88b95c83d7f9b8d88af48bff17 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 2 Oct 2025 00:28:47 -0700 Subject: [PATCH 33/89] mesh_shading.md: more tweaks --- docs/api-specs/mesh_shading.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index fcead0898bb..5990e63e871 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -151,19 +151,19 @@ A function with the `@mesh` attribute is a **mesh shader entry point**. Mesh sha Like compute shaders, mesh shaders are invoked in a grid of workgroups, called a **mesh shader grid**. If the mesh shader pipeline has a task shader, then each task shader workgroup determines the size of a mesh shader grid to be dispatched, as described above. Otherwise, the three-component size passed to `draw_mesh_tasks`, or drawn from the indirect buffer for its indirect variants, specifies the size of the mesh shader grid directly, as the number of workgroups along each of the grid's three axes. -A mesh shader entry point must have a `@workgroup_size` attribute, meeting the same requirements as one appearing on a compute shader entry point. - If the mesh shader pipeline has a task shader entry point with a `@payload(G)` attribute, then the pipeline's mesh shader entry point must also have a `@payload(G)` attribute, naming the same variable. Mesh shader invocations can read, but not write, this variable, which is initialized to whatever value was written to it by the task shader workgroup that dispatched this mesh shader grid. If the mesh shader pipeline does not have a task shader entry point, or the task shader entry point does not have a `@payload(G)` attribute, then the mesh shader entry point must not have any `@payload` attribute. A mesh shader entry point must have the following attributes: +- `@workgroup_size`: this has the same meaning as when it appears on a compute shader entry point. + - `@vertex_output(V, NV)`: This indicates that the mesh shader workgroup will generate at most `NV` vertex values, each of type `V`. - `@primitive_output(P, NP)`: This indicates that the mesh shader workgroup will generate at most `NP` primitives, each of type `P`. -Each mesh shader entry point invocation must call the `setMeshOutputs(numVertices: u32, numPrimitives: u32)` builtin function exactly once, in uniform control flow. The values passed by each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) determine how many vertices (values of type `V`) and primitives (values of type `P`) the workgroup must produce. This call essentially establishes two implicit arrays of vertex and primitive values, shared across the workgroup, for invocations to populate. +Before generating any results, each mesh shader entry point invocation must call the `setMeshOutputs(numVertices: u32, numPrimitives: u32)` builtin function exactly once, in uniform control flow. The values passed by each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) determine how many vertices (values of type `V`) and primitives (values of type `P`) the workgroup must produce. This call essentially establishes two implicit arrays of vertex and primitive values, shared across the workgroup, for invocations to populate. The `numVertices` and `numPrimitives` arguments must be no greater than `NV` and `NP` from the `@vertex_output` and `@primitive_output` attributes. From 33ed0a66f4baf09b9692631e8b36140daee238f5 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 2 Oct 2025 12:22:11 -0500 Subject: [PATCH 34/89] Ran cargo fmt --- naga/src/valid/analyzer.rs | 8 ++++---- naga/src/valid/interface.rs | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 6d9fd7f6a08..84390c3e5cd 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1150,10 +1150,10 @@ impl FunctionInfo { | &crate::MeshFunction::SetPrimitive { index, value } => { let _ = self.add_ref(index); let _ = self.add_ref(value); - let ty = - self.expressions[value.index()].ty.handle().ok_or( - FunctionError::InvalidMeshShaderOutputType(value).with_span(), - )?; + let ty = self.expressions[value.index()] + .ty + .handle() + .ok_or(FunctionError::InvalidMeshShaderOutputType(value).with_span())?; if matches!(func, crate::MeshFunction::SetVertex { .. }) { self.try_update_mesh_vertex_type(ty, value)?; diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 9f5cb278330..d40db4b45f8 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -859,8 +859,8 @@ impl super::Validator { // Task shaders must have a single `MeshTaskSize` output, and nothing else. if ep.stage == crate::ShaderStage::Task { let ok = result_built_ins.contains(&crate::BuiltIn::MeshTaskSize) - && result_built_ins.len() == 1 - && self.location_mask.is_empty(); + && result_built_ins.len() == 1 + && self.location_mask.is_empty(); if !ok { return Err(EntryPointError::WrongTaskShaderEntryResult.with_span()); } From 53ecb39b7171bfa13a153efc6233d3bb9e6e9adb Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 2 Oct 2025 13:03:04 -0500 Subject: [PATCH 35/89] Small tweaks --- docs/api-specs/mesh_shading.md | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index 5990e63e871..c3f80e79a67 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -138,9 +138,9 @@ A function with the `@task` attribute is a **task shader entry point**. A mesh s A task shader entry point must have a `@workgroup_size` attribute, meeting the same requirements as one appearing on a compute shader entry point. -A task shader entry point must return a `vec3` value. The return value of each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) is taken as the size of a **mesh shader grid** to dispatch, measured in workgroups. (If the task shader entry point returns `vec3(0, 0, 0)`, then no mesh shaders are dispatched.) Mesh shader grids are described in the next section. +A task shader entry point must also have a `@payload(G)` property, where `G` is the name of a global variable in the `task_payload` address space. Each task shader workgroup has its own instance of this variable, visible to all invocations in the workgroup. Whatever value the workgroup collectively stores in that global variable becomes the **task payload**, and is provided to all invocations in the mesh shader grid dispatched for the workgroup. -If a task shader entry point has a `@payload(G)` property, then `G` must be the name of a global variable in the `task_payload` address space. Each task shader workgroup has its own instance of this variable, visible to all invocations in the workgroup. Whatever value the workgroup collectively stores in that global variable becomes the **task payload**, and is provided to all invocations in the mesh shader grid dispatched for the workgroup. +A task shader entry point must return a `vec3` value. The return value of each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) is taken as the size of a **mesh shader grid** to dispatch, measured in workgroups. (If the task shader entry point returns `vec3(0, 0, 0)`, then no mesh shaders are dispatched.) Mesh shader grids are described in the next section. Each task shader workgroup dispatches an independent mesh shader grid: in mesh shader invocations, `@builtin` values like `workgroup_id` and `global_invocation_id` describe the position of the workgroup and invocation within that grid; and `@builtin(num_workgroups)` matches the task shader workgroup's return value. Mesh shaders dispatched for other task shader workgroups are not included in the count. If it is necessary for a mesh shader to know which task shader workgroup dispatched it, the task shader can include its own workgroup id in the task payload. @@ -151,9 +151,9 @@ A function with the `@mesh` attribute is a **mesh shader entry point**. Mesh sha Like compute shaders, mesh shaders are invoked in a grid of workgroups, called a **mesh shader grid**. If the mesh shader pipeline has a task shader, then each task shader workgroup determines the size of a mesh shader grid to be dispatched, as described above. Otherwise, the three-component size passed to `draw_mesh_tasks`, or drawn from the indirect buffer for its indirect variants, specifies the size of the mesh shader grid directly, as the number of workgroups along each of the grid's three axes. -If the mesh shader pipeline has a task shader entry point with a `@payload(G)` attribute, then the pipeline's mesh shader entry point must also have a `@payload(G)` attribute, naming the same variable. Mesh shader invocations can read, but not write, this variable, which is initialized to whatever value was written to it by the task shader workgroup that dispatched this mesh shader grid. +If the mesh shader pipeline has a task shader entry point, then the pipeline's mesh shader entry point must also have a `@payload(G)` attribute, naming the same variable, and the sizes must match. Mesh shader invocations can read, but not write, this variable, which is initialized to whatever value was written to it by the task shader workgroup that dispatched this mesh shader grid. -If the mesh shader pipeline does not have a task shader entry point, or the task shader entry point does not have a `@payload(G)` attribute, then the mesh shader entry point must not have any `@payload` attribute. +If the mesh shader pipeline does not have a task shader entry point, then the mesh shader entry point must not have any `@payload` attribute. A mesh shader entry point must have the following attributes: @@ -167,7 +167,7 @@ Before generating any results, each mesh shader entry point invocation must call The `numVertices` and `numPrimitives` arguments must be no greater than `NV` and `NP` from the `@vertex_output` and `@primitive_output` attributes. -To produce vertex data, the workgroup as a whole must make `numVertices` calls to the `setVertex(i: u32, vertex: V)` builtin function. This establishes `vertex` as the value of the `i`'th vertex. `V` is the type given in the `@vertex_output` attribute. `V` must meet the same requirements as a struct type returned by a `@vertex` entry point: all members must have either `@builtin` or `@location` attributes, there must be a '@builtin(position)`, and so on. An invocation may only call `setVertex` after its call to `setMeshOutputs`. +To produce vertex data, the workgroup as a whole must make `numVertices` calls to the `setVertex(i: u32, vertex: V)` builtin function. This establishes `vertex` as the value of the `i`'th vertex. `V` is the type given in the `@vertex_output` attribute. `V` must meet the same requirements as a struct type returned by a `@vertex` entry point: all members must have either `@builtin` or `@location` attributes, there must be a `@builtin(position)`, and so on. An invocation may only call `setVertex` after its call to `setMeshOutputs`. To produce primitives, the workgroup as a whole must make `numPrimitives` calls to the `setPrimitive(i: u32, primitive: P)` builtin function. This establishes `primitive` as the value of the `i`'th primitive. `P` is the type given in the `@primitive_output` attribute. `P` must be a struct type, every member of which either has a `@location` or `@builtin` attribute. The following `@builtin` attributes are allowed: From 86d1877de42defc7f45609b2edfcebc3b94b32f9 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 00:07:59 -0500 Subject: [PATCH 36/89] MESH SHADERS ON METAL LMAO HAHA YESS --- wgpu-hal/src/metal/device.rs | 2 +- wgpu-types/src/lib.rs | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index ec663ce78a3..13b6853529c 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1258,7 +1258,7 @@ impl crate::Device for super::Device { primitive_class, naga::ShaderStage::Task, )?; - descriptor.set_mesh_function(Some(&ts.function)); + descriptor.set_object_function(Some(&ts.function)); if self.shared.private_caps.supports_mutability { Self::set_buffers_mutability( descriptor.mesh_buffers().unwrap(), diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 9b9cd1d3e02..3f062606040 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -1045,14 +1045,12 @@ impl Limits { #[must_use] pub const fn using_recommended_minimum_mesh_shader_values(self) -> Self { Self { - // Literally just made this up as 256^2 or 2^16. - // My GPU supports 2^22, and compute shaders don't have this kind of limit. - // This very likely is never a real limiter + // I believe this is a common limit for apple devices. I'm not entirely sure why. max_task_workgroup_total_count: 1024, max_task_workgroups_per_dimension: 1024, // llvmpipe reports 0 multiview count, which just means no multiview is allowed max_mesh_multiview_count: 0, - // llvmpipe once again requires this to be 8. An RTX 3060 supports well over 1024. + // llvmpipe once again requires this to be <=8. An RTX 3060 supports well over 1024. max_mesh_output_layers: 8, ..self } From 00c19fc39c515b98ea93645a5bd03c27fa46c29f Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 00:28:57 -0500 Subject: [PATCH 37/89] Looked over all except command.rs --- wgpu-hal/src/metal/device.rs | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 13b6853529c..f4e7592e9d8 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1111,7 +1111,7 @@ impl crate::Device for super::Device { >, ) -> Result { objc::rc::autoreleasepool(|| { - let (primitive_class, _raw_primitive_type) = + let (primitive_class, raw_primitive_type) = conv::map_primitive_topology(desc.primitive.topology); let vs_info; @@ -1323,9 +1323,6 @@ impl crate::Device for super::Device { ), }; - let (primitive_class, raw_primitive_type) = - conv::map_primitive_topology(desc.primitive.topology); - // Fragment shader let fs_info = match desc.fragment_stage { Some(ref stage) => { @@ -1461,7 +1458,7 @@ impl crate::Device for super::Device { wgt::ShaderStages::TASK | wgt::ShaderStages::MESH | wgt::ShaderStages::FRAGMENT, - format!("new_render_pipeline_state: {e:?}"), + format!("new_mesh_render_pipeline_state: {e:?}"), ) })?, }; From effe0f41a8bacdc23fbea5aeed8ea34f053b65ab Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 01:14:02 -0500 Subject: [PATCH 38/89] =?UTF-8?q?(Almost)=20everything=20passes=20?= =?UTF-8?q?=F0=9F=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- wgpu-hal/src/metal/command.rs | 82 ++++++++++++++++++++++++++++++----- 1 file changed, 70 insertions(+), 12 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 0799a76ff28..19c4fe6ffeb 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -681,7 +681,15 @@ impl crate::CommandEncoder for super::CommandEncoder { let mut update_stage = |stage: naga::ShaderStage, render_encoder: Option<&metal::RenderCommandEncoder>, - compute_encoder: Option<&metal::ComputeCommandEncoder>| { + compute_encoder: Option<&metal::ComputeCommandEncoder>, + index_base: super::ResourceData| { + let resource_indices = match stage { + naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs, + naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs, + naga::ShaderStage::Task => &bg_info.base_resource_indices.ts, + naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms, + naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs, + }; let buffers = match stage { naga::ShaderStage::Vertex => group.counters.vs.buffers, naga::ShaderStage::Fragment => group.counters.fs.buffers, @@ -691,12 +699,12 @@ impl crate::CommandEncoder for super::CommandEncoder { }; let mut changes_sizes_buffer = false; for index in 0..buffers { - let buf = &group.buffers[index as usize]; + let buf = &group.buffers[(index_base.buffers + index) as usize]; let mut offset = buf.offset; if let Some(dyn_index) = buf.dynamic_index { offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; } - let a1 = (bg_info.base_resource_indices.vs.buffers + index) as u64; + let a1 = (resource_indices.buffers + index) as u64; let a2 = Some(buf.ptr.as_native()); let a3 = offset; match stage { @@ -760,8 +768,8 @@ impl crate::CommandEncoder for super::CommandEncoder { naga::ShaderStage::Compute => group.counters.cs.samplers, }; for index in 0..samplers { - let res = group.samplers[(group.counters.vs.samplers + index) as usize]; - let a1 = (bg_info.base_resource_indices.fs.samplers + index) as u64; + let res = group.samplers[(index_base.samplers + index) as usize]; + let a1 = (resource_indices.samplers + index) as u64; let a2 = Some(res.as_native()); match stage { naga::ShaderStage::Vertex => { @@ -790,8 +798,8 @@ impl crate::CommandEncoder for super::CommandEncoder { naga::ShaderStage::Compute => group.counters.cs.textures, }; for index in 0..textures { - let res = group.textures[index as usize]; - let a1 = (bg_info.base_resource_indices.vs.textures + index) as u64; + let res = group.textures[(index_base.textures + index) as usize]; + let a1 = (resource_indices.textures + index) as u64; let a2 = Some(res.as_native()); match stage { naga::ShaderStage::Vertex => { @@ -809,17 +817,67 @@ impl crate::CommandEncoder for super::CommandEncoder { } }; if let Some(encoder) = render_encoder { - update_stage(naga::ShaderStage::Vertex, Some(&encoder), None); - update_stage(naga::ShaderStage::Fragment, Some(&encoder), None); - update_stage(naga::ShaderStage::Task, Some(&encoder), None); - update_stage(naga::ShaderStage::Mesh, Some(&encoder), None); + update_stage( + naga::ShaderStage::Vertex, + Some(&encoder), + None, + // All zeros, as vs comes first + super::ResourceData::default(), + ); + update_stage( + naga::ShaderStage::Task, + Some(&encoder), + None, + // All zeros, as ts comes first + super::ResourceData::default(), + ); + update_stage( + naga::ShaderStage::Mesh, + Some(&encoder), + None, + group.counters.ts.clone(), + ); + update_stage( + naga::ShaderStage::Fragment, + Some(&encoder), + None, + super::ResourceData { + buffers: group.counters.vs.buffers + + group.counters.ts.buffers + + group.counters.ms.buffers, + textures: group.counters.vs.textures + + group.counters.ts.textures + + group.counters.ms.textures, + samplers: group.counters.vs.samplers + + group.counters.ts.samplers + + group.counters.ms.samplers, + }, + ); // Call useResource on all textures and buffers used indirectly so they are alive for (resource, use_info) in group.resources_to_use.iter() { encoder.use_resource_at(resource.as_native(), use_info.uses, use_info.stages); } } if let Some(encoder) = compute_encoder { - update_stage(naga::ShaderStage::Compute, None, Some(&encoder)); + update_stage( + naga::ShaderStage::Compute, + None, + Some(&encoder), + super::ResourceData { + buffers: group.counters.vs.buffers + + group.counters.ts.buffers + + group.counters.ms.buffers + + group.counters.fs.buffers, + textures: group.counters.vs.textures + + group.counters.ts.textures + + group.counters.ms.textures + + group.counters.fs.textures, + samplers: group.counters.vs.samplers + + group.counters.ts.samplers + + group.counters.ms.samplers + + group.counters.fs.samplers, + }, + ); // Call useResource on all textures and buffers used indirectly so they are alive for (resource, use_info) in group.resources_to_use.iter() { if !use_info.visible_in_compute { From 27d595e91faa0f8438642103ada78cdee157e0bf Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 01:20:34 -0500 Subject: [PATCH 39/89] Another quick fix (still 2 failing) --- wgpu-hal/src/metal/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index fda7e001906..27e357cf4c9 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -653,11 +653,15 @@ impl MultiStageData { iter::once(&self.vs) .chain(iter::once(&self.fs)) .chain(iter::once(&self.cs)) + .chain(iter::once(&self.ts)) + .chain(iter::once(&self.ms)) } fn iter_mut<'a>(&'a mut self) -> impl Iterator { iter::once(&mut self.vs) .chain(iter::once(&mut self.fs)) .chain(iter::once(&mut self.cs)) + .chain(iter::once(&mut self.ts)) + .chain(iter::once(&mut self.ms)) } } From fb7e24c0d45a7e2b7662efb94876fffab14e8a15 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 01:32:40 -0500 Subject: [PATCH 40/89] Update changelog.md --- CHANGELOG.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 011f59392e2..13f4c1f1f43 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -78,6 +78,9 @@ SamplerDescriptor { - Texture now has `from_custom`. By @R-Cramer4 in [#8315](https://github.com/gfx-rs/wgpu/pull/8315). +#### Metal +- Add support for mesh shaders. By @SupaMaggie70Incorporated in [#8139](https://github.com/gfx-rs/wgpu/pull/8139) + ### Bug Fixes #### General @@ -313,9 +316,6 @@ By @wumpf in [#8282](https://github.com/gfx-rs/wgpu/pull/8282), [#8285](https:// - Allow disabling waiting for latency waitable object. By @marcpabst in [#7400](https://github.com/gfx-rs/wgpu/pull/7400) - Add mesh shader support, including to the example. By @SupaMaggie70Incorporated in [#8110](https://github.com/gfx-rs/wgpu/issues/8110) -#### Metal -- Add support for mesh shaders. By @SupaMaggie70Incorporated in [#8139](https://github.com/gfx-rs/wgpu/pull/8139) - ### Bug Fixes #### General From 233d76f09a8320c291e1fed5bb1a725e512f76b9 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 02:05:15 -0500 Subject: [PATCH 41/89] Added little bit to explain something --- wgpu-hal/src/metal/command.rs | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 19c4fe6ffeb..0b76cd47bfc 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1029,7 +1029,33 @@ impl crate::CommandEncoder for super::CommandEncoder { ); } } - if pipeline.ms_info.is_some() { + if let Some(_ms_info) = &pipeline.ms_info { + // TODO: + // https://developer.apple.com/documentation/metal/mtlrendercommandencoder/setthreadgroupmemorylength(_:offset:index:) + // doesn't exist in current metal-rs version for some reason. Maybe put it off until objc2 arrives? + // Also, this will need to be added to the task stage + /* + // update the threadgroup memory sizes + while self.state.stage_infos.ms.work_group_memory_sizes.len() + < ms_info.work_group_memory_sizes.len() + { + self.state.stage_infos.ms.work_group_memory_sizes.push(0); + } + for (index, (cur_size, pipeline_size)) in self + .state + .stage_infos + .ms + .work_group_memory_sizes + .iter_mut() + .zip(ms_info.work_group_memory_sizes.iter()) + .enumerate() + { + let size = pipeline_size.next_multiple_of(16); + if *cur_size != size { + *cur_size = size; + encoder.set_threadgroup_memory_length(index as _, size as _); + } + }*/ if let Some((index, sizes)) = self .state .make_sizes_buffer_update(naga::ShaderStage::Mesh, &mut self.temp.binding_sizes) From 2d60810c4194d6fa000088d0661977bf4b534db6 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 02:34:07 -0500 Subject: [PATCH 42/89] More tiny incremental upgrades --- wgpu-hal/src/metal/command.rs | 5 ----- wgpu-hal/src/metal/device.rs | 4 ++-- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 0b76cd47bfc..d43f0aefed7 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1420,11 +1420,6 @@ impl crate::CommandEncoder for super::CommandEncoder { } // update the threadgroup memory sizes - while self.state.stage_infos.cs.work_group_memory_sizes.len() - < pipeline.cs_info.work_group_memory_sizes.len() - { - self.state.stage_infos.cs.work_group_memory_sizes.push(0); - } for (index, (cur_size, pipeline_size)) in self .state .stage_infos diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index f4e7592e9d8..a19a14fb074 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1272,7 +1272,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings: vec![], library: Some(ts.library), raw_wg_size: ts.wg_size, - work_group_memory_sizes: vec![], + work_group_memory_sizes: ts.wg_memory_sizes, }); } else { ts_info = None; @@ -1299,7 +1299,7 @@ impl crate::Device for super::Device { vertex_buffer_mappings: vec![], library: Some(ms.library), raw_wg_size: ms.wg_size, - work_group_memory_sizes: vec![], + work_group_memory_sizes: ms.wg_memory_sizes, }); } MetalGenericRenderPipelineDescriptor::Mesh(descriptor) From 204c542f4af2b0b9099c2fd8a471434738786924 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 03:50:51 -0500 Subject: [PATCH 43/89] Am I ... whatever its bedtime --- wgpu-hal/src/metal/command.rs | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index d43f0aefed7..3d9ffcf783f 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1403,6 +1403,8 @@ impl crate::CommandEncoder for super::CommandEncoder { } unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) { + let previous_sizes = + core::mem::take(&mut self.state.stage_infos.cs.work_group_memory_sizes); self.state.stage_infos.cs.assign_from(&pipeline.cs_info); let encoder = self.state.compute.as_ref().unwrap(); @@ -1420,19 +1422,23 @@ impl crate::CommandEncoder for super::CommandEncoder { } // update the threadgroup memory sizes - for (index, (cur_size, pipeline_size)) in self + for (i, current_size) in self .state .stage_infos .cs .work_group_memory_sizes .iter_mut() - .zip(pipeline.cs_info.work_group_memory_sizes.iter()) .enumerate() { - let size = pipeline_size.next_multiple_of(16); - if *cur_size != size { - *cur_size = size; - encoder.set_threadgroup_memory_length(index as _, size as _); + let prev_size = if i < previous_sizes.len() { + previous_sizes[i] + } else { + u32::MAX + }; + let size: u32 = current_size.next_multiple_of(16); + *current_size = size; + if size != prev_size { + encoder.set_threadgroup_memory_length(i as _, size as _); } } } From 7e6dee67972975b5420793f9dab5dd2ae3d7e81c Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 11 Oct 2025 14:22:50 -0500 Subject: [PATCH 44/89] Did some work --- wgpu-hal/src/metal/command.rs | 44 +++++++++++++++++------------------ 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 3d9ffcf783f..46cf52716c8 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -1017,27 +1017,10 @@ impl crate::CommandEncoder for super::CommandEncoder { ); } } - if pipeline.ts_info.is_some() { - if let Some((index, sizes)) = self - .state - .make_sizes_buffer_update(naga::ShaderStage::Task, &mut self.temp.binding_sizes) - { - encoder.set_object_bytes( - index as _, - (sizes.len() * WORD_SIZE) as u64, - sizes.as_ptr().cast(), - ); - } - } - if let Some(_ms_info) = &pipeline.ms_info { - // TODO: - // https://developer.apple.com/documentation/metal/mtlrendercommandencoder/setthreadgroupmemorylength(_:offset:index:) - // doesn't exist in current metal-rs version for some reason. Maybe put it off until objc2 arrives? - // Also, this will need to be added to the task stage - /* + if let Some(ts_info) = &pipeline.ts_info { // update the threadgroup memory sizes while self.state.stage_infos.ms.work_group_memory_sizes.len() - < ms_info.work_group_memory_sizes.len() + < ts_info.work_group_memory_sizes.len() { self.state.stage_infos.ms.work_group_memory_sizes.push(0); } @@ -1047,15 +1030,32 @@ impl crate::CommandEncoder for super::CommandEncoder { .ms .work_group_memory_sizes .iter_mut() - .zip(ms_info.work_group_memory_sizes.iter()) + .zip(ts_info.work_group_memory_sizes.iter()) .enumerate() { let size = pipeline_size.next_multiple_of(16); if *cur_size != size { *cur_size = size; - encoder.set_threadgroup_memory_length(index as _, size as _); + encoder.set_object_threadgroup_memory_length(index as _, size as _); } - }*/ + } + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Task, &mut self.temp.binding_sizes) + { + encoder.set_object_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr().cast(), + ); + } + } + if let Some(_ms_info) = &pipeline.ms_info { + // So there isn't an equivalent to + // https://developer.apple.com/documentation/metal/mtlrendercommandencoder/setthreadgroupmemorylength(_:offset:index:) + // for mesh shaders. This is probably because the CPU has less control over the dispatch sizes and such. Interestingly + // it also affects mesh shaders without task/object shaders, even though none of compute, task or fragment shaders + // behave this way. if let Some((index, sizes)) = self .state .make_sizes_buffer_update(naga::ShaderStage::Mesh, &mut self.temp.binding_sizes) From c4e3eefe014ff92e5a362e226b606dee5587a27c Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 15 Oct 2025 16:24:35 -0700 Subject: [PATCH 45/89] [naga] Move definition of `ShaderStage::compute_like` to `proc`. Move the definition of `naga::ShaderStage::compute_like` from `naga::ir` into `naga::proc`. We generally want ot keep methods out of `naga::ir`, since the IR itself is complicated enough already. --- naga/src/ir/mod.rs | 10 ---------- naga/src/proc/mod.rs | 10 ++++++++++ 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 94159ae7bf6..ad03f542d09 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -337,16 +337,6 @@ pub enum ShaderStage { Compute, } -impl ShaderStage { - // TODO: make more things respect this - pub const fn compute_like(self) -> bool { - match self { - Self::Vertex | Self::Fragment => false, - Self::Compute | Self::Task | Self::Mesh => true, - } - } -} - /// Addressing space of variables. #[derive(Clone, Copy, Debug, Hash, Eq, Ord, PartialEq, PartialOrd)] #[cfg_attr(feature = "serialize", derive(Serialize))] diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 5743e96a33e..7b90aa35512 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -631,6 +631,16 @@ pub fn flatten_compose<'arenas>( .take(size) } +impl super::ShaderStage { + // TODO: make more things respect this + pub const fn compute_like(self) -> bool { + match self { + Self::Vertex | Self::Fragment => false, + Self::Compute | Self::Task | Self::Mesh => true, + } + } +} + #[test] fn test_matrix_size() { let module = crate::Module::default(); From 8c9287d634f13fd47ec709406d83e952a54a496c Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 15 Oct 2025 17:37:59 -0700 Subject: [PATCH 46/89] Replace TODO comment with followup issue. --- naga/src/valid/interface.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index d40db4b45f8..f33e8fc8133 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -965,7 +965,6 @@ impl super::Validator { if let &Some(ref mesh_info) = &ep.mesh_info { // Mesh shaders don't return any value. All their results are supplied through // [`SetVertex`] and [`SetPrimitive`] calls. - // TODO: check that only the allowed builtins are used here if let Some(used_vertex_type) = info.mesh_shader_info.vertex_type { if used_vertex_type.0 != mesh_info.vertex_output_type { return Err(EntryPointError::WrongMeshOutputType From 3a8399de7ca78521606c6180cee5d217c4fc70e3 Mon Sep 17 00:00:00 2001 From: Inner Daemons <85136135+inner-daemons@users.noreply.github.com> Date: Wed, 15 Oct 2025 22:24:56 -0500 Subject: [PATCH 47/89] Update analyzer.rs Co-authored-by: Jim Blandy --- naga/src/valid/analyzer.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 84390c3e5cd..5ce80f20fb9 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1252,7 +1252,7 @@ impl FunctionInfo { if let &Some(ref other_vertex) = &callee.vertex_type { self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?; } - if let &Some(ref other_primitive) = &callee.vertex_type { + if let &Some(ref other_primitive) = &callee.primitive_type { self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?; } Ok(()) From d92fe673e65a91e5aee86539e16ce2248bbb5721 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 15 Oct 2025 23:08:59 -0500 Subject: [PATCH 48/89] Removed stuff in accordance with Jim's recommendation --- Cargo.lock | 4 ++-- naga/src/valid/interface.rs | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 992defc7d5f..d8c550ff796 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3997,8 +3997,8 @@ dependencies = [ "fastrand", "getrandom 0.3.3", "once_cell", - "rustix 1.0.8", - "windows-sys 0.61.0", + "rustix 1.1.2", + "windows-sys 0.52.0", ] [[package]] diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index f33e8fc8133..6aebd33a64e 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -851,9 +851,7 @@ impl super::Validator { { return Err(EntryPointError::MissingVertexOutputPosition.with_span()); } - if ep.stage == crate::ShaderStage::Mesh - && (!result_built_ins.is_empty() || !self.location_mask.is_empty()) - { + if ep.stage == crate::ShaderStage::Mesh { return Err(EntryPointError::UnexpectedMeshShaderEntryResult.with_span()); } // Task shaders must have a single `MeshTaskSize` output, and nothing else. From 2dc409028517c9da3bfb5852fca04f5b33296e6d Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 15 Oct 2025 20:14:08 -0700 Subject: [PATCH 49/89] minor changes for readability --- naga/src/valid/interface.rs | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 6aebd33a64e..550f200150a 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -393,13 +393,23 @@ impl VaryingContext<'_> { { return Err(VaryingError::NotIOShareableType(ty)); } - if !per_primitive && self.mesh_output_type == MeshOutputType::PrimitiveOutput { - return Err(VaryingError::MissingPerPrimitive); - } else if per_primitive - && ((self.stage != crate::ShaderStage::Fragment || self.output) - && self.mesh_output_type != MeshOutputType::PrimitiveOutput) - { - return Err(VaryingError::InvalidPerPrimitive); + + // Check whether `per_primitive` is appropriate for this stage and direction. + if self.mesh_output_type == MeshOutputType::PrimitiveOutput { + // All mesh shader `Location` outputs must be `per_primitive`. + if !per_primitive { + return Err(VaryingError::MissingPerPrimitive); + } + } else if self.stage == crate::ShaderStage::Fragment && !self.output { + // Fragment stage inputs may be `per_primitive`. We'll only + // know if these are correct when the whole mesh pipeline is + // created and we're paired with a specific mesh or vertex + // shader. + } else { + // All other `Location` bindings must not be `per_primitive`. + if per_primitive { + return Err(VaryingError::InvalidPerPrimitive); + } } if let Some(blend_src) = blend_src { @@ -959,18 +969,18 @@ impl super::Validator { } } - // If this is a `Mesh` entry point, check its interface. + // If this is a `Mesh` entry point, check the bindings of its vertex and primitive output types. if let &Some(ref mesh_info) = &ep.mesh_info { // Mesh shaders don't return any value. All their results are supplied through // [`SetVertex`] and [`SetPrimitive`] calls. - if let Some(used_vertex_type) = info.mesh_shader_info.vertex_type { - if used_vertex_type.0 != mesh_info.vertex_output_type { + if let Some((used_vertex_type, _)) = info.mesh_shader_info.vertex_type { + if used_vertex_type != mesh_info.vertex_output_type { return Err(EntryPointError::WrongMeshOutputType .with_span_handle(mesh_info.vertex_output_type, &module.types)); } } - if let Some(used_primitive_type) = info.mesh_shader_info.primitive_type { - if used_primitive_type.0 != mesh_info.primitive_output_type { + if let Some((used_primitive_type, _)) = info.mesh_shader_info.primitive_type { + if used_primitive_type != mesh_info.primitive_output_type { return Err(EntryPointError::WrongMeshOutputType .with_span_handle(mesh_info.primitive_output_type, &module.types)); } From 1ec734b3528b08c69bcf425f2c953274d5ea812a Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Wed, 15 Oct 2025 20:38:11 -0700 Subject: [PATCH 50/89] Pull mesh shader output type validation out into its own function. --- naga/src/valid/interface.rs | 113 ++++++++++++++++++++---------------- 1 file changed, 64 insertions(+), 49 deletions(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 550f200150a..891a87c5cbf 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -747,6 +747,58 @@ impl super::Validator { Ok(()) } + /// Validate the mesh shader output type `ty`, used as `mesh_output_type`. + fn validate_mesh_output_type( + &mut self, + ep: &crate::EntryPoint, + module: &crate::Module, + ty: Handle, + mesh_output_type: MeshOutputType, + ) -> Result<(), WithSpan> { + if !matches!(module.types[ty].inner, crate::TypeInner::Struct { .. }) { + return Err(EntryPointError::InvalidMeshOutputType.with_span_handle(ty, &module.types)); + } + let mut result_built_ins = crate::FastHashSet::default(); + let mut ctx = VaryingContext { + stage: ep.stage, + output: true, + types: &module.types, + type_info: &self.types, + location_mask: &mut self.location_mask, + blend_src_mask: &mut self.blend_src_mask, + built_ins: &mut result_built_ins, + capabilities: self.capabilities, + flags: self.flags, + mesh_output_type, + }; + ctx.validate(ep, ty, None) + .map_err_inner(|e| EntryPointError::Result(e).with_span())?; + if mesh_output_type == MeshOutputType::PrimitiveOutput { + let mut num_indices_builtins = 0; + if result_built_ins.contains(&crate::BuiltIn::PointIndex) { + num_indices_builtins += 1; + } + if result_built_ins.contains(&crate::BuiltIn::LineIndices) { + num_indices_builtins += 1; + } + if result_built_ins.contains(&crate::BuiltIn::TriangleIndices) { + num_indices_builtins += 1; + } + if num_indices_builtins != 1 { + return Err(EntryPointError::InvalidMeshPrimitiveOutputType + .with_span_handle(ty, &module.types)); + } + } else if mesh_output_type == MeshOutputType::VertexOutput + && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false }) + { + return Err( + EntryPointError::MissingVertexOutputPosition.with_span_handle(ty, &module.types) + ); + } + + Ok(()) + } + pub(super) fn validate_entry_point( &mut self, ep: &crate::EntryPoint, @@ -986,55 +1038,18 @@ impl super::Validator { } } - for (ty, mesh_output_type) in [ - (mesh_info.vertex_output_type, MeshOutputType::VertexOutput), - ( - mesh_info.primitive_output_type, - MeshOutputType::PrimitiveOutput, - ), - ] { - if !matches!(module.types[ty].inner, crate::TypeInner::Struct { .. }) { - return Err( - EntryPointError::InvalidMeshOutputType.with_span_handle(ty, &module.types) - ); - } - let mut result_built_ins = crate::FastHashSet::default(); - let mut ctx = VaryingContext { - stage: ep.stage, - output: true, - types: &module.types, - type_info: &self.types, - location_mask: &mut self.location_mask, - blend_src_mask: &mut self.blend_src_mask, - built_ins: &mut result_built_ins, - capabilities: self.capabilities, - flags: self.flags, - mesh_output_type, - }; - ctx.validate(ep, ty, None) - .map_err_inner(|e| EntryPointError::Result(e).with_span())?; - if mesh_output_type == MeshOutputType::PrimitiveOutput { - let mut num_indices_builtins = 0; - if result_built_ins.contains(&crate::BuiltIn::PointIndex) { - num_indices_builtins += 1; - } - if result_built_ins.contains(&crate::BuiltIn::LineIndices) { - num_indices_builtins += 1; - } - if result_built_ins.contains(&crate::BuiltIn::TriangleIndices) { - num_indices_builtins += 1; - } - if num_indices_builtins != 1 { - return Err(EntryPointError::InvalidMeshPrimitiveOutputType - .with_span_handle(ty, &module.types)); - } - } else if mesh_output_type == MeshOutputType::VertexOutput - && !result_built_ins.contains(&crate::BuiltIn::Position { invariant: false }) - { - return Err(EntryPointError::MissingVertexOutputPosition - .with_span_handle(ty, &module.types)); - } - } + self.validate_mesh_output_type( + ep, + module, + mesh_info.vertex_output_type, + MeshOutputType::VertexOutput, + )?; + self.validate_mesh_output_type( + ep, + module, + mesh_info.primitive_output_type, + MeshOutputType::PrimitiveOutput, + )?; } else if info.mesh_shader_info.vertex_type.is_some() || info.mesh_shader_info.primitive_type.is_some() { From 9ef0ed580e8cd21cba47f22ac7aad6b490339cff Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 16 Oct 2025 08:08:06 -0700 Subject: [PATCH 51/89] doc fixes --- naga/src/ir/mod.rs | 17 ++++++++++++----- naga/src/valid/analyzer.rs | 29 +++++++++++++++++++++++++++++ naga/src/valid/interface.rs | 15 ++++++++++----- 3 files changed, 51 insertions(+), 10 deletions(-) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index ad03f542d09..a8a5d220463 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -983,16 +983,23 @@ pub enum Binding { location: u32, interpolation: Option, sampling: Option, + /// Optional `blend_src` index used for dual source blending. /// See blend_src: Option, + /// Whether the binding is a per-primitive binding for use with mesh shaders. - /// This is required to match for mesh and fragment shader stages. - /// This is merely an extra attribute on a binding. You still may not have - /// a per-vertex and per-primitive input with the same location. /// - /// Per primitive values are not interpolated at all and are not dependent on the vertices - /// or pixel location. For example, it may be used to store a non-interpolated normal vector. + /// This must be `true` if this binding is a mesh shader primitive output, or such + /// an output's corresponding fragment shader input. It must be `false` otherwise. + /// + /// A stage's outputs must all have unique `location` numbers, regardless of + /// whether they are per-primitive; a mesh shader's per-vertex and per-primitive + /// outputs share the same location numbering space. + /// + /// Per primitive values are not interpolated at all and are not dependent on the + /// vertices or pixel location. For example, it may be used to store a + /// non-interpolated normal vector. per_primitive: bool, }, } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 5ce80f20fb9..bbf00508e00 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -91,7 +91,16 @@ struct FunctionUniformity { #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[cfg_attr(test, derive(PartialEq))] pub struct FunctionMeshShaderInfo { + /// The type of value this function passes to [`SetVertex`], and the + /// expression that first established it. + /// + /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex pub vertex_type: Option<(Handle, Handle)>, + + /// The type of value this function passes to [`SetPrimitive`], and the + /// expression that first established it. + /// + /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive pub primitive_type: Option<(Handle, Handle)>, } @@ -313,6 +322,7 @@ pub struct FunctionInfo { /// validation. diagnostic_filter_leaf: Option>, + /// Mesh shader info for this function and its callees. pub mesh_shader_info: FunctionMeshShaderInfo, } @@ -502,6 +512,7 @@ impl FunctionInfo { *mine |= *other; } + // Inherit mesh output types from our callees. self.try_update_mesh_info(&callee.mesh_shader_info)?; Ok(FunctionUniformity { @@ -1210,6 +1221,15 @@ impl FunctionInfo { Ok(combined_uniformity) } + /// Note the type of value passed to [`SetVertex`]. + /// + /// Record that this function passed a value of type `ty` as the second + /// argument to the [`SetVertex`] builtin function. All calls to + /// `SetVertex` must pass the same type, and this must match the + /// function's [`vertex_output_type`]. + /// + /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex + /// [`vertex_output_type`]: crate::ir::MeshStageInfo::vertex_output_type fn try_update_mesh_vertex_type( &mut self, ty: Handle, @@ -1227,6 +1247,15 @@ impl FunctionInfo { Ok(()) } + /// Note the type of value passed to [`SetPrimitive`]. + /// + /// Record that this function passed a value of type `ty` as the second + /// argument to the [`SetPrimitive`] builtin function. All calls to + /// `SetPrimitive` must pass the same type, and this must match the + /// function's [`primitive_output_type`]. + /// + /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive + /// [`primitive_output_type`]: crate::ir::MeshStageInfo::primitive_output_type fn try_update_mesh_primitive_type( &mut self, ty: Handle, diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 891a87c5cbf..5768f56e641 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -1021,7 +1021,8 @@ impl super::Validator { } } - // If this is a `Mesh` entry point, check the bindings of its vertex and primitive output types. + // If this is a `Mesh` entry point, check its vertex and primitive output types. + // We verified previously that only mesh shaders can have `mesh_info`. if let &Some(ref mesh_info) = &ep.mesh_info { // Mesh shaders don't return any value. All their results are supplied through // [`SetVertex`] and [`SetPrimitive`] calls. @@ -1050,10 +1051,14 @@ impl super::Validator { mesh_info.primitive_output_type, MeshOutputType::PrimitiveOutput, )?; - } else if info.mesh_shader_info.vertex_type.is_some() - || info.mesh_shader_info.primitive_type.is_some() - { - return Err(EntryPointError::UnexpectedMeshShaderOutput.with_span()); + } else { + // This is not a `Mesh` entry point, so ensure that it never tries to produce + // vertices or primitives. + if info.mesh_shader_info.vertex_type.is_some() + || info.mesh_shader_info.primitive_type.is_some() + { + return Err(EntryPointError::UnexpectedMeshShaderOutput.with_span()); + } } Ok(info) From 1173b0f578da4921a530f755f4cd85bb9b42cf62 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 16 Oct 2025 10:01:21 -0700 Subject: [PATCH 52/89] remove duplicated task payload validation --- naga/src/valid/interface.rs | 51 ++++++++++++++++++++++--------------- 1 file changed, 30 insertions(+), 21 deletions(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 5768f56e641..db6d800bd31 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -839,20 +839,38 @@ impl super::Validator { .validate_function(&ep.function, module, mod_info, true) .map_err(WithSpan::into_other)?; - if let Some(handle) = ep.task_payload { - if ep.stage != crate::ShaderStage::Task && ep.stage != crate::ShaderStage::Mesh { - return Err(EntryPointError::UnexpectedTaskPayload.with_span()); + // Validate the task shader payload. + match ep.stage { + // Task shaders must produce a payload. + crate::ShaderStage::Task => { + let Some(handle) = ep.task_payload else { + return Err(EntryPointError::ExpectedTaskPayload.with_span()); + }; + if module.global_variables[handle].space != crate::AddressSpace::TaskPayload { + return Err(EntryPointError::TaskPayloadWrongAddressSpace + .with_span_handle(handle, &module.global_variables)); + } + info.insert_global_use(GlobalUse::READ | GlobalUse::WRITE, handle); } - if module.global_variables[handle].space != crate::AddressSpace::TaskPayload { - return Err(EntryPointError::TaskPayloadWrongAddressSpace.with_span()); + + // Mesh shaders may accept a payload. + crate::ShaderStage::Mesh => { + if let Some(handle) = ep.task_payload { + if module.global_variables[handle].space != crate::AddressSpace::TaskPayload { + return Err(EntryPointError::TaskPayloadWrongAddressSpace + .with_span_handle(handle, &module.global_variables)); + } + info.insert_global_use(GlobalUse::READ, handle); + } + } + + // Other stages must not have a payload. + _ => { + if let Some(handle) = ep.task_payload { + return Err(EntryPointError::UnexpectedTaskPayload + .with_span_handle(handle, &module.global_variables)); + } } - // Make sure that this is always present in the outputted shader - let uses = if ep.stage == crate::ShaderStage::Mesh { - GlobalUse::READ - } else { - GlobalUse::READ | GlobalUse::WRITE - }; - info.insert_global_use(uses, handle); } { @@ -949,15 +967,6 @@ impl super::Validator { } } - if let Some(task_payload) = ep.task_payload { - if module.global_variables[task_payload].space != crate::AddressSpace::TaskPayload { - return Err(EntryPointError::TaskPayloadWrongAddressSpace - .with_span_handle(task_payload, &module.global_variables)); - } - } else if ep.stage == crate::ShaderStage::Task { - return Err(EntryPointError::ExpectedTaskPayload.with_span()); - } - self.ep_resource_bindings.clear(); for (var_handle, var) in module.global_variables.iter() { let usage = info[var_handle]; From 258e7e642ab414a318843e76877a12b9911bf72d Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 16 Oct 2025 15:43:01 -0500 Subject: [PATCH 53/89] Quick little changes --- naga/src/back/glsl/mod.rs | 2 +- naga/src/back/hlsl/writer.rs | 2 +- naga/src/back/mod.rs | 6 +++--- naga/src/back/pipeline_constants.rs | 4 ++-- naga/src/back/spv/writer.rs | 5 ++++- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 6376b39c58b..37bf318c4f8 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -1879,7 +1879,7 @@ impl<'a, W: Write> Writer<'a, W> { writeln!(self.out, ") {{")?; if self.options.zero_initialize_workgroup_memory - && ctx.ty.is_compute_entry_point(self.module) + && ctx.ty.is_compute_like_entry_point(self.module) { self.write_workgroup_variables_initialization(&ctx)?; } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 8d1aabded61..6f0ba814a52 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -1765,7 +1765,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { module: &Module, ) -> bool { self.options.zero_initialize_workgroup_memory - && func_ctx.ty.is_compute_entry_point(module) + && func_ctx.ty.is_compute_like_entry_point(module) && module.global_variables.iter().any(|(handle, var)| { !func_ctx.info[handle].is_empty() && var.space == crate::AddressSpace::WorkGroup }) diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 0d13d63dd9b..8be763234e7 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -139,11 +139,11 @@ pub enum FunctionType { } impl FunctionType { - /// Returns true if the function is an entry point for a compute shader. - pub fn is_compute_entry_point(&self, module: &crate::Module) -> bool { + /// Returns true if the function is an entry point for a compute-like shader. + pub fn is_compute_like_entry_point(&self, module: &crate::Module) -> bool { match *self { FunctionType::EntryPoint(index) => { - module.entry_points[index as usize].stage == crate::ShaderStage::Compute + module.entry_points[index as usize].stage.compute_like() } FunctionType::Function(_) => false, } diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index c009082a3c9..109cc591e74 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -309,13 +309,13 @@ fn process_mesh_shader_overrides( mesh_info.max_vertices = module .to_ctx() .eval_expr_to_u32(adjusted_global_expressions[r#override]) - .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)?; + .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?; } if let Some(r#override) = mesh_info.max_primitives_override { mesh_info.max_primitives = module .to_ctx() .eval_expr_to_u32(adjusted_global_expressions[r#override]) - .map_err(|_| PipelineConstantError::NegativeWorkgroupSize)?; + .map_err(|_| PipelineConstantError::NegativeMeshOutputMax)?; } } Ok(()) diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 85d575cb9af..1e207fc7002 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -1094,7 +1094,10 @@ impl Writer { super::ZeroInitializeWorkgroupMemoryMode::Polyfill, Some( ref mut interface @ FunctionInterface { - stage: crate::ShaderStage::Compute, + stage: + crate::ShaderStage::Compute + | crate::ShaderStage::Mesh + | crate::ShaderStage::Task, .. }, ), From 8885c5def0e8b23150bbc54a7e9baa41b2ff2f28 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 16 Oct 2025 15:49:24 -0500 Subject: [PATCH 54/89] Another quick fix --- naga/src/valid/interface.rs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index db6d800bd31..8346e1e4ba9 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -405,11 +405,9 @@ impl VaryingContext<'_> { // know if these are correct when the whole mesh pipeline is // created and we're paired with a specific mesh or vertex // shader. - } else { + } else if per_primitive { // All other `Location` bindings must not be `per_primitive`. - if per_primitive { - return Err(VaryingError::InvalidPerPrimitive); - } + return Err(VaryingError::InvalidPerPrimitive); } if let Some(blend_src) = blend_src { From 1cc3e8516f691cd166c4906c993c60a0e02af9c0 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 16 Oct 2025 16:01:35 -0500 Subject: [PATCH 55/89] Quick fix --- naga/src/valid/interface.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 8346e1e4ba9..6d122a8b2c5 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -473,10 +473,9 @@ impl VaryingContext<'_> { } } - // TODO: update this to reflect the fact that per-primitive outputs aren't interpolated for fragment and mesh stages let needs_interpolation = match self.stage { crate::ShaderStage::Vertex => self.output, - crate::ShaderStage::Fragment => !self.output, + crate::ShaderStage::Fragment => !self.output && !per_primitive, crate::ShaderStage::Compute | crate::ShaderStage::Task => false, crate::ShaderStage::Mesh => self.output, }; From 3be2c256ce3f5330ea2cf200b88ca4f2c9b34700 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 16 Oct 2025 16:04:36 -0500 Subject: [PATCH 56/89] Removed unnecessary TODO statement --- naga/src/valid/function.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 0ae2ffdb54f..4dca52b4687 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1577,7 +1577,8 @@ impl super::Validator { crate::MeshFunction::SetVertex { index, value: _ } | crate::MeshFunction::SetPrimitive { index, value: _ } => { ensure_u32(index)?; - // TODO: ensure it is correct for the value + // Value is validated elsewhere (since the value type isn't known ahead of time but must match for a function + // and all functions it calls) } } } From 21d3cc703c127b40e52175c43d4d0110d975353b Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 16 Oct 2025 16:05:16 -0500 Subject: [PATCH 57/89] A --- naga/src/valid/function.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 4dca52b4687..4caa6ffc451 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1577,8 +1577,8 @@ impl super::Validator { crate::MeshFunction::SetVertex { index, value: _ } | crate::MeshFunction::SetPrimitive { index, value: _ } => { ensure_u32(index)?; - // Value is validated elsewhere (since the value type isn't known ahead of time but must match for a function - // and all functions it calls) + // Value is validated elsewhere (since the value type isn't known ahead of time but must match for all calls + // in a function or the function's called functions) } } } From d5c11d3b594a5aa8cdaea5f9c73934a3ba59f1c7 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 16 Oct 2025 16:09:59 -0500 Subject: [PATCH 58/89] Tried to be more expressive --- naga/src/valid/function.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 4caa6ffc451..0216c6ef7f6 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -217,7 +217,7 @@ pub enum FunctionError { EmitResult(Handle), #[error("Expression not visited by the appropriate statement")] UnvisitedExpression(Handle), - #[error("Expression {0:?} should be u32, but isn't")] + #[error("Expression {0:?} in mesh shader intrinsic call should be `u32` (is the expression a signed integer?)")] InvalidMeshFunctionCall(Handle), #[error("Mesh output types differ from {0:?} to {1:?}")] ConflictingMeshOutputTypes(Handle, Handle), From e7faff660c927c075b409ada6c0b7c217ba77fe2 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 16 Oct 2025 20:36:59 -0500 Subject: [PATCH 59/89] Made functions only work in mesh shader entry points --- naga/src/valid/analyzer.rs | 56 ++++++++++++++++++++------------------ 1 file changed, 29 insertions(+), 27 deletions(-) diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index bbf00508e00..14554573c9f 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1147,34 +1147,36 @@ impl FunctionInfo { } FunctionUniformity::new() } - S::MeshFunction(func) => match &func { - // TODO: double check all of this uniformity stuff. I frankly don't fully understand all of it. - &crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - } => { - let _ = self.add_ref(vertex_count); - let _ = self.add_ref(primitive_count); - FunctionUniformity::new() - } - &crate::MeshFunction::SetVertex { index, value } - | &crate::MeshFunction::SetPrimitive { index, value } => { - let _ = self.add_ref(index); - let _ = self.add_ref(value); - let ty = self.expressions[value.index()] - .ty - .handle() - .ok_or(FunctionError::InvalidMeshShaderOutputType(value).with_span())?; - - if matches!(func, crate::MeshFunction::SetVertex { .. }) { - self.try_update_mesh_vertex_type(ty, value)?; - } else { - self.try_update_mesh_primitive_type(ty, value)?; - }; - - FunctionUniformity::new() + S::MeshFunction(func) => { + self.available_stages |= ShaderStages::MESH; + match &func { + // TODO: double check all of this uniformity stuff. I frankly don't fully understand all of it. + &crate::MeshFunction::SetMeshOutputs { + vertex_count, + primitive_count, + } => { + let _ = self.add_ref(vertex_count); + let _ = self.add_ref(primitive_count); + FunctionUniformity::new() + } + &crate::MeshFunction::SetVertex { index, value } + | &crate::MeshFunction::SetPrimitive { index, value } => { + let _ = self.add_ref(index); + let _ = self.add_ref(value); + let ty = self.expressions[value.index()].ty.handle().ok_or( + FunctionError::InvalidMeshShaderOutputType(value).with_span(), + )?; + + if matches!(func, crate::MeshFunction::SetVertex { .. }) { + self.try_update_mesh_vertex_type(ty, value)?; + } else { + self.try_update_mesh_primitive_type(ty, value)?; + }; + + FunctionUniformity::new() + } } - }, + } S::SubgroupBallot { result: _, predicate, From 385535a8d0045fd8fec7e7a454924491824e6a83 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 16 Oct 2025 23:25:14 -0500 Subject: [PATCH 60/89] Various validation fix attempts --- naga/src/valid/handles.rs | 14 ++++++++++++++ naga/src/valid/interface.rs | 16 ++++++++++++++++ naga/src/valid/mod.rs | 4 +++- 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index a0153e9398c..adb9f355c11 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -233,6 +233,20 @@ impl super::Validator { validate_const_expr(size)?; } } + if let Some(task_payload) = entry_point.task_payload { + Self::validate_global_variable_handle(task_payload, global_variables)?; + } + if let Some(ref mesh_info) = entry_point.mesh_info { + validate_type(mesh_info.vertex_output_type)?; + validate_type(mesh_info.primitive_output_type)?; + for ov in mesh_info + .max_vertices_override + .iter() + .chain(mesh_info.max_primitives_override.iter()) + { + validate_const_expr(*ov)?; + } + } } for (function_handle, function) in functions.iter() { diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 6d122a8b2c5..04c5d99babb 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -98,6 +98,8 @@ pub enum VaryingError { InvalidPerPrimitive, #[error("Non-builtin members of a mesh primitive output struct must be decorated with `@per_primitive`")] MissingPerPrimitive, + #[error("The `MESH_SHADER` capability must be enabled to use per-primitive fragment inputs.")] + PerPrimitiveNotAllowed, } #[derive(Clone, Debug, thiserror::Error)] @@ -151,6 +153,10 @@ pub enum EntryPointError { InvalidMeshPrimitiveOutputType, #[error("Task shaders must declare a task payload output")] ExpectedTaskPayload, + #[error( + "The `MESH_SHADER` capability must be enabled to compile mesh shaders and task shaders." + )] + MeshShaderCapabilityDisabled, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -386,6 +392,9 @@ impl VaryingContext<'_> { blend_src, per_primitive, } => { + if per_primitive && !self.capabilities.contains(Capabilities::MESH_SHADER) { + return Err(VaryingError::PerPrimitiveNotAllowed); + } // Only IO-shareable types may be stored in locations. if !self.type_info[ty.index()] .flags @@ -802,6 +811,13 @@ impl super::Validator { module: &crate::Module, mod_info: &ModuleInfo, ) -> Result> { + if matches!( + ep.stage, + crate::ShaderStage::Task | crate::ShaderStage::Mesh + ) && !self.capabilities.contains(Capabilities::MESH_SHADER) + { + return Err(EntryPointError::MeshShaderCapabilityDisabled.with_span()); + } if ep.early_depth_test.is_some() { let required = Capabilities::EARLY_DEPTH_TEST; if !self.capabilities.contains(required) { diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index eb707bcb383..d47d878ed4e 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -83,7 +83,7 @@ bitflags::bitflags! { #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[derive(Clone, Copy, Debug, Eq, PartialEq)] - pub struct Capabilities: u32 { + pub struct Capabilities: u64 { /// Support for [`AddressSpace::PushConstant`][1]. /// /// [1]: crate::AddressSpace::PushConstant @@ -186,6 +186,8 @@ bitflags::bitflags! { /// Support for `quantizeToF16`, `pack2x16float`, and `unpack2x16float`, which store /// `f16`-precision values in `f32`s. const SHADER_FLOAT16_IN_FLOAT32 = 1 << 28; + /// Support for task shaders, mesh shaders, and per-primitive fragment inputs + const MESH_SHADER = 1 << 29; } } From c3f9acd8427e2961de5b68014a9a347f8fbdc415 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 17 Oct 2025 13:27:30 -0500 Subject: [PATCH 61/89] Undid capabilities resize --- naga/src/valid/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index d47d878ed4e..2460a46df4b 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -83,7 +83,7 @@ bitflags::bitflags! { #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] #[derive(Clone, Copy, Debug, Eq, PartialEq)] - pub struct Capabilities: u64 { + pub struct Capabilities: u32 { /// Support for [`AddressSpace::PushConstant`][1]. /// /// [1]: crate::AddressSpace::PushConstant From d15ba19aa097ecaf52bbf1496a64032e69d97738 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 17 Oct 2025 13:33:32 -0500 Subject: [PATCH 62/89] WGSL PR is up :) --- naga/src/front/wgsl/error.rs | 34 + naga/src/front/wgsl/lower/mod.rs | 220 ++- naga/src/front/wgsl/parse/ast.rs | 11 + naga/src/front/wgsl/parse/conv.rs | 7 + .../wgsl/parse/directive/enable_extension.rs | 9 + naga/src/front/wgsl/parse/mod.rs | 81 +- naga/tests/in/wgsl/mesh-shader.toml | 19 + naga/tests/in/wgsl/mesh-shader.wgsl | 71 + .../out/analysis/wgsl-mesh-shader.info.ron | 1211 +++++++++++++++++ .../tests/out/ir/wgsl-mesh-shader.compact.ron | 846 ++++++++++++ naga/tests/out/ir/wgsl-mesh-shader.ron | 846 ++++++++++++ 11 files changed, 3313 insertions(+), 42 deletions(-) create mode 100644 naga/tests/in/wgsl/mesh-shader.toml create mode 100644 naga/tests/in/wgsl/mesh-shader.wgsl create mode 100644 naga/tests/out/analysis/wgsl-mesh-shader.info.ron create mode 100644 naga/tests/out/ir/wgsl-mesh-shader.compact.ron create mode 100644 naga/tests/out/ir/wgsl-mesh-shader.ron diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 17dab5cb0ea..5fc69382447 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -406,6 +406,19 @@ pub(crate) enum Error<'a> { accept_span: Span, accept_type: String, }, + MissingMeshShaderInfo { + mesh_attribute_span: Span, + }, + OneMeshShaderAttribute { + attribute_span: Span, + }, + ExpectedGlobalVariable { + name_span: Span, + }, + MeshPrimitiveNoDefinedTopology { + attribute_span: Span, + struct_span: Span, + }, StructMemberTooLarge { member_name_span: Span, }, @@ -1370,6 +1383,27 @@ impl<'a> Error<'a> { ], notes: vec![], }, + Error::MissingMeshShaderInfo { mesh_attribute_span} => ParseError { + message: "mesh shader entry point is missing @vertex_output or @primitive_output".into(), + labels: vec![(mesh_attribute_span, "mesh shader entry declared here".into())], + notes: vec![], + }, + Error::OneMeshShaderAttribute { attribute_span } => ParseError { + message: "only one of @vertex_output or @primitive_output was given".into(), + labels: vec![(attribute_span, "only one of @vertex_output or @primitive_output is provided".into())], + notes: vec![], + }, + Error::ExpectedGlobalVariable { name_span } => ParseError { + message: "expected global variable".to_string(), + // TODO: I would like to also include the global declaration span + labels: vec![(name_span, "variable used here".into())], + notes: vec![], + }, + Error::MeshPrimitiveNoDefinedTopology { struct_span, attribute_span } => ParseError { + message: "mesh primitive struct must have exactly one of point indices, line indices, or triangle indices".to_string(), + labels: vec![(attribute_span, "primitive type declared here".into()), (struct_span, "primitive struct declared here".into())], + notes: vec![] + }, Error::StructMemberTooLarge { member_name_span } => ParseError { message: "struct member is too large".into(), labels: vec![(member_name_span, "this member exceeds the maximum size".into())], diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index 2066d7cf2c8..ef63e6aaea7 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1479,47 +1479,147 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { .collect(); if let Some(ref entry) = f.entry_point { - let workgroup_size_info = if let Some(workgroup_size) = entry.workgroup_size { - // TODO: replace with try_map once stabilized - let mut workgroup_size_out = [1; 3]; - let mut workgroup_size_overrides_out = [None; 3]; - for (i, size) in workgroup_size.into_iter().enumerate() { - if let Some(size_expr) = size { - match self.const_u32(size_expr, &mut ctx.as_const()) { - Ok(value) => { - workgroup_size_out[i] = value.0; - } - Err(err) => { - if let Error::ConstantEvaluatorError(ref ty, _) = *err { - match **ty { - proc::ConstantEvaluatorError::OverrideExpr => { - workgroup_size_overrides_out[i] = - Some(self.workgroup_size_override( - size_expr, - &mut ctx.as_override(), - )?); - } - _ => { - return Err(err); + let (workgroup_size, workgroup_size_overrides) = + if let Some(workgroup_size) = entry.workgroup_size { + // TODO: replace with try_map once stabilized + let mut workgroup_size_out = [1; 3]; + let mut workgroup_size_overrides_out = [None; 3]; + for (i, size) in workgroup_size.into_iter().enumerate() { + if let Some(size_expr) = size { + match self.const_u32(size_expr, &mut ctx.as_const()) { + Ok(value) => { + workgroup_size_out[i] = value.0; + } + Err(err) => { + if let Error::ConstantEvaluatorError(ref ty, _) = *err { + match **ty { + proc::ConstantEvaluatorError::OverrideExpr => { + workgroup_size_overrides_out[i] = + Some(self.workgroup_size_override( + size_expr, + &mut ctx.as_override(), + )?); + } + _ => { + return Err(err); + } } + } else { + return Err(err); } - } else { - return Err(err); } } } } - } - if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { - (workgroup_size_out, None) + if workgroup_size_overrides_out.iter().all(|x| x.is_none()) { + (workgroup_size_out, None) + } else { + (workgroup_size_out, Some(workgroup_size_overrides_out)) + } } else { - (workgroup_size_out, Some(workgroup_size_overrides_out)) + ([0; 3], None) + }; + + let mesh_info = if let Some(mesh_info) = entry.mesh_shader_info { + let mut const_u32 = |expr| match self.const_u32(expr, &mut ctx.as_const()) { + Ok(value) => Ok((value.0, None)), + Err(err) => { + if let Error::ConstantEvaluatorError(ref ty, _) = *err { + match **ty { + proc::ConstantEvaluatorError::OverrideExpr => Ok(( + 0, + Some( + // This is dubious but it seems the code isn't workgroup size specific + self.workgroup_size_override(expr, &mut ctx.as_override())?, + ), + )), + _ => Err(err), + } + } else { + Err(err) + } + } + }; + let (max_vertices, max_vertices_override) = const_u32(mesh_info.vertex_count)?; + let (max_primitives, max_primitives_override) = + const_u32(mesh_info.primitive_count)?; + let vertex_output_type = + self.resolve_ast_type(mesh_info.vertex_type.0, &mut ctx.as_const())?; + let primitive_output_type = + self.resolve_ast_type(mesh_info.primitive_type.0, &mut ctx.as_const())?; + + let mut topology = None; + let struct_span = ctx.module.types.get_span(primitive_output_type); + match &ctx.module.types[primitive_output_type].inner { + &ir::TypeInner::Struct { + ref members, + span: _, + } => { + for member in members { + let out_topology = match member.binding { + Some(ir::Binding::BuiltIn(ir::BuiltIn::TriangleIndices)) => { + Some(ir::MeshOutputTopology::Triangles) + } + Some(ir::Binding::BuiltIn(ir::BuiltIn::LineIndices)) => { + Some(ir::MeshOutputTopology::Lines) + } + _ => None, + }; + if out_topology.is_some() { + if topology.is_some() { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })); + } + topology = out_topology; + } + } + } + _ => { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })) + } } + let topology = if let Some(t) = topology { + t + } else { + return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { + attribute_span: mesh_info.primitive_type.1, + struct_span, + })); + }; + + Some(ir::MeshStageInfo { + max_vertices, + max_vertices_override, + max_primitives, + max_primitives_override, + + vertex_output_type, + primitive_output_type, + topology, + }) + } else { + None + }; + + let task_payload = if let Some((var_name, var_span)) = entry.task_payload { + Some(match ctx.globals.get(var_name) { + Some(&LoweredGlobalDecl::Var(handle)) => handle, + Some(_) => { + return Err(Box::new(Error::ExpectedGlobalVariable { + name_span: var_span, + })) + } + None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), + }) } else { - ([0; 3], None) + None }; - let (workgroup_size, workgroup_size_overrides) = workgroup_size_info; ctx.module.entry_points.push(ir::EntryPoint { name: f.name.name.to_string(), stage: entry.stage, @@ -1527,8 +1627,8 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { workgroup_size, workgroup_size_overrides, function, - mesh_info: None, - task_payload: None, + mesh_info, + task_payload, }); Ok(LoweredGlobalDecl::EntryPoint( ctx.module.entry_points.len() - 1, @@ -3132,6 +3232,59 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } + + "setMeshOutputs" | "setVertex" | "setPrimitive" => { + let mut args = ctx.prepare_args(arguments, 2, span); + let arg1 = args.next()?; + let arg2 = args.next()?; + args.finish()?; + + let mut cast_u32 = |arg| { + // Try to convert abstract values to the known argument types + let expr = self.expression_for_abstract(arg, ctx)?; + let goal_ty = + ctx.ensure_type_exists(ir::TypeInner::Scalar(ir::Scalar::U32)); + ctx.try_automatic_conversions( + expr, + &proc::TypeResolution::Handle(goal_ty), + ctx.ast_expressions.get_span(arg), + ) + }; + + let arg1 = cast_u32(arg1)?; + let arg2 = if function.name == "setMeshOutputs" { + cast_u32(arg2)? + } else { + self.expression(arg2, ctx)? + }; + + let rctx = ctx.runtime_expression_ctx(span)?; + + // Emit all previous expressions, even if not used directly + rctx.block + .extend(rctx.emitter.finish(&rctx.function.expressions)); + rctx.block.push( + crate::Statement::MeshFunction(match function.name { + "setMeshOutputs" => crate::MeshFunction::SetMeshOutputs { + vertex_count: arg1, + primitive_count: arg2, + }, + "setVertex" => crate::MeshFunction::SetVertex { + index: arg1, + value: arg2, + }, + "setPrimitive" => crate::MeshFunction::SetPrimitive { + index: arg1, + value: arg2, + }, + _ => unreachable!(), + }), + span, + ); + rctx.emitter.start(&rctx.function.expressions); + + return Ok(None); + } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) } @@ -4059,6 +4212,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, + per_primitive, }) => { let blend_src = if let Some(blend_src) = blend_src { Some(self.const_u32(blend_src, &mut ctx.as_const())?.0) @@ -4071,7 +4225,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { interpolation, sampling, blend_src, - per_primitive: false, + per_primitive, }; binding.apply_default_interpolation(&ctx.module.types[ty].inner); Some(binding) diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 345e9c4c486..49ecddfdee5 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -128,6 +128,16 @@ pub struct EntryPoint<'a> { pub stage: crate::ShaderStage, pub early_depth_test: Option, pub workgroup_size: Option<[Option>>; 3]>, + pub mesh_shader_info: Option>, + pub task_payload: Option<(&'a str, Span)>, +} + +#[derive(Debug, Clone, Copy)] +pub struct EntryPointMeshShaderInfo<'a> { + pub vertex_count: Handle>, + pub primitive_count: Handle>, + pub vertex_type: (Handle>, Span), + pub primitive_type: (Handle>, Span), } #[cfg(doc)] @@ -152,6 +162,7 @@ pub enum Binding<'a> { interpolation: Option, sampling: Option, blend_src: Option>>, + per_primitive: bool, }, } diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 30d0eb2d598..2bde001804e 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -16,6 +16,7 @@ pub fn map_address_space(word: &str, span: Span) -> Result<'_, crate::AddressSpa }), "push_constant" => Ok(crate::AddressSpace::PushConstant), "function" => Ok(crate::AddressSpace::Function), + "task_payload" => Ok(crate::AddressSpace::TaskPayload), _ => Err(Box::new(Error::UnknownAddressSpace(span))), } } @@ -49,6 +50,12 @@ pub fn map_built_in( "subgroup_id" => crate::BuiltIn::SubgroupId, "subgroup_size" => crate::BuiltIn::SubgroupSize, "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, + // mesh + "cull_primitive" => crate::BuiltIn::CullPrimitive, + "point_index" => crate::BuiltIn::PointIndex, + "line_indices" => crate::BuiltIn::LineIndices, + "triangle_indices" => crate::BuiltIn::TriangleIndices, + "mesh_task_size" => crate::BuiltIn::MeshTaskSize, _ => return Err(Box::new(Error::UnknownBuiltin(span))), }; match built_in { diff --git a/naga/src/front/wgsl/parse/directive/enable_extension.rs b/naga/src/front/wgsl/parse/directive/enable_extension.rs index 38d6d6719ca..d376c114ff0 100644 --- a/naga/src/front/wgsl/parse/directive/enable_extension.rs +++ b/naga/src/front/wgsl/parse/directive/enable_extension.rs @@ -10,6 +10,7 @@ use alloc::boxed::Box; /// Tracks the status of every enable-extension known to Naga. #[derive(Clone, Debug, Eq, PartialEq)] pub struct EnableExtensions { + mesh_shader: bool, dual_source_blending: bool, /// Whether `enable f16;` was written earlier in the shader module. f16: bool, @@ -19,6 +20,7 @@ pub struct EnableExtensions { impl EnableExtensions { pub(crate) const fn empty() -> Self { Self { + mesh_shader: false, f16: false, dual_source_blending: false, clip_distances: false, @@ -28,6 +30,7 @@ impl EnableExtensions { /// Add an enable-extension to the set requested by a module. pub(crate) fn add(&mut self, ext: ImplementedEnableExtension) { let field = match ext { + ImplementedEnableExtension::MeshShader => &mut self.mesh_shader, ImplementedEnableExtension::DualSourceBlending => &mut self.dual_source_blending, ImplementedEnableExtension::F16 => &mut self.f16, ImplementedEnableExtension::ClipDistances => &mut self.clip_distances, @@ -38,6 +41,7 @@ impl EnableExtensions { /// Query whether an enable-extension tracked here has been requested. pub(crate) const fn contains(&self, ext: ImplementedEnableExtension) -> bool { match ext { + ImplementedEnableExtension::MeshShader => self.mesh_shader, ImplementedEnableExtension::DualSourceBlending => self.dual_source_blending, ImplementedEnableExtension::F16 => self.f16, ImplementedEnableExtension::ClipDistances => self.clip_distances, @@ -70,6 +74,7 @@ impl EnableExtension { const F16: &'static str = "f16"; const CLIP_DISTANCES: &'static str = "clip_distances"; const DUAL_SOURCE_BLENDING: &'static str = "dual_source_blending"; + const MESH_SHADER: &'static str = "mesh_shading"; const SUBGROUPS: &'static str = "subgroups"; const PRIMITIVE_INDEX: &'static str = "primitive_index"; @@ -81,6 +86,7 @@ impl EnableExtension { Self::DUAL_SOURCE_BLENDING => { Self::Implemented(ImplementedEnableExtension::DualSourceBlending) } + Self::MESH_SHADER => Self::Implemented(ImplementedEnableExtension::MeshShader), Self::SUBGROUPS => Self::Unimplemented(UnimplementedEnableExtension::Subgroups), Self::PRIMITIVE_INDEX => { Self::Unimplemented(UnimplementedEnableExtension::PrimitiveIndex) @@ -93,6 +99,7 @@ impl EnableExtension { pub const fn to_ident(self) -> &'static str { match self { Self::Implemented(kind) => match kind { + ImplementedEnableExtension::MeshShader => Self::MESH_SHADER, ImplementedEnableExtension::DualSourceBlending => Self::DUAL_SOURCE_BLENDING, ImplementedEnableExtension::F16 => Self::F16, ImplementedEnableExtension::ClipDistances => Self::CLIP_DISTANCES, @@ -126,6 +133,8 @@ pub enum ImplementedEnableExtension { /// /// [`enable clip_distances;`]: https://www.w3.org/TR/WGSL/#extension-clip_distances ClipDistances, + /// Enables the `mesh_shader` extension, native only + MeshShader, } /// A variant of [`EnableExtension::Unimplemented`]. diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index c01ba4de30f..29376614d6e 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -178,6 +178,7 @@ struct BindingParser<'a> { sampling: ParsedAttribute, invariant: ParsedAttribute, blend_src: ParsedAttribute>>, + per_primitive: ParsedAttribute<()>, } impl<'a> BindingParser<'a> { @@ -238,6 +239,9 @@ impl<'a> BindingParser<'a> { lexer.skip(Token::Separator(',')); lexer.expect(Token::Paren(')'))?; } + "per_primitive" => { + self.per_primitive.set((), name_span)?; + } _ => return Err(Box::new(Error::UnknownAttribute(name_span))), } Ok(()) @@ -251,9 +255,10 @@ impl<'a> BindingParser<'a> { self.sampling.value, self.invariant.value.unwrap_or_default(), self.blend_src.value, + self.per_primitive.value, ) { - (None, None, None, None, false, None) => Ok(None), - (Some(location), None, interpolation, sampling, false, blend_src) => { + (None, None, None, None, false, None, None) => Ok(None), + (Some(location), None, interpolation, sampling, false, blend_src, per_primitive) => { // Before handing over the completed `Module`, we call // `apply_default_interpolation` to ensure that the interpolation and // sampling have been explicitly specified on all vertex shader output and fragment @@ -263,17 +268,18 @@ impl<'a> BindingParser<'a> { interpolation, sampling, blend_src, + per_primitive: per_primitive.is_some(), })) } - (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None) => { + (None, Some(crate::BuiltIn::Position { .. }), None, None, invariant, None, None) => { Ok(Some(ast::Binding::BuiltIn(crate::BuiltIn::Position { invariant, }))) } - (None, Some(built_in), None, None, false, None) => { + (None, Some(built_in), None, None, false, None, None) => { Ok(Some(ast::Binding::BuiltIn(built_in))) } - (_, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), + (_, _, _, _, _, _, _) => Err(Box::new(Error::InconsistentBinding(span))), } } } @@ -2790,12 +2796,15 @@ impl Parser { // read attributes let mut binding = None; let mut stage = ParsedAttribute::default(); - let mut compute_span = Span::new(0, 0); + let mut compute_like_span = Span::new(0, 0); let mut workgroup_size = ParsedAttribute::default(); let mut early_depth_test = ParsedAttribute::default(); let (mut bind_index, mut bind_group) = (ParsedAttribute::default(), ParsedAttribute::default()); let mut id = ParsedAttribute::default(); + let mut payload = ParsedAttribute::default(); + let mut vertex_output = ParsedAttribute::default(); + let mut primitive_output = ParsedAttribute::default(); let mut must_use: ParsedAttribute = ParsedAttribute::default(); @@ -2854,7 +2863,35 @@ impl Parser { } "compute" => { stage.set(ShaderStage::Compute, name_span)?; - compute_span = name_span; + compute_like_span = name_span; + } + "task" => { + stage.set(ShaderStage::Task, name_span)?; + compute_like_span = name_span; + } + "mesh" => { + stage.set(ShaderStage::Mesh, name_span)?; + compute_like_span = name_span; + } + "payload" => { + lexer.expect(Token::Paren('('))?; + payload.set(lexer.next_ident_with_span()?, name_span)?; + lexer.expect(Token::Paren(')'))?; + } + "vertex_output" | "primitive_output" => { + lexer.expect(Token::Paren('('))?; + let type_span = lexer.peek().1; + let r#type = self.type_decl(lexer, &mut ctx)?; + let type_span = lexer.span_from(type_span.to_range().unwrap().start); + lexer.expect(Token::Separator(','))?; + let max_output = self.general_expression(lexer, &mut ctx)?; + let end_span = lexer.expect_span(Token::Paren(')'))?; + let total_span = name_span.until(&end_span); + if name == "vertex_output" { + vertex_output.set((r#type, type_span, max_output), total_span)?; + } else if name == "primitive_output" { + primitive_output.set((r#type, type_span, max_output), total_span)?; + } } "workgroup_size" => { lexer.expect(Token::Paren('('))?; @@ -3020,13 +3057,39 @@ impl Parser { )?; Some(ast::GlobalDeclKind::Fn(ast::Function { entry_point: if let Some(stage) = stage.value { - if stage == ShaderStage::Compute && workgroup_size.value.is_none() { - return Err(Box::new(Error::MissingWorkgroupSize(compute_span))); + if stage.compute_like() && workgroup_size.value.is_none() { + return Err(Box::new(Error::MissingWorkgroupSize(compute_like_span))); } + if stage == ShaderStage::Mesh + && (vertex_output.value.is_none() || primitive_output.value.is_none()) + { + return Err(Box::new(Error::MissingMeshShaderInfo { + mesh_attribute_span: compute_like_span, + })); + } + let mesh_shader_info = match (vertex_output.value, primitive_output.value) { + (Some(vertex_output), Some(primitive_output)) => { + Some(ast::EntryPointMeshShaderInfo { + vertex_count: vertex_output.2, + primitive_count: primitive_output.2, + vertex_type: (vertex_output.0, vertex_output.1), + primitive_type: (primitive_output.0, primitive_output.1), + }) + } + (None, None) => None, + (Some(v), None) | (None, Some(v)) => { + return Err(Box::new(Error::OneMeshShaderAttribute { + attribute_span: v.1, + })) + } + }; + Some(ast::EntryPoint { stage, early_depth_test: early_depth_test.value, workgroup_size: workgroup_size.value, + mesh_shader_info, + task_payload: payload.value, }) } else { None diff --git a/naga/tests/in/wgsl/mesh-shader.toml b/naga/tests/in/wgsl/mesh-shader.toml new file mode 100644 index 00000000000..1f8b4e23baa --- /dev/null +++ b/naga/tests/in/wgsl/mesh-shader.toml @@ -0,0 +1,19 @@ +# Stolen from ray-query.toml + +god_mode = true +targets = "IR | ANALYSIS" + +[msl] +fake_missing_bindings = true +lang_version = [2, 4] +spirv_cross_compatibility = false +zero_initialize_workgroup_memory = false + +[hlsl] +shader_model = "V6_5" +fake_missing_bindings = true +zero_initialize_workgroup_memory = true + +[spv] +version = [1, 4] +capabilities = ["MeshShadingEXT"] diff --git a/naga/tests/in/wgsl/mesh-shader.wgsl b/naga/tests/in/wgsl/mesh-shader.wgsl new file mode 100644 index 00000000000..70fc2aec333 --- /dev/null +++ b/naga/tests/in/wgsl/mesh-shader.wgsl @@ -0,0 +1,71 @@ +enable mesh_shading; + +const positions = array( + vec4(0.,1.,0.,1.), + vec4(-1.,-1.,0.,1.), + vec4(1.,-1.,0.,1.) +); +const colors = array( + vec4(0.,1.,0.,1.), + vec4(0.,0.,1.,1.), + vec4(1.,0.,0.,1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} +@mesh +@payload(taskPayload) +@vertex_output(VertexOutput, 3) @primitive_output(PrimitiveOutput, 1) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + setMeshOutputs(3, 1); + workgroupData = 2.0; + var v: VertexOutput; + + v.position = positions[0]; + v.color = colors[0] * taskPayload.colorMask; + setVertex(0, v); + + v.position = positions[1]; + v.color = colors[1] * taskPayload.colorMask; + setVertex(1, v); + + v.position = positions[2]; + v.color = colors[2] * taskPayload.colorMask; + setVertex(2, v); + + var p: PrimitiveOutput; + p.index = vec3(0, 1, 2); + p.cull = !taskPayload.visible; + p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); + setPrimitive(0, p); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} diff --git a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron new file mode 100644 index 00000000000..208e0aac84e --- /dev/null +++ b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron @@ -0,0 +1,1211 @@ +( + type_flags: [ + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ], + functions: [], + entry_points: [ + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ | WRITE"), + ("WRITE"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 0, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 2, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Bool, + width: 1, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(6), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), + ), + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + ("READ"), + ("WRITE"), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(5), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 0, + assignable_global: None, + ty: Handle(6), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(1), + ty: Value(Pointer( + base: 0, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 9, + assignable_global: None, + ty: Value(Pointer( + base: 4, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 1, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(6), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 4, + assignable_global: None, + ty: Value(Pointer( + base: 7, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 6, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(6), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 2, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 3, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(0), + ty: Value(Pointer( + base: 2, + space: TaskPayload, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(2), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Pointer( + base: 1, + space: Function, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Value(Scalar(( + kind: Uint, + width: 4, + ))), + ), + ( + uniformity: ( + non_uniform_result: Some(61), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(7), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: Some((4, 24)), + primitive_type: Some((7, 79)), + ), + ), + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE | MESH | TASK"), + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [ + (""), + (""), + ], + expressions: [ + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(4), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(8), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(1), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ( + uniformity: ( + non_uniform_result: Some(0), + requirements: (""), + ), + ref_count: 1, + assignable_global: None, + ty: Handle(1), + ), + ], + sampling: [], + dual_source_blending: false, + diagnostic_filter_leaf: None, + mesh_shader_info: ( + vertex_type: None, + primitive_type: None, + ), + ), + ], + const_expression_types: [], +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-mesh-shader.compact.ron b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron new file mode 100644 index 00000000000..38c79cba451 --- /dev/null +++ b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron @@ -0,0 +1,846 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Quad, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("TaskPayload"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: None, + offset: 0, + ), + ( + name: Some("visible"), + ty: 2, + binding: None, + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("VertexOutput"), + inner: Struct( + members: [ + ( + name: Some("position"), + ty: 1, + binding: Some(BuiltIn(Position( + invariant: false, + ))), + offset: 0, + ), + ( + name: Some("color"), + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Uint, + width: 4, + ), + ), + ), + ( + name: Some("PrimitiveOutput"), + inner: Struct( + members: [ + ( + name: Some("index"), + ty: 6, + binding: Some(BuiltIn(TriangleIndices)), + offset: 0, + ), + ( + name: Some("cull"), + ty: 2, + binding: Some(BuiltIn(CullPrimitive)), + offset: 12, + ), + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("PrimitiveInput"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 0, + ), + ], + span: 16, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("taskPayload"), + space: TaskPayload, + binding: None, + ty: 3, + init: None, + ), + ( + name: Some("workgroupData"), + space: WorkGroup, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "ts_main", + stage: Task, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ts_main"), + arguments: [], + result: Some(( + ty: 6, + binding: Some(BuiltIn(MeshTaskSize)), + )), + local_variables: [], + expressions: [ + GlobalVariable(1), + Literal(F32(1.0)), + GlobalVariable(0), + AccessIndex( + base: 2, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 4, + 5, + 6, + 7, + ], + ), + GlobalVariable(0), + AccessIndex( + base: 9, + index: 1, + ), + Literal(Bool(true)), + Literal(U32(3)), + Literal(U32(1)), + Literal(U32(1)), + Compose( + ty: 6, + components: [ + 12, + 13, + 14, + ], + ), + ], + named_expressions: {}, + body: [ + Store( + pointer: 0, + value: 1, + ), + Emit(( + start: 3, + end: 4, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 8, + end: 9, + )), + Store( + pointer: 3, + value: 8, + ), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 10, + value: 11, + ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 16, + )), + Return( + value: Some(15), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: Some(0), + ), + ( + name: "ms_main", + stage: Mesh, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ms_main"), + arguments: [ + ( + name: Some("index"), + ty: 5, + binding: Some(BuiltIn(LocalInvocationIndex)), + ), + ( + name: Some("id"), + ty: 6, + binding: Some(BuiltIn(GlobalInvocationId)), + ), + ], + result: None, + local_variables: [ + ( + name: Some("v"), + ty: 4, + init: None, + ), + ( + name: Some("p"), + ty: 7, + init: None, + ), + ], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + Literal(U32(3)), + Literal(U32(1)), + GlobalVariable(1), + Literal(F32(2.0)), + LocalVariable(0), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 8, + 9, + 10, + 11, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 14, + index: 0, + ), + Load( + pointer: 15, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 17, + 18, + 19, + 20, + ], + ), + Binary( + op: Multiply, + left: 21, + right: 16, + ), + Literal(U32(0)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(-1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 26, + 27, + 28, + 29, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 32, + index: 0, + ), + Load( + pointer: 33, + ), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 35, + 36, + 37, + 38, + ], + ), + Binary( + op: Multiply, + left: 39, + right: 34, + ), + Literal(U32(1)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 44, + 45, + 46, + 47, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 50, + index: 0, + ), + Load( + pointer: 51, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 53, + 54, + 55, + 56, + ], + ), + Binary( + op: Multiply, + left: 57, + right: 52, + ), + Literal(U32(2)), + Load( + pointer: 6, + ), + LocalVariable(1), + AccessIndex( + base: 61, + index: 0, + ), + Literal(U32(0)), + Literal(U32(1)), + Literal(U32(2)), + Compose( + ty: 6, + components: [ + 63, + 64, + 65, + ], + ), + AccessIndex( + base: 61, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 68, + index: 1, + ), + Load( + pointer: 69, + ), + Unary( + op: LogicalNot, + expr: 70, + ), + AccessIndex( + base: 61, + index: 2, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 73, + 74, + 75, + 76, + ], + ), + Literal(U32(0)), + Load( + pointer: 61, + ), + ], + named_expressions: { + 0: "index", + 1: "id", + }, + body: [ + MeshFunction(SetMeshOutputs( + vertex_count: 2, + primitive_count: 3, + )), + Store( + pointer: 4, + value: 5, + ), + Emit(( + start: 7, + end: 8, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 12, + end: 13, + )), + Store( + pointer: 7, + value: 12, + ), + Emit(( + start: 13, + end: 14, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 17, + )), + Emit(( + start: 21, + end: 23, + )), + Store( + pointer: 13, + value: 22, + ), + Emit(( + start: 24, + end: 25, + )), + MeshFunction(SetVertex( + index: 23, + value: 24, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 30, + end: 31, + )), + Store( + pointer: 25, + value: 30, + ), + Emit(( + start: 31, + end: 32, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 33, + end: 35, + )), + Emit(( + start: 39, + end: 41, + )), + Store( + pointer: 31, + value: 40, + ), + Emit(( + start: 42, + end: 43, + )), + MeshFunction(SetVertex( + index: 41, + value: 42, + )), + Emit(( + start: 43, + end: 44, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 48, + end: 49, + )), + Store( + pointer: 43, + value: 48, + ), + Emit(( + start: 49, + end: 50, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 51, + end: 53, + )), + Emit(( + start: 57, + end: 59, + )), + Store( + pointer: 49, + value: 58, + ), + Emit(( + start: 60, + end: 61, + )), + MeshFunction(SetVertex( + index: 59, + value: 60, + )), + Emit(( + start: 62, + end: 63, + )), + Emit(( + start: 66, + end: 67, + )), + Store( + pointer: 62, + value: 66, + ), + Emit(( + start: 67, + end: 68, + )), + Emit(( + start: 69, + end: 72, + )), + Store( + pointer: 67, + value: 71, + ), + Emit(( + start: 72, + end: 73, + )), + Emit(( + start: 77, + end: 78, + )), + Store( + pointer: 72, + value: 77, + ), + Emit(( + start: 79, + end: 80, + )), + MeshFunction(SetPrimitive( + index: 78, + value: 79, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: Some(( + topology: Triangles, + max_vertices: 3, + max_vertices_override: None, + max_primitives: 1, + max_primitives_override: None, + vertex_output_type: 4, + primitive_output_type: 7, + )), + task_payload: Some(0), + ), + ( + name: "fs_main", + stage: Fragment, + early_depth_test: None, + workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, + function: ( + name: Some("fs_main"), + arguments: [ + ( + name: Some("vertex"), + ty: 4, + binding: None, + ), + ( + name: Some("primitive"), + ty: 8, + binding: None, + ), + ], + result: Some(( + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + )), + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + AccessIndex( + base: 0, + index: 1, + ), + AccessIndex( + base: 1, + index: 0, + ), + Binary( + op: Multiply, + left: 2, + right: 3, + ), + ], + named_expressions: { + 0: "vertex", + 1: "primitive", + }, + body: [ + Emit(( + start: 2, + end: 5, + )), + Return( + value: Some(4), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: None, + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file diff --git a/naga/tests/out/ir/wgsl-mesh-shader.ron b/naga/tests/out/ir/wgsl-mesh-shader.ron new file mode 100644 index 00000000000..38c79cba451 --- /dev/null +++ b/naga/tests/out/ir/wgsl-mesh-shader.ron @@ -0,0 +1,846 @@ +( + types: [ + ( + name: None, + inner: Scalar(( + kind: Float, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Quad, + scalar: ( + kind: Float, + width: 4, + ), + ), + ), + ( + name: None, + inner: Scalar(( + kind: Bool, + width: 1, + )), + ), + ( + name: Some("TaskPayload"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: None, + offset: 0, + ), + ( + name: Some("visible"), + ty: 2, + binding: None, + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("VertexOutput"), + inner: Struct( + members: [ + ( + name: Some("position"), + ty: 1, + binding: Some(BuiltIn(Position( + invariant: false, + ))), + offset: 0, + ), + ( + name: Some("color"), + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: None, + inner: Scalar(( + kind: Uint, + width: 4, + )), + ), + ( + name: None, + inner: Vector( + size: Tri, + scalar: ( + kind: Uint, + width: 4, + ), + ), + ), + ( + name: Some("PrimitiveOutput"), + inner: Struct( + members: [ + ( + name: Some("index"), + ty: 6, + binding: Some(BuiltIn(TriangleIndices)), + offset: 0, + ), + ( + name: Some("cull"), + ty: 2, + binding: Some(BuiltIn(CullPrimitive)), + offset: 12, + ), + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 16, + ), + ], + span: 32, + ), + ), + ( + name: Some("PrimitiveInput"), + inner: Struct( + members: [ + ( + name: Some("colorMask"), + ty: 1, + binding: Some(Location( + location: 1, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: true, + )), + offset: 0, + ), + ], + span: 16, + ), + ), + ], + special_types: ( + ray_desc: None, + ray_intersection: None, + ray_vertex_return: None, + external_texture_params: None, + external_texture_transfer_function: None, + predeclared_types: {}, + ), + constants: [], + overrides: [], + global_variables: [ + ( + name: Some("taskPayload"), + space: TaskPayload, + binding: None, + ty: 3, + init: None, + ), + ( + name: Some("workgroupData"), + space: WorkGroup, + binding: None, + ty: 0, + init: None, + ), + ], + global_expressions: [], + functions: [], + entry_points: [ + ( + name: "ts_main", + stage: Task, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ts_main"), + arguments: [], + result: Some(( + ty: 6, + binding: Some(BuiltIn(MeshTaskSize)), + )), + local_variables: [], + expressions: [ + GlobalVariable(1), + Literal(F32(1.0)), + GlobalVariable(0), + AccessIndex( + base: 2, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 4, + 5, + 6, + 7, + ], + ), + GlobalVariable(0), + AccessIndex( + base: 9, + index: 1, + ), + Literal(Bool(true)), + Literal(U32(3)), + Literal(U32(1)), + Literal(U32(1)), + Compose( + ty: 6, + components: [ + 12, + 13, + 14, + ], + ), + ], + named_expressions: {}, + body: [ + Store( + pointer: 0, + value: 1, + ), + Emit(( + start: 3, + end: 4, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 8, + end: 9, + )), + Store( + pointer: 3, + value: 8, + ), + Emit(( + start: 10, + end: 11, + )), + Store( + pointer: 10, + value: 11, + ), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 16, + )), + Return( + value: Some(15), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: Some(0), + ), + ( + name: "ms_main", + stage: Mesh, + early_depth_test: None, + workgroup_size: (1, 1, 1), + workgroup_size_overrides: None, + function: ( + name: Some("ms_main"), + arguments: [ + ( + name: Some("index"), + ty: 5, + binding: Some(BuiltIn(LocalInvocationIndex)), + ), + ( + name: Some("id"), + ty: 6, + binding: Some(BuiltIn(GlobalInvocationId)), + ), + ], + result: None, + local_variables: [ + ( + name: Some("v"), + ty: 4, + init: None, + ), + ( + name: Some("p"), + ty: 7, + init: None, + ), + ], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + Literal(U32(3)), + Literal(U32(1)), + GlobalVariable(1), + Literal(F32(2.0)), + LocalVariable(0), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 8, + 9, + 10, + 11, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 14, + index: 0, + ), + Load( + pointer: 15, + ), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 17, + 18, + 19, + 20, + ], + ), + Binary( + op: Multiply, + left: 21, + right: 16, + ), + Literal(U32(0)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(-1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 26, + 27, + 28, + 29, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 32, + index: 0, + ), + Load( + pointer: 33, + ), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 35, + 36, + 37, + 38, + ], + ), + Binary( + op: Multiply, + left: 39, + right: 34, + ), + Literal(U32(1)), + Load( + pointer: 6, + ), + AccessIndex( + base: 6, + index: 0, + ), + Literal(F32(1.0)), + Literal(F32(-1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 44, + 45, + 46, + 47, + ], + ), + AccessIndex( + base: 6, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 50, + index: 0, + ), + Load( + pointer: 51, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 53, + 54, + 55, + 56, + ], + ), + Binary( + op: Multiply, + left: 57, + right: 52, + ), + Literal(U32(2)), + Load( + pointer: 6, + ), + LocalVariable(1), + AccessIndex( + base: 61, + index: 0, + ), + Literal(U32(0)), + Literal(U32(1)), + Literal(U32(2)), + Compose( + ty: 6, + components: [ + 63, + 64, + 65, + ], + ), + AccessIndex( + base: 61, + index: 1, + ), + GlobalVariable(0), + AccessIndex( + base: 68, + index: 1, + ), + Load( + pointer: 69, + ), + Unary( + op: LogicalNot, + expr: 70, + ), + AccessIndex( + base: 61, + index: 2, + ), + Literal(F32(1.0)), + Literal(F32(0.0)), + Literal(F32(1.0)), + Literal(F32(1.0)), + Compose( + ty: 1, + components: [ + 73, + 74, + 75, + 76, + ], + ), + Literal(U32(0)), + Load( + pointer: 61, + ), + ], + named_expressions: { + 0: "index", + 1: "id", + }, + body: [ + MeshFunction(SetMeshOutputs( + vertex_count: 2, + primitive_count: 3, + )), + Store( + pointer: 4, + value: 5, + ), + Emit(( + start: 7, + end: 8, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 12, + end: 13, + )), + Store( + pointer: 7, + value: 12, + ), + Emit(( + start: 13, + end: 14, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 15, + end: 17, + )), + Emit(( + start: 21, + end: 23, + )), + Store( + pointer: 13, + value: 22, + ), + Emit(( + start: 24, + end: 25, + )), + MeshFunction(SetVertex( + index: 23, + value: 24, + )), + Emit(( + start: 25, + end: 26, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 30, + end: 31, + )), + Store( + pointer: 25, + value: 30, + ), + Emit(( + start: 31, + end: 32, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 33, + end: 35, + )), + Emit(( + start: 39, + end: 41, + )), + Store( + pointer: 31, + value: 40, + ), + Emit(( + start: 42, + end: 43, + )), + MeshFunction(SetVertex( + index: 41, + value: 42, + )), + Emit(( + start: 43, + end: 44, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 48, + end: 49, + )), + Store( + pointer: 43, + value: 48, + ), + Emit(( + start: 49, + end: 50, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 0, + end: 0, + )), + Emit(( + start: 51, + end: 53, + )), + Emit(( + start: 57, + end: 59, + )), + Store( + pointer: 49, + value: 58, + ), + Emit(( + start: 60, + end: 61, + )), + MeshFunction(SetVertex( + index: 59, + value: 60, + )), + Emit(( + start: 62, + end: 63, + )), + Emit(( + start: 66, + end: 67, + )), + Store( + pointer: 62, + value: 66, + ), + Emit(( + start: 67, + end: 68, + )), + Emit(( + start: 69, + end: 72, + )), + Store( + pointer: 67, + value: 71, + ), + Emit(( + start: 72, + end: 73, + )), + Emit(( + start: 77, + end: 78, + )), + Store( + pointer: 72, + value: 77, + ), + Emit(( + start: 79, + end: 80, + )), + MeshFunction(SetPrimitive( + index: 78, + value: 79, + )), + Return( + value: None, + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: Some(( + topology: Triangles, + max_vertices: 3, + max_vertices_override: None, + max_primitives: 1, + max_primitives_override: None, + vertex_output_type: 4, + primitive_output_type: 7, + )), + task_payload: Some(0), + ), + ( + name: "fs_main", + stage: Fragment, + early_depth_test: None, + workgroup_size: (0, 0, 0), + workgroup_size_overrides: None, + function: ( + name: Some("fs_main"), + arguments: [ + ( + name: Some("vertex"), + ty: 4, + binding: None, + ), + ( + name: Some("primitive"), + ty: 8, + binding: None, + ), + ], + result: Some(( + ty: 1, + binding: Some(Location( + location: 0, + interpolation: Some(Perspective), + sampling: Some(Center), + blend_src: None, + per_primitive: false, + )), + )), + local_variables: [], + expressions: [ + FunctionArgument(0), + FunctionArgument(1), + AccessIndex( + base: 0, + index: 1, + ), + AccessIndex( + base: 1, + index: 0, + ), + Binary( + op: Multiply, + left: 2, + right: 3, + ), + ], + named_expressions: { + 0: "vertex", + 1: "primitive", + }, + body: [ + Emit(( + start: 2, + end: 5, + )), + Return( + value: Some(4), + ), + ], + diagnostic_filter_leaf: None, + ), + mesh_info: None, + task_payload: None, + ), + ], + diagnostic_filters: [], + diagnostic_filter_leaf: None, + doc_comments: None, +) \ No newline at end of file From f14e0f0b5cee3439b348f8be1f64d65691e475f4 Mon Sep 17 00:00:00 2001 From: Inner Daemons <85136135+inner-daemons@users.noreply.github.com> Date: Fri, 17 Oct 2025 14:14:21 -0500 Subject: [PATCH 63/89] Update naga/src/ir/mod.rs Co-authored-by: Erich Gubler --- naga/src/ir/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index a8a5d220463..151bd36b694 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -2178,7 +2178,7 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, - /// A mesh shader intrinsic + /// A mesh shader intrinsic. MeshFunction(MeshFunction), /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { From 7e12d30c0b29f65bdce3985eb910c2f5e6aad89e Mon Sep 17 00:00:00 2001 From: Inner Daemons <85136135+inner-daemons@users.noreply.github.com> Date: Fri, 17 Oct 2025 14:14:28 -0500 Subject: [PATCH 64/89] Update naga/src/front/wgsl/error.rs Co-authored-by: Erich Gubler --- naga/src/front/wgsl/error.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 5fc69382447..26505c20478 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -1384,12 +1384,12 @@ impl<'a> Error<'a> { notes: vec![], }, Error::MissingMeshShaderInfo { mesh_attribute_span} => ParseError { - message: "mesh shader entry point is missing @vertex_output or @primitive_output".into(), + message: "mesh shader entry point is missing `@vertex_output` or `@primitive_output`".into(), labels: vec![(mesh_attribute_span, "mesh shader entry declared here".into())], notes: vec![], }, Error::OneMeshShaderAttribute { attribute_span } => ParseError { - message: "only one of @vertex_output or @primitive_output was given".into(), + message: "only one of `@vertex_output` or `@primitive_output` was given".into(), labels: vec![(attribute_span, "only one of @vertex_output or @primitive_output is provided".into())], notes: vec![], }, From ce517bb48c99f21be2214b1c01b2501e04c41342 Mon Sep 17 00:00:00 2001 From: Inner Daemons <85136135+inner-daemons@users.noreply.github.com> Date: Fri, 17 Oct 2025 14:14:40 -0500 Subject: [PATCH 65/89] Update naga/src/ir/mod.rs Co-authored-by: Erich Gubler --- naga/src/ir/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 151bd36b694..6f5857861a8 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -997,7 +997,7 @@ pub enum Binding { /// whether they are per-primitive; a mesh shader's per-vertex and per-primitive /// outputs share the same location numbering space. /// - /// Per primitive values are not interpolated at all and are not dependent on the + /// Per-primitive values are not interpolated at all and are not dependent on the /// vertices or pixel location. For example, it may be used to store a /// non-interpolated normal vector. per_primitive: bool, From 083959e4129b4088d71c205320afde601ce9f327 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 17 Oct 2025 14:16:12 -0500 Subject: [PATCH 66/89] Other Erich suggestion --- naga/src/front/wgsl/error.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 26505c20478..004528dbe91 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -1384,7 +1384,7 @@ impl<'a> Error<'a> { notes: vec![], }, Error::MissingMeshShaderInfo { mesh_attribute_span} => ParseError { - message: "mesh shader entry point is missing `@vertex_output` or `@primitive_output`".into(), + message: "mesh shader entry point is missing both `@vertex_output` and `@primitive_output`".into(), labels: vec![(mesh_attribute_span, "mesh shader entry declared here".into())], notes: vec![], }, From 16aa7d059926ca7f8f47bdfc5c7c27cc717b09c5 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 17 Oct 2025 14:50:22 -0500 Subject: [PATCH 67/89] Updated docs & validation for some builtins --- naga/src/ir/mod.rs | 43 +++++++++++++++++++++++++++++++------ naga/src/valid/interface.rs | 25 +++++++++++++-------- 2 files changed, 53 insertions(+), 15 deletions(-) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 6f5857861a8..3c2d1942d7c 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -381,41 +381,72 @@ pub enum AddressSpace { #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum BuiltIn { + /// Written in vertex/mesh shaders, read in fragment shaders Position { invariant: bool }, + /// Read in task, mesh, vertex, and fragment shaders ViewIndex, - // vertex (and often mesh) + + /// Read in vertex shaders BaseInstance, + /// Read in vertex shaders BaseVertex, + /// Written in vertex & mesh shaders ClipDistance, + /// Written in vertex & mesh shaders CullDistance, + /// Read in vertex shaders InstanceIndex, + /// Written in vertex & mesh shaders PointSize, + /// Read in vertex shaders VertexIndex, + /// Read in vertex & task shaders, or mesh shaders in pipelines without task shaders DrawID, - // fragment + + /// Written in fragment shaders FragDepth, + /// Read in fragment shaders PointCoord, + /// Read in fragment shaders FrontFacing, - PrimitiveIndex, // Also for mesh output + /// Read in fragment shaders, in the future may written in mesh shaders + PrimitiveIndex, + /// Read in fragment shaders SampleIndex, + /// Read or written in fragment shaders SampleMask, - // compute (and task/mesh) + + /// Read in compute, task, and mesh shaders GlobalInvocationId, + /// Read in compute, task, and mesh shaders LocalInvocationId, + /// Read in compute, task, and mesh shaders LocalInvocationIndex, + /// Read in compute, task, and mesh shaders WorkGroupId, + /// Read in compute, task, and mesh shaders WorkGroupSize, + /// Read in compute, task, and mesh shaders NumWorkGroups, - // subgroup + + /// Read in compute, task, and mesh shaders NumSubgroups, + /// Read in compute, task, and mesh shaders SubgroupId, + /// Read in compute, fragment, task, and mesh shaders SubgroupSize, + /// Read in compute, fragment, task, and mesh shaders SubgroupInvocationId, - // mesh + + /// Written in task shaders MeshTaskSize, + /// Written in mesh shaders CullPrimitive, + /// Written in mesh shaders PointIndex, + /// Written in mesh shaders LineIndices, + /// Written in mesh shaders TriangleIndices, } diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 04c5d99babb..a4e0af99ccc 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -191,6 +191,7 @@ struct VaryingContext<'a> { capabilities: Capabilities, flags: super::ValidationFlags, mesh_output_type: MeshOutputType, + has_task_payload: bool, } impl VaryingContext<'_> { @@ -243,16 +244,20 @@ impl VaryingContext<'_> { } let (visible, type_good) = match built_in { - Bi::BaseInstance - | Bi::BaseVertex - | Bi::InstanceIndex - | Bi::VertexIndex - | Bi::DrawID => ( + Bi::BaseInstance | Bi::BaseVertex | Bi::InstanceIndex | Bi::VertexIndex => ( self.stage == St::Vertex && !self.output, *ty_inner == Ti::Scalar(crate::Scalar::U32), ), + Bi::DrawID => ( + // Always allowed in task/vertex stage. Allowed in mesh stage if there is no task stage in the pipeline. + (self.stage == St::Vertex + || self.stage == St::Task + || (self.stage == St::Mesh && !self.has_task_payload)) + && !self.output, + *ty_inner == Ti::Scalar(crate::Scalar::U32), + ), Bi::ClipDistance | Bi::CullDistance => ( - self.stage == St::Vertex && self.output, + (self.stage == St::Vertex || self.stage == St::Mesh) && self.output, match *ty_inner { Ti::Array { base, size, .. } => { self.types[base].inner == Ti::Scalar(crate::Scalar::F32) @@ -265,7 +270,7 @@ impl VaryingContext<'_> { }, ), Bi::PointSize => ( - self.stage == St::Vertex && self.output, + (self.stage == St::Vertex || self.stage == St::Mesh) && self.output, *ty_inner == Ti::Scalar(crate::Scalar::F32), ), Bi::PointCoord => ( @@ -290,9 +295,8 @@ impl VaryingContext<'_> { ), Bi::ViewIndex => ( match self.stage { - St::Vertex | St::Fragment => !self.output, + St::Vertex | St::Fragment | St::Task | St::Mesh => !self.output, St::Compute => false, - St::Task | St::Mesh => unreachable!(), }, *ty_inner == Ti::Scalar(crate::Scalar::I32), ), @@ -776,6 +780,7 @@ impl super::Validator { capabilities: self.capabilities, flags: self.flags, mesh_output_type, + has_task_payload: ep.task_payload.is_some(), }; ctx.validate(ep, ty, None) .map_err_inner(|e| EntryPointError::Result(e).with_span())?; @@ -917,6 +922,7 @@ impl super::Validator { capabilities: self.capabilities, flags: self.flags, mesh_output_type: MeshOutputType::None, + has_task_payload: ep.task_payload.is_some(), }; ctx.validate(ep, fa.ty, fa.binding.as_ref()) .map_err_inner(|e| EntryPointError::Argument(index as u32, e).with_span())?; @@ -936,6 +942,7 @@ impl super::Validator { capabilities: self.capabilities, flags: self.flags, mesh_output_type: MeshOutputType::None, + has_task_payload: ep.task_payload.is_some(), }; ctx.validate(ep, fr.ty, fr.binding.as_ref()) .map_err_inner(|e| EntryPointError::Result(e).with_span())?; From 76bfca00a1170673c45e9f07b57a59d259324bbd Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 17 Oct 2025 15:05:55 -0500 Subject: [PATCH 68/89] Added some docs & removed contentious "// TODO" --- naga/src/ir/mod.rs | 15 +++++++++++++++ naga/src/proc/mod.rs | 1 - 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 3c2d1942d7c..4b0769c2803 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -2566,28 +2566,40 @@ pub struct DocComments { pub module: Vec, } +/// The output topology for a mesh shader. Note that mesh shaders don't allow things like triangle-strips. #[derive(Debug, Clone, Copy)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum MeshOutputTopology { + /// Outputs individual vertices to be rendered as points. Points, + /// Outputs groups of 2 vertices to be renderedas lines . Lines, + /// Outputs groups of 3 vertices to be rendered as triangles. Triangles, } +/// Information specific to mesh shader entry points. #[derive(Debug, Clone)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] #[allow(dead_code)] pub struct MeshStageInfo { + /// The type of primitive outputted. pub topology: MeshOutputTopology, + /// The maximum number of vertices a mesh shader may output. pub max_vertices: u32, + /// If pipeline constants are used, the expressions that override `max_vertices` pub max_vertices_override: Option>, + /// The maximum number of primitives a mesh shader may output. pub max_primitives: u32, + /// If pipeline constants are used, the expressions that override `max_primitives` pub max_primitives_override: Option>, + /// The type used by vertex outputs, i.e. what is passed to `setVertex`. pub vertex_output_type: Handle, + /// The type used by primitive outputs, i.e. what is passed to `setPrimitive`. pub primitive_output_type: Handle, } @@ -2597,14 +2609,17 @@ pub struct MeshStageInfo { #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] pub enum MeshFunction { + /// Sets the number of vertices and primitives that will be outputted. SetMeshOutputs { vertex_count: Handle, primitive_count: Handle, }, + /// Sets the output vertex at a given index. SetVertex { index: Handle, value: Handle, }, + /// Sets the output primitive at a given index. SetPrimitive { index: Handle, value: Handle, diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 7b90aa35512..eca63ee4fb5 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -632,7 +632,6 @@ pub fn flatten_compose<'arenas>( } impl super::ShaderStage { - // TODO: make more things respect this pub const fn compute_like(self) -> bool { match self { Self::Vertex | Self::Fragment => false, From 69b97953ce0fd98814e398a0b8aec4aa913045c2 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 29 Oct 2025 14:03:30 -0500 Subject: [PATCH 69/89] Some tweaks --- examples/features/src/lib.rs | 1 + examples/features/src/mesh_shader/mod.rs | 37 ++++++--- .../features/src/mesh_shader/screenshot.png | Bin 0 -> 34256 bytes tests/tests/wgpu-gpu/mesh_shader/mod.rs | 41 +++++++--- tests/tests/wgpu-gpu/mesh_shader/shader.metal | 77 ++++++++++++++++++ wgpu-types/src/features.rs | 5 +- wgpu-types/src/lib.rs | 2 +- 7 files changed, 138 insertions(+), 25 deletions(-) create mode 100644 examples/features/src/mesh_shader/screenshot.png create mode 100644 tests/tests/wgpu-gpu/mesh_shader/shader.metal diff --git a/examples/features/src/lib.rs b/examples/features/src/lib.rs index baacf6a6b39..05f2db5ef21 100644 --- a/examples/features/src/lib.rs +++ b/examples/features/src/lib.rs @@ -48,6 +48,7 @@ fn all_tests() -> Vec { cube::TEST, cube::TEST_LINES, hello_synchronization::tests::SYNC, + mesh_shader::TEST, mipmap::TEST, mipmap::TEST_QUERY, msaa_line::TEST, diff --git a/examples/features/src/mesh_shader/mod.rs b/examples/features/src/mesh_shader/mod.rs index 2916e0fadcf..33ea10ba59d 100644 --- a/examples/features/src/mesh_shader/mod.rs +++ b/examples/features/src/mesh_shader/mod.rs @@ -83,26 +83,23 @@ impl crate::framework::Example for Example { device: &wgpu::Device, _queue: &wgpu::Queue, ) -> Self { - let (ts, ms, fs) = if adapter.get_info().backend == wgpu::Backend::Vulkan { - ( + let (ts, ms, fs) = match adapter.get_info().backend { + wgpu::Backend::Vulkan => ( compile_glsl(device, "task"), compile_glsl(device, "mesh"), compile_glsl(device, "frag"), - ) - } else if adapter.get_info().backend == wgpu::Backend::Dx12 { - ( + ), + wgpu::Backend::Dx12 => ( compile_hlsl(device, "Task", "as"), compile_hlsl(device, "Mesh", "ms"), compile_hlsl(device, "Frag", "ps"), - ) - } else if adapter.get_info().backend == wgpu::Backend::Metal { - ( + ), + wgpu::Backend::Metal => ( compile_msl(device, "taskShader"), compile_msl(device, "meshShader"), compile_msl(device, "fragShader"), - ) - } else { - panic!("Example can only run on vulkan or dx12"); + ), + _ => panic!("Example can currently only run on vulkan, dx12 or metal"), }; let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: None, @@ -196,3 +193,21 @@ impl crate::framework::Example for Example { pub fn main() { crate::framework::run::("mesh_shader"); } + +#[cfg(test)] +#[wgpu_test::gpu_test] +pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams { + name: "mesh_shader", + image_path: "/examples/features/src/mesh_shader/screenshot.png", + width: 1024, + height: 768, + optional_features: wgpu::Features::default(), + base_test_parameters: wgpu_test::TestParameters::default() + .features( + wgpu::Features::EXPERIMENTAL_MESH_SHADER + | wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS, + ) + .limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()), + comparisons: &[wgpu_test::ComparisonType::Mean(0.01)], + _phantom: std::marker::PhantomData::, +}; diff --git a/examples/features/src/mesh_shader/screenshot.png b/examples/features/src/mesh_shader/screenshot.png new file mode 100644 index 0000000000000000000000000000000000000000..df76e1415048224c23423e3cb5610936c91f886d GIT binary patch literal 34256 zcmb_l3tY|T|Now@Qq71I(#HNVa&5wj(=OP>-6*$;C|jwlQKZf>mu+s%%q5qO%g;nO zsUeheF!ywJNl9_qY_)~T((Tka&;R}Tp0e)FDSN$YI(@&-^SR%jr*9@t9B0|STYHY< zEI<6<{iz&hgnt@x?M(2$bLYBPa@^V-AHM(2Ct>%VsNGWw$M2AqZua}dC4Wyy+DLp` zkiX~dNYiEO>~4(od-a1o!9NL1KltE-cCWu~^lm#N1Ea6|zMtlpksYTjth_I=KfliZ z`?vr$@SkmmpG;xh%0}I}KfNSD)Rb#(3$P`}^}HSQFTe@wD#|G;&c0&LW@`*N zuILg0H0X~g^#F4w(d=N8nA*%4Z-jM`{n3TvihEw@6&YbDml@#S1FMT7InMLaqyiRX zB})<%_esFZVV2JHT_8NlBuxDw|FZFDERk!x@+g!2lX*g2ZI1^=TM!g!+g0rB>&1N% z=^a=8&B5UXEc2R3w0;6TuxS$gX7IQImM_8@X9v@deFxx26S?Fjz{ov$7a!%WT9D49 zQIABs__w~{-+Ik2T_nNI#Q`>W#C4++Eh2i`oa4mT<^fx==`RO0B2Bs9U2r3JN(H`) zO{eVFR54IS?lccC6~Edqohg`RyfuD^x=7UNuy3lGMN1USaW}1MAfd;&qjy9qA;&3J zV?X(A+x`5ygjLcg!YXMkKKNVJ*23*?bX8sZ+yMgy2!TS6-7P5da?z=S)DvBKh$ zw2tLjY?aLNvk1QJJbc|hKcyPkmi_$vq!7x>foQdXkz^2TH0 z*}v1nfj_k`kSVZaQXe9F;HCB;#n?+;*NK46JONFP{ZAwEtlAKOkAlrt;=A0^W4gNh zO*~(|ilF)9Sgb}A_*V55+G_hXqJ?>4g38`LU-S*3wf$-ze0NNp3##Mwc%H-7f4zdO z)Uxjhs|4D~`<+7Uq^TT~Xoioz zc0fzNj4x@Y1u=w^LG>NN=?5Po={K%3LG-drey4sUx!%1dxwD8@@c7%rshvM{hfMg3 z29O|q%A;fuG?fXZ>ZK&@JHI61+bQ3DM;1aAi~t>*hQLFMIEqJu>2fftF`bGTB8ssktI z1>9ntim6j1Q6eh>>~|ip*V5S&&FRgeM0#i4A0!JV@2U?&*W=sME++4&g#aT;WKQqQ zF#>-TH&AODs z`(7;&mfCaT>GSaW$ws9zFEw-3*g4{-oALO)n}D`iv03PY4aD7u0`$>}QsrT=IGtF! z&#FL_qGHRdU1s8^FL>^}63epHr9B?*PL^9&Uq=-q^44l4?XfnI}>G(*S(?R;-fU z_fib$6fOLB_hJC~?Ij}o8~Zeh=`5lrWMVnJ`PRE?iB!Wn5qi5wcijxN=xw>46L*q` z(w!u75mq~)Dhk0r?-3E?z_>J>c^#dHAA-LoFqaXRGeb7M@7v>9)x#9!0wT}cY7rjr zHr0~!upH0jc8_I`fMaH4cLG*9pYXQSQm6M8TT}Jbv{PQRv)qA~VZ#EJ0>#8r!4~vr zhgeoVjf7V1f!C`$5qK4X0=7&v>e>ejuD&DNlH^NUd;LWCR&%9HqctISc@>quqn63Q znW{=*X&xvqQ!mVWUm)@o;Z^>e-=`Er($^C3+2l z$3cd~q3>z=^G)iAzlXP%l42y=5lcNSOW7#x-a%i13}FM(Sw1iisa_4x#a&4*ia)A@ z`PiqY>{)e()7UStZuMsUC^!3t(h3qn!QBU3yNX4p2>Zq$v;VZg8yK|pT zW=Gh#Q!E&4%Qe0W1YNt;A^Vh-$NhTCA=5S?_WNs#GNQvB$bf_mf!fFoNg@2Sa%%6LF9{v3O7QB;2tz-DMAIi8 zgda}9RIc@8cBbml$g#bgUQuoO!@oa%Puh`VOq2~2@RYsx3&a~sgs#En9zCJuqG50s z3idux*Mz9+Xjv*cFN^VcW9Y+u%Dxwt^(kdrMc=`W$^RfRiGNEAjUSW@EY1;NYwwxF z=R0@UB$AbFq@f;AggTslk_Phb6eZ9y2l!>)_z$TbmuXv504N$xzDu|kF5Kf)lJ`UP zDH8*wpAW*scOy&VsFjv`U7&DyW5!zA*917RLY|pN9I@NGJWhgoIF&6ZWmY6G(~F_G z-ev^-2k9qHXsSS(_+CWfFX*$wN0?+OdkduT76bssAISN@u-9l`%7+kNiY<2XVb(O_ zZSN`Jex*vVR-8rh`x|jea)Pv2LNUhppNL%T>xt z_*iz10Iy8eLZSd&$~9*RBaZm=QGKn9WoldHCSp`91&-pWT7FMuq?ol)0sb_sSwh1K5BilF?!on zY-ZIUief*Y5YDa6T1RJb+$>sa7s>e%7E+_`{8AI@+n#i$**FC{SShrY8j>vBtraK{TlvZHNnpkR5-TMI zoYPn)?5hHSoOl!ogL7kHNKA*q9U_HO@*cxMt13rWRsO-xB9M|@m*d6iPCPG}Cx9D? z*dczcCD<~kz6YMds+?v_VZzECWb{|^%i)n)!5Dos<1hemBIsmVVU?XhFhjPWnz#MZ z`Zk4MZ#u1)jrFQ|J6GMWA$)&RnzT{Alit8)q4 zF2FXol;p#cz<`%7V4BL}4j2Cd;0O7<~AaTr6n3;BdWt7{|NB|iss_Bx1s*maDW{pqz33weHpWF~hZ$?2BZ zLz*H<)X~y6VL*(a3t=6#oVZ`Zi3h}qFtd7Av|Fi6`0xmc$!8FgVm{0tq%~dxN}r2K z?|Hxq=?w{RI!hs4NH!$81vjSKY@*vRqT5rVThB&xJ0S*A9oF%fPGe5_jsVKp8CdLZ ztPN)n8C3ftte5*xuz&QvDgoQR{VYtdYPMWYx-z*!S;uQvkDqChm85uPJ?JhgUKDx!>r0dOib$G1R> z8f3jr-er~5O4vB@Y>G1j+c$x6=CdHIRk(qqM zoy!N1*bQ$ic9Yp76AfstfITm$%r75gS zCAO5dW>82PKcX%nm0HJMXbg&3*9xIK)VK0XJoJWgO8J4Ncs-Ym@l3R&n-kKmW(Qaco>C7 zTbq)|bDD9agW05mUyu{QQ`$U1DYT#du{ubSA7IIOyF!q5)~cZp1)4R$J%r?FYyDW1 zqKb5OP9yyUxh;&pcsbG2p}r(3M^EzK1lD`^q-ATNl~LeW8A@zwPbxXTk?}Dpm7QSG zux6x|D+t=IwdsgVL?TKHTY_DAm!X>GTHBq@h9k7g)a-SHTz7_53+~v+M6OnOWs>Ci zZ6e`+g@`+*A#s13^`6>Pyp(K+agDGdrmlJ{>IyG7@FM7r0#Hq@;<+U!wt0bEKRIyR zMv~USHg!#0IxAZ~%l~MCfLjVB@UE4wNeDJIR66nQ3qp9m+BCd-VTnm%M-U-+vqtG% zE!z&C^8IrJ=SCL5rk29fn1W(cZwgNK*)$5iNLv}DBuRXQrZ2%enat61zgQruay>pQ zUBr3V17B2~w6dGdE3N8-1T~sANnLd7!R{ys+b{&>;)zX*L?2=>NZ}?VBZG%1aiKt#?0$Csr^jb5>A`su z$Y0c&YXUj-l16zA+A%NyJpHigrH3QtGklDaYzl<5-ZDdM%3@ zR$N3L{!J|HRzK;Z;FPEIZLJ``76s$|T5U$mhaXxp_up7DlH;CT1l~q<$Z_IF6OAEc z`2#9Sc57Ovk1!>fo_pj}Q5vgj#YshOALYwAuTcrRTOF{=KJXWV)QgGKhE37~$hok0 zc0-B2CYQrC7G@Net9oE-(vw%}fPLZDZr&isXwqzy&KiVcq}ee3%ZbwEWHBz(4N?Y_ z$|7PicAZC#W6Ft(ARuy+I!njQv#ehEMEDa}Ie}Pd+N8c9|9Ay}7vG}gM-$fd)rAU?G4vOF1YDR4SdeVpr1m37h+?r&(Lt0FdYU&@Q54Nl ziO+pPt9Poix)R14J|Xuw!Hx1>)BJr~Cc)xCRMPaMFuk#1;S4!>TwRgei{$JcvoYu7 zx8THbhc!q4rI9j)Km0xMv0BU z>L1hUs14RF79n9n6gU^=1gjt9R~OLgu1&8l0Z~v()J8`{a&na9e~Z6>ScW+^BbK7Q z%0byBqDu~lCMh!HON>+(p!i%X`Y2E?U->pllC(ozvXjr`cld;M=+XQR(lb;Gj3KRc zeWkJ1#a*H6urnwBh_IkagEp2cUx;0tM0vu? z23)WcxSeF#0=Ez3-+T&X$RTCGgH0%df}$BOlzFH^inNa}&UJ2austYpsFDY;uwSw*+cTU)!)D&7P{1p#*GUBtf)bL42= zWXOOVBfmWE@8u|E-w>sNXlvon#0#5p!lIZg-lrXV1yFz20HbJI*4vTe;t;FLKW`Bu z3R@{3^`PR6N7IV3Ut{Rj-2$P1ZRn0zA^;@?U*BraC!!c*&U*;hgOY3AQg5x-m#AcE zMee_pphZlz))~~z+TPgLk;2bFgR2wgJ>3RRP!B^JYei(iDy61Yj7R%Z-VYix*sjHm zD5`xTx(MPKql_y4r%An3?B!ci9%d$-_(&t-DW}~s;0AYx`}^UYdZk_}%fzy!Y800$ z)%{j@Y6D$yG%3$pO{!Rnw@k#=!0MwqA%-(*WKK?z3l=XmL6x}+aQdP}oZzphf!&4n zp~7~fI0L4Y_Myu6+(ybIOj!2p2e}BfgMC)esEIGw>Wm*MMnA8Hf!#yZjK*OB^e>AG zHUZD4HL+J&!2|*5mk2xYZPVFl4gJQeg%Y(Fif6e-kO@57>|Q?`8fq)xb^Q`aU)tTy z6;CnU<3>p!%Hhx3=xE;6mIaQmN7a#CMonXn5};MrsW5!INsH8>R5sAX#Fq=}04k!6 zUe7~C+k%Rx$tmmw89QRTS6;B4f1)Tk&`InF)?okl(CW8p{@&S3GuEeEX25T=6kH zm*fZ^n|bX>`E?H$gF2@i6@INlOSG<9GCy_ZD4~fjMK~l$59Vr<2|uVkhKhdU=RfMz zRBEJl<(ajSKa>R`M=NF9p_7F4^MMDc$20qsGxZArHBgLNk#>Huqk#jp7!uqw0%N7V>~On`89C83PxM zdn0w6@P+Jmh106uLV@z!1TJ?kj8xd`>Fl)e%Cfzu^E^-fEnE>)+ey$SjC@!w8yNm< zPiVG>Q7}B2w%?G>_9WQ%$q)b4*KLf#Z)jS^`#$J_dm_rYeoxzS6_>-Wp$oy!5pB87 zPYh7Q;b1SP-RFCyv+9-js+j&6b$UHn5j4Sb?IghejyhXKF|#wnJc*))a=fsFDEi^H1y@6{tlDz`ct_FFM}>ypzB-TY%&$xxD)3y_y&G>;yv zw!}zBBiFZaeQH4A`pK-MWC|;J=zb((H}8~k?T9#-0MeN_gS(pWJrOec;bVK`9MnYJ z9x>Z$6NV<1^I--0J(=vhBN9m%+@2PVg4)>Ysd*>bl?sLS_=?T*p_WbA#t}*5)b_jUQDrI%$(+f_sY>hmYNHr&%9iNcb2}D@ zB1hw^x*r;&5t?vF?>OWoxvO-PUeEm+<>Hb~ZY}@?Ax8_KRGmoV#Q7vIjjNdr<>d$) zi8_!1p)qx9&J4&+Ec$U+`T8;Jm+%dZ(qi&|2lM}@F^omu$<>HgWo1l9)ACs6<5X8@ zXBCJRGZtZuBo_MI;rU$IbG!1%0^bn#H7bb&#Tk3qN65U3qG8e$a=>AzBOH5G^ycP*^tqjReVH5=R>I((-kS23<36nc;!TyMMb8HGk*#@I&rmIC0H?cq{?g-1W zTWeXeo&Lz}St<&+npMZNe5{uJz$Dk=(Ud}RgjUgubomJLkz@goNd(BNx&yL>0MRB) z%>jbxfImf;h@~mo+_5?k#j7NhrWcdC!l?&h!f~-c5^?zWS$9WO)PvbH8C19U!v|$O zN9Es7XXh8MF{nV1K)hQU+cpD`w1qf|oYhXE7bLIEv*q3_kG zV}~?K|1A*Sj4LM85__@4WV)7Q9aAEIkEeXYdW}kXTCBM6TSiWQYxTII0otLADX>7) z3_&Gozb12GRLoXY2X6-pB-4C=(wx05pZzBrvCfz~2BGnOT{o&6C92FKs(mXoTSg+X)UqV7vvo!BTKcLNmx2WMm9KHu_DpThFB6{|qSG{J}I+ds=O0{KwkHkRW z49sKA(e}}sQ6@GQu^RmJ3b=wM;;E)#*=tuKJY~ksA=03^sH0F|CqCH$9;yzdM=H0U<_R+&k6DoB!49fY!f-EXt~pL;tVUvp%Ni^v?>d z3oSU7c?2lkUQNC`5ybG7qW2A?W)S2=G>xdCfv`B^^SV_bt?}8O4%;I6KT6f z{ZKFyVe?{$7LHr!=<0=`<4dn{;s|h2+xnL^YbS3G@hiD-9t7JA>3Xb(bRCJzfY?g` z+j+otca`q2rEUYwA`OxJQ7Z+FW`eH-K^u;A#>NOd*_p<2FpIxeDLs(SmzglD z3&#ihRp2)&oB*1Qh7c0UgS&J|5fyfcs;gWv41Cxq9l#l@-dmx>k;?8|*ziVCZcCsY zKj&w~1z~z7421!miVNxNtWRP55dLHa-*s;`ic-U(bXj$zQ7I+duVGT7KE5bg^Q&Q1ryov3zs`f!dTumS11)Fa)jqOoqv<`nf3 zw!gNZ@X4JiEvSQgDRL<~|3gdYOvRYfL-=z(o`x;OqBT^Ui6D@LTXUY!hvfRlXO4&V ze$C}X=ob66K10jP3K|yJp1Ujxk*2na6OId#g-Z8JAP}IdH#CTr6JMb;Y(?#v8H>$m-qQ}|IW>m=RP+|ln5aJr#zZr3@ zvhKcbeoO|zDj~VfI;d879IslT=$GM~UEQ=zFV#BsD}(QV} zlOljYa0KBx&Q!2w1(>nH!UzDG^z zF34JvvZ1P`O zb|(#v{#UUGxrPJNGAr*fHk^kpA`6EY>A`u}c5is1P%pRTaa>VUm}MkC3LE=zGim9) zGk3Wz{PbK+0TnfbLW~Yqgc=;fQQliRNNto#bRv`CfV(+L%ov)fn+}!+8ZEe1I`H3v z!Q=BDc-m>FK>2})h5_h`rS@^9St<)>0UdLcPt{u%KG|qI_A)0Qs)OVn)MVnMiToo7 zLxYS=Php6OFxky{Kz8XgWlc-8e)Xz(JsX>Eux&6+_fT`Ph3p>|ES(KW(Rj_|JP@ny ztba^IKUK@{+h*%+#U&!D#;MLG>lDok#L&09+6WA_hdb#MW0R%h*?+ok9=UrQ+79m4 zoG%1oTK)3_dYRlisvZkXs+9o--kl&4o%JTMZ<^V&OZ=#j4hBN;&o*PREkxo_o9u(< zLX*2j@;gvDbbR%taLB}c?M=9W+H;vL#tP{4Z1S9mxh1%Y`sZ@=H4(PH@=4W7gkrr` z8i+q`Ccu3%J3zd?e)V1$ivDM{i!E9%&iPA%Y-$h4HdB$+hUBgV?CliI&RYiZE}-<} zdMWBDoUDJG^k7V6=gl{Avg4qn&D0-MGo^(-r$*x}HkzyQMF1OZQawG##+Z`8-6dfr z7Mz7{)ZXh^{`pMK<{L}SqW zL9_ocyX?T3V44AIHC;goIw|Ov_v!KH^9&OUkZq^4r$r0U$4}^8>}tdH{p!UbCB$T} zrxVTyZzOdTPzLX%8+NX&c1%pz|Gj1`?(GI#<7fi3n%aQPfXyRA>#a^nTd)PZ1svXm#mI7 zT*Z%-y&&=6r9@`3S#9UR{S85x*d`}l*%`wr7bb9f&?*>V+)U>J_o;l6N=L3RGcKdy z6e^M6mKo)ZzeoB#Gym}t=cb+TevxNfx+3NCRSE;SZ2_X-3{(;#hSzLx7GKX0`$Kh5cYiBR0%4Ig*H(N1(G&cVaXUw(ZjtNK73y=p z-<4E~61ieSXB>QLQ*YhWZl%nWUEUZuuFRs^ z$S@b9iAPwMQwQ0~)2v|Asu4*$Zk#RoZ^-X<*5gF>RmNxLn0lpAdB1@6sVZT*Um5!m z^^z+4A(?h6k5ie_?wc3i_7x=KSiqo`wEWyr_V8hx&*S!=JQIF4{;zdzozu?z8WgZ! z5|?sWmM6Su9ZM~v+K_Fhpm&J-5YqsIWLyD~+_{-TO0yaC)?Y%>{tjKBkVg&XJhjJ_ zUmR{m-Vc796<|%bF%;K74WMttd4vti8t21g&H}o(!4v(ICP$K9q(1pJBdl?xw78cJ z)z&F2^PkcQR2CA4&NIyYzcy-*e?p?qVuNH;5Z|b^#2*suh|3Xa4pWD$O8VO)y6DG) zE0UYYFSnsud`I;^P3D5~vQn04RvZQ$!g&jQGj7GmgUT(Vq8^(oRtbu|b*28Jbw1LY zz5-WhLWh<#;f109RuxRZ^gWBMa?e&6(P(FDTy^j8(eoEL7n&N<5T|Yo$h5qcMhM_M zLWlZ?EcA5Pguq;kz5mP3m=)t#I$gohS{}rO1|Bmv%zZ~6BvLgUt>uaFg5p=(lsJ+a zA7o1@NEPSjT8Y1EW`3l_@Y+_r`SJ1vUMV3yXli5HRL1}$aKtj-vYw7&KW{@CAgw!d zsAuhID)UdEJJl-p1{mb@l%R^oz#=V=H%L$kKVQE* zUqKF(VtAXCBITZ}Fxx0;&}c!@o2|&u_WV+GJ!V|luyezYK`qwlB~gXZyN!b<1#meQ z^`^4)ETT**5KSm6VCR?%(69p?tdWZ zPiMOxpJ$KMMf0t^akB%h)>?j?x)gCtU`P|Wdgw-Z+TODY0}p7&hWdk9 zY1u{|2$&pO8}jCFi8vl??|%{Z3+hyE)&}j-St%jqIevm;-&l&Tv?)g=*Db_hUy=K6 z?$+;Cw3_|}Bnog{6QnLV*Lv90-mBAMmr|Rn4)5XObUw83W|gD(6y&O1o8{_^`qB9b z;-$rAbU{z+TzJo{OuIs#ffiz8ebwVS(7u<~Jz&$a0iCKeuSZ$e|! zdbyB@!&h9)y&awD7V$PTRXKT(dG{DCu&LOwN5vv;H$J=Ae4l+|42p!k z!+U1!sES=D7);j!y|`!jpA%6of0|}s12w6C)w(DL1(Ta%iw8&0YR=L<99Uiewnv!#EkqxJFK2GUi zDp2{5uPSs4;J#>chyw8nk+AXvv!8L&vAyQIX-p>NFD0*kfbE5Ko!(WS+q+-n{?;p< zT2_=Z33*UaBBzQM1lPZet94wFR8>pL7aQ!sl!u+@1}@vm+|&5r*$b;YTO!Uz#m(dP z=!*>ft%G@jJ(o=yTrAdVT*d*HKcOz)v`Jm0zZZ7z>*UShYJNJw{^AB6AJIpwxe{Kb zR_-?b_L7NP?>5I?eH`P*CETkzY4-ExNnEEk@y6@&{o#{+FVPUY9=SV&4IT}$hQBWm zTz_a@&rys%r80*qGxNt=j%gwT40^!ruam85y|>Jf{{ByQM|U|5zUzmNbK(UwA0yts8>cgiWdW= z`X>lllYd<@2!A6KmqVaIf&Z_?hMQ=)wH;chUlzO%V%lBkh3iik_yr*q7!fs@;PunmPUDH8L+i+k&R*7&Oef z=%MqsGdbC|%pgvm{NSETwZFzHU=^L|sTsH6Yp!dI%RY#~wJ!6ma11k@Eo$m^0-d5b z9>*?Wrvxn?6s_AS)n=byp8sR9q(bc(C`evrSpVEeD~hmc9wS`4kdeIHD)*EGRBCM^ zx}$2?vH7OStK7SBI*k6)fP-r3<&#c{;rXY-BxjODamK~xEbE_4YlWjKv$DS18`mGP z8{LNsJj=~m+nGe6p)>Ngn9zm1UQ=9D3csqN+qX8}%G+~$?(fM5@}2lI|9JsrS5pH&}GIR<1#5{LNcm9t@t=L?18%u-q ze7f4+-pkI%ecu5##i!qJqXWSgwnh4wtLHqq&fY5+m?;g*t795V{NI)K4jC8t%Erh3 zYhfV^izEBx*=&3I(**Ig2!rCvL26|4buPGB_KZ!Iz9ksC)NkmCV>J^~x7$CQac=wX zH#dDr&FS#`+h0bHnx3}A;nB?@p-P*u%Lbx9-+g<2T5_W5bJ@DX_XMGNQ+VR&lh6-% zAW2LT?J_Rj*)|DQ%v`UDom0SeO=kX18oB9W1J^t&uJuh5V07lLAoNdMn*QRmNHsGY z+7bmYI+etRJbcqy{$w~t5!z6u@Ox8r@be?g|6SY;^K{Jwakj{y4c)B5%wh+Z+!HzX zpjmnZ`S)GUT&beJj8PaU$Z4Z{3)y}^*l&kQ=76(+PMUq)gA8)Y+Quc5W%7umL#^-z z~ww#{GHLFlY)XpT|R{WiLS$Kt9b{Q;A{QNDG#Fl4xSu5-&aq{Sn9mjC+_k$QT=&RMXX z88vsfS%=bPj&zmXWTE1YV=LSnVU=hbH=DZ^*tQ@fZLs0kys0b`gKjgfcI+Qy;8@kR zxXRFFUu4Wo$4^_wsr4HrpJ~1wGeG>8a~n+_4H5As&#Qm$s?jrYj64e4Y9ZXzu8A{m zw!-RFS(kUQojojWCU>iiF=JhJjV-G3_|(cfdJ-tr+5wAO>TY)hr6tEMb!vrwQl3eY zUSD*i^1ZGnxgYdP!Wsp`^$_d4uzF77B?tMlmx0>VPg-FHi)LeTZqJeE#e(tKZ3`gx z%rn`TdPpNRD-|8z*!O0|3Ew;i`HwFHv74P+i`XaWsR`UbxW8ZNMLh=*v%8y@y32`4aY0wTE{qtxZxrgf~ ztci0J|Ne53n?6ibVZ58W+{Vnt+caGfa>jFAt2wuB;oy4)Hm4DVG0j)UJsy{*o?j|0 zZ;-#}6?=el2UpmhH5>dgI2ZcpFTZuI<(!Z4E${D*CtrBE$hnseoo}_T`ele^+0X;Z zsjamr#}?UEcf4ZlmIXJ$s_c)?KdE#o%1C+Oujjm412>SWa<95ML7eas%;lV{Th}8$ zZbfZG7PQ1`%M#OD`)(Ti{QxFY#12k(EDO)Oe=ajz&mxur$LOwWlgrGTUmw^w-EsB$ z3F5++gIyQ(GTYW#mCf5P+Gl+#*z=O{+vd31{+}$P5$c!V6Z(gr{Wp4It1&d&esQMx zYbm_;nI3sbrltQ(lkh!RMjx~>o#vC9o1VU|(^?%|PWsV86Z=2opOarYg0@3=#BmTb!q2%@flG6@R{?R=4ZAxSqPfU zgd-EU3sk)nv4sBCRijFJsxb(bfAxh|<}}Q|;R7M4+p2$Y+0Drk%J)(GaK+2ud zt8e;u*-+K-lU-c&IVW593|)o`WG2gy>A9CB%(^|P`pedeYAlnb9yTt1q?ObKyrdqq z-=UIp*Xag{29EI90^2z^qQJj@AFiH?xru)&KMwQt@h-Ca#52w#Cu-lXM%UB#-bhN& zX*y?EykSB9*@PY2{3GwQRzTC)!(GC>Oh>U)`T6g|*^d>hI_Q>3MUKTHwd-pK?mJWz z&U!<$CuP|WzC|9>UH+Yw7g?H=B8=+c)a&8VdoD53w|ws`2~An{SH(7e_nQ~~P~Fe- zn3rmox8i2X;lI)fved(zUn?AZZMtp$+a=%cl6BC?roQ!P{yW(m{`+vu#P?6U>;2{b E13MvD<^TWy literal 0 HcmV?d00001 diff --git a/tests/tests/wgpu-gpu/mesh_shader/mod.rs b/tests/tests/wgpu-gpu/mesh_shader/mod.rs index 8a3218970f0..2288c4089d8 100644 --- a/tests/tests/wgpu-gpu/mesh_shader/mod.rs +++ b/tests/tests/wgpu-gpu/mesh_shader/mod.rs @@ -98,6 +98,18 @@ fn compile_hlsl( } } +fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule { + unsafe { + device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough { + entry_point: entry.to_owned(), + label: None, + msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))), + num_workgroups: (1, 1, 1), + ..Default::default() + }) + } +} + fn get_shaders( device: &wgpu::Device, backend: wgpu::Backend, @@ -114,8 +126,8 @@ fn get_shaders( // (In the case that the platform does support mesh shaders, the dummy // shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS.) let dummy_shader = device.create_shader_module(wgpu::include_wgsl!("non_mesh.wgsl")); - if backend == wgpu::Backend::Vulkan { - ( + match backend { + wgpu::Backend::Vulkan => ( info.use_task.then(|| compile_glsl(device, "task")), if info.use_mesh { compile_glsl(device, "mesh") @@ -123,9 +135,8 @@ fn get_shaders( dummy_shader }, info.use_frag.then(|| compile_glsl(device, "frag")), - ) - } else if backend == wgpu::Backend::Dx12 { - ( + ), + wgpu::Backend::Dx12 => ( info.use_task .then(|| compile_hlsl(device, "Task", "as", test_name)), if info.use_mesh { @@ -135,11 +146,21 @@ fn get_shaders( }, info.use_frag .then(|| compile_hlsl(device, "Frag", "ps", test_name)), - ) - } else { - assert!(!MESH_SHADER_BACKENDS.contains(Backends::from(backend))); - assert!(!info.use_task && !info.use_mesh && !info.use_frag); - (None, dummy_shader, None) + ), + wgpu::Backend::Metal => ( + info.use_task.then(|| compile_msl(device, "taskShader")), + if info.use_mesh { + compile_msl(device, "meshShader") + } else { + dummy_shader + }, + info.use_frag.then(|| compile_msl(device, "fragShader")), + ), + _ => { + assert!(!MESH_SHADER_BACKENDS.contains(Backends::from(backend))); + assert!(!info.use_task && !info.use_mesh && !info.use_frag); + (None, dummy_shader, None) + } } } diff --git a/tests/tests/wgpu-gpu/mesh_shader/shader.metal b/tests/tests/wgpu-gpu/mesh_shader/shader.metal new file mode 100644 index 00000000000..4c7da503832 --- /dev/null +++ b/tests/tests/wgpu-gpu/mesh_shader/shader.metal @@ -0,0 +1,77 @@ +using namespace metal; + +struct OutVertex { + float4 Position [[position]]; + float4 Color [[user(locn0)]]; +}; + +struct OutPrimitive { + float4 ColorMask [[flat]] [[user(locn1)]]; + bool CullPrimitive [[primitive_culled]]; +}; + +struct InVertex { +}; + +struct InPrimitive { + float4 ColorMask [[flat]] [[user(locn1)]]; +}; + +struct FragmentIn { + float4 Color [[user(locn0)]]; + float4 ColorMask [[flat]] [[user(locn1)]]; +}; + +struct PayloadData { + float4 ColorMask; + bool Visible; +}; + +using Meshlet = metal::mesh; + + +constant float4 positions[3] = { + float4(0.0, 1.0, 0.0, 1.0), + float4(-1.0, -1.0, 0.0, 1.0), + float4(1.0, -1.0, 0.0, 1.0) +}; + +constant float4 colors[3] = { + float4(0.0, 1.0, 0.0, 1.0), + float4(0.0, 0.0, 1.0, 1.0), + float4(1.0, 0.0, 0.0, 1.0) +}; + + +[[object]] +void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) { + outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0); + outPayload.Visible = true; + grid.set_threadgroups_per_grid(uint3(3, 1, 1)); +} + +[[mesh]] +void meshShader( + object_data PayloadData const& payload [[payload]], + Meshlet out +) +{ + out.set_primitive_count(1); + + for(int i = 0;i < 3;i++) { + OutVertex vert; + vert.Position = positions[i]; + vert.Color = colors[i] * payload.ColorMask; + out.set_vertex(i, vert); + out.set_index(i, i); + } + + OutPrimitive prim; + prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0); + prim.CullPrimitive = !payload.Visible; + out.set_primitive(0, prim); +} + +fragment float4 fragShader(FragmentIn data [[stage_in]]) { + return data.Color * data.ColorMask; +} diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index c36a16c35ea..1741b7a14b1 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -1166,12 +1166,11 @@ bitflags_array! { /// This is a native only feature. const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 47; - /// Enables mesh shaders and task shaders in mesh shader pipelines. + /// Enables mesh shaders and task shaders in mesh shader pipelines. This extension does NOT imply support for + /// compiling mesh shaders at runtime. Rather, the user must use custom passthrough shaders. /// /// Supported platforms: /// - Vulkan (with [VK_EXT_mesh_shader](https://registry.khronos.org/vulkan/specs/latest/man/html/VK_EXT_mesh_shader.html)) - /// - /// Potential Platforms: /// - DX12 /// - Metal /// diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 10545bc560d..f91c6547649 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -1045,7 +1045,7 @@ impl Limits { #[must_use] pub const fn using_recommended_minimum_mesh_shader_values(self) -> Self { Self { - // I believe this is a common limit for apple devices. I'm not entirely sure why. + // This is a common limit for apple devices. It's not immediately clear why. max_task_workgroup_total_count: 1024, max_task_workgroups_per_dimension: 1024, // llvmpipe reports 0 multiview count, which just means no multiview is allowed From 9f5e3ff2fe5d9dc0c42539ee87ad55fb777ddbf2 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 29 Oct 2025 14:06:02 -0500 Subject: [PATCH 70/89] Made tests actually run on metal --- tests/tests/wgpu-gpu/mesh_shader/mod.rs | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/tests/wgpu-gpu/mesh_shader/mod.rs b/tests/tests/wgpu-gpu/mesh_shader/mod.rs index 2288c4089d8..1b79770b254 100644 --- a/tests/tests/wgpu-gpu/mesh_shader/mod.rs +++ b/tests/tests/wgpu-gpu/mesh_shader/mod.rs @@ -3,15 +3,11 @@ use std::{ process::Stdio, }; -use wgpu::{util::DeviceExt, Backends}; +use wgpu::util::DeviceExt; use wgpu_test::{ - fail, gpu_test, FailureCase, GpuTestConfiguration, GpuTestInitializer, TestParameters, - TestingContext, + fail, gpu_test, GpuTestConfiguration, GpuTestInitializer, TestParameters, TestingContext, }; -/// Backends that support mesh shaders -const MESH_SHADER_BACKENDS: Backends = Backends::DX12.union(Backends::VULKAN); - pub fn all_tests(tests: &mut Vec) { tests.extend([ MESH_PIPELINE_BASIC_MESH, @@ -157,7 +153,6 @@ fn get_shaders( info.use_frag.then(|| compile_msl(device, "fragShader")), ), _ => { - assert!(!MESH_SHADER_BACKENDS.contains(Backends::from(backend))); assert!(!info.use_task && !info.use_mesh && !info.use_frag); (None, dummy_shader, None) } @@ -396,7 +391,6 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration { GpuTestConfiguration::new().parameters( TestParameters::default() - .skip(FailureCase::backend(!MESH_SHADER_BACKENDS)) .test_features_limits() .features( wgpu::Features::EXPERIMENTAL_MESH_SHADER From e02e379a9dc56eeb8f5ddf203b6be1059c5afb93 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 29 Oct 2025 14:40:08 -0500 Subject: [PATCH 71/89] Tried to improve one part of the code --- wgpu-hal/src/metal/command.rs | 263 +++++++++++++++++----------------- 1 file changed, 131 insertions(+), 132 deletions(-) diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 46cf52716c8..bf911362e58 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -143,6 +143,127 @@ impl super::CommandEncoder { self.state.reset(); self.leave_blit(); } + + /// Updates the bindings for a single shader stage, called in `set_bind_group`. + #[expect(clippy::too_many_arguments)] + fn update_bind_group_state( + &mut self, + stage: naga::ShaderStage, + render_encoder: Option<&metal::RenderCommandEncoder>, + compute_encoder: Option<&metal::ComputeCommandEncoder>, + index_base: super::ResourceData, + bg_info: &super::BindGroupLayoutInfo, + dynamic_offsets: &[wgt::DynamicOffset], + group_index: u32, + group: &super::BindGroup, + ) { + let resource_indices = match stage { + naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs, + naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs, + naga::ShaderStage::Task => &bg_info.base_resource_indices.ts, + naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms, + naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs, + }; + let buffers = match stage { + naga::ShaderStage::Vertex => group.counters.vs.buffers, + naga::ShaderStage::Fragment => group.counters.fs.buffers, + naga::ShaderStage::Task => group.counters.ts.buffers, + naga::ShaderStage::Mesh => group.counters.ms.buffers, + naga::ShaderStage::Compute => group.counters.cs.buffers, + }; + let mut changes_sizes_buffer = false; + for index in 0..buffers { + let buf = &group.buffers[(index_base.buffers + index) as usize]; + let mut offset = buf.offset; + if let Some(dyn_index) = buf.dynamic_index { + offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; + } + let a1 = (resource_indices.buffers + index) as u64; + let a2 = Some(buf.ptr.as_native()); + let a3 = offset; + match stage { + naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_buffer(a1, a2, a3), + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_buffer(a1, a2, a3) + } + naga::ShaderStage::Task => render_encoder.unwrap().set_object_buffer(a1, a2, a3), + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_buffer(a1, a2, a3), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_buffer(a1, a2, a3), + } + if let Some(size) = buf.binding_size { + let br = naga::ResourceBinding { + group: group_index, + binding: buf.binding_location, + }; + self.state.storage_buffer_length_map.insert(br, size); + changes_sizes_buffer = true; + } + } + if changes_sizes_buffer { + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(stage, &mut self.temp.binding_sizes) + { + let a1 = index as _; + let a2 = (sizes.len() * WORD_SIZE) as u64; + let a3 = sizes.as_ptr().cast(); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_bytes(a1, a2, a3) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_bytes(a1, a2, a3) + } + naga::ShaderStage::Task => render_encoder.unwrap().set_object_bytes(a1, a2, a3), + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_bytes(a1, a2, a3), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_bytes(a1, a2, a3), + } + } + } + let samplers = match stage { + naga::ShaderStage::Vertex => group.counters.vs.samplers, + naga::ShaderStage::Fragment => group.counters.fs.samplers, + naga::ShaderStage::Task => group.counters.ts.samplers, + naga::ShaderStage::Mesh => group.counters.ms.samplers, + naga::ShaderStage::Compute => group.counters.cs.samplers, + }; + for index in 0..samplers { + let res = group.samplers[(index_base.samplers + index) as usize]; + let a1 = (resource_indices.samplers + index) as u64; + let a2 = Some(res.as_native()); + match stage { + naga::ShaderStage::Vertex => { + render_encoder.unwrap().set_vertex_sampler_state(a1, a2) + } + naga::ShaderStage::Fragment => { + render_encoder.unwrap().set_fragment_sampler_state(a1, a2) + } + naga::ShaderStage::Task => render_encoder.unwrap().set_object_sampler_state(a1, a2), + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_sampler_state(a1, a2), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_sampler_state(a1, a2), + } + } + + let textures = match stage { + naga::ShaderStage::Vertex => group.counters.vs.textures, + naga::ShaderStage::Fragment => group.counters.fs.textures, + naga::ShaderStage::Task => group.counters.ts.textures, + naga::ShaderStage::Mesh => group.counters.ms.textures, + naga::ShaderStage::Compute => group.counters.cs.textures, + }; + for index in 0..textures { + let res = group.textures[(index_base.textures + index) as usize]; + let a1 = (resource_indices.textures + index) as u64; + let a2 = Some(res.as_native()); + match stage { + naga::ShaderStage::Vertex => render_encoder.unwrap().set_vertex_texture(a1, a2), + naga::ShaderStage::Fragment => render_encoder.unwrap().set_fragment_texture(a1, a2), + naga::ShaderStage::Task => render_encoder.unwrap().set_object_texture(a1, a2), + naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2), + naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2), + } + } + } } impl super::CommandState { @@ -683,138 +804,16 @@ impl crate::CommandEncoder for super::CommandEncoder { render_encoder: Option<&metal::RenderCommandEncoder>, compute_encoder: Option<&metal::ComputeCommandEncoder>, index_base: super::ResourceData| { - let resource_indices = match stage { - naga::ShaderStage::Vertex => &bg_info.base_resource_indices.vs, - naga::ShaderStage::Fragment => &bg_info.base_resource_indices.fs, - naga::ShaderStage::Task => &bg_info.base_resource_indices.ts, - naga::ShaderStage::Mesh => &bg_info.base_resource_indices.ms, - naga::ShaderStage::Compute => &bg_info.base_resource_indices.cs, - }; - let buffers = match stage { - naga::ShaderStage::Vertex => group.counters.vs.buffers, - naga::ShaderStage::Fragment => group.counters.fs.buffers, - naga::ShaderStage::Task => group.counters.ts.buffers, - naga::ShaderStage::Mesh => group.counters.ms.buffers, - naga::ShaderStage::Compute => group.counters.cs.buffers, - }; - let mut changes_sizes_buffer = false; - for index in 0..buffers { - let buf = &group.buffers[(index_base.buffers + index) as usize]; - let mut offset = buf.offset; - if let Some(dyn_index) = buf.dynamic_index { - offset += dynamic_offsets[dyn_index as usize] as wgt::BufferAddress; - } - let a1 = (resource_indices.buffers + index) as u64; - let a2 = Some(buf.ptr.as_native()); - let a3 = offset; - match stage { - naga::ShaderStage::Vertex => { - render_encoder.unwrap().set_vertex_buffer(a1, a2, a3) - } - naga::ShaderStage::Fragment => { - render_encoder.unwrap().set_fragment_buffer(a1, a2, a3) - } - naga::ShaderStage::Task => { - render_encoder.unwrap().set_object_buffer(a1, a2, a3) - } - naga::ShaderStage::Mesh => { - render_encoder.unwrap().set_mesh_buffer(a1, a2, a3) - } - naga::ShaderStage::Compute => { - compute_encoder.unwrap().set_buffer(a1, a2, a3) - } - } - if let Some(size) = buf.binding_size { - let br = naga::ResourceBinding { - group: group_index, - binding: buf.binding_location, - }; - self.state.storage_buffer_length_map.insert(br, size); - changes_sizes_buffer = true; - } - } - if changes_sizes_buffer { - if let Some((index, sizes)) = self - .state - .make_sizes_buffer_update(stage, &mut self.temp.binding_sizes) - { - let a1 = index as _; - let a2 = (sizes.len() * WORD_SIZE) as u64; - let a3 = sizes.as_ptr().cast(); - match stage { - naga::ShaderStage::Vertex => { - render_encoder.unwrap().set_vertex_bytes(a1, a2, a3) - } - naga::ShaderStage::Fragment => { - render_encoder.unwrap().set_fragment_bytes(a1, a2, a3) - } - naga::ShaderStage::Task => { - render_encoder.unwrap().set_object_bytes(a1, a2, a3) - } - naga::ShaderStage::Mesh => { - render_encoder.unwrap().set_mesh_bytes(a1, a2, a3) - } - naga::ShaderStage::Compute => { - compute_encoder.unwrap().set_bytes(a1, a2, a3) - } - } - } - } - let samplers = match stage { - naga::ShaderStage::Vertex => group.counters.vs.samplers, - naga::ShaderStage::Fragment => group.counters.fs.samplers, - naga::ShaderStage::Task => group.counters.ts.samplers, - naga::ShaderStage::Mesh => group.counters.ms.samplers, - naga::ShaderStage::Compute => group.counters.cs.samplers, - }; - for index in 0..samplers { - let res = group.samplers[(index_base.samplers + index) as usize]; - let a1 = (resource_indices.samplers + index) as u64; - let a2 = Some(res.as_native()); - match stage { - naga::ShaderStage::Vertex => { - render_encoder.unwrap().set_vertex_sampler_state(a1, a2) - } - naga::ShaderStage::Fragment => { - render_encoder.unwrap().set_fragment_sampler_state(a1, a2) - } - naga::ShaderStage::Task => { - render_encoder.unwrap().set_object_sampler_state(a1, a2) - } - naga::ShaderStage::Mesh => { - render_encoder.unwrap().set_mesh_sampler_state(a1, a2) - } - naga::ShaderStage::Compute => { - compute_encoder.unwrap().set_sampler_state(a1, a2) - } - } - } - - let textures = match stage { - naga::ShaderStage::Vertex => group.counters.vs.textures, - naga::ShaderStage::Fragment => group.counters.fs.textures, - naga::ShaderStage::Task => group.counters.ts.textures, - naga::ShaderStage::Mesh => group.counters.ms.textures, - naga::ShaderStage::Compute => group.counters.cs.textures, - }; - for index in 0..textures { - let res = group.textures[(index_base.textures + index) as usize]; - let a1 = (resource_indices.textures + index) as u64; - let a2 = Some(res.as_native()); - match stage { - naga::ShaderStage::Vertex => { - render_encoder.unwrap().set_vertex_texture(a1, a2) - } - naga::ShaderStage::Fragment => { - render_encoder.unwrap().set_fragment_texture(a1, a2) - } - naga::ShaderStage::Task => { - render_encoder.unwrap().set_object_texture(a1, a2) - } - naga::ShaderStage::Mesh => render_encoder.unwrap().set_mesh_texture(a1, a2), - naga::ShaderStage::Compute => compute_encoder.unwrap().set_texture(a1, a2), - } - } + self.update_bind_group_state( + stage, + render_encoder, + compute_encoder, + index_base, + bg_info, + dynamic_offsets, + group_index, + group, + ); }; if let Some(encoder) = render_encoder { update_stage( From 6937fa26e63348ca0935daa83f7d6ce7ae0675e2 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 29 Oct 2025 15:02:56 -0500 Subject: [PATCH 72/89] Updated feature check to hopefully fix CI --- wgpu-hal/src/metal/adapter.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index e3cfb8d134a..5efd393ac2d 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -902,8 +902,10 @@ impl super::PrivateCapabilities { && (device.supports_family(MTLGPUFamily::Apple7) || device.supports_family(MTLGPUFamily::Mac2)), supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac), - mesh_shaders: device.supports_family(MTLGPUFamily::Apple7) - || device.supports_family(MTLGPUFamily::Mac2), + mesh_shaders: family_check + && device.supports_family(MTLGPUFamily::Metal3) + && (device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Mac2)), shader_barycentrics: device.supports_shader_barycentric_coordinates(), // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=3 supports_memoryless_storage: if family_check { From 8bbcea0a26f6c32060ffb910aad646f4bbcfbed0 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 29 Oct 2025 22:42:46 -0500 Subject: [PATCH 73/89] Smartified mesh shader detection --- wgpu-hal/src/metal/adapter.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index 5efd393ac2d..8b460d26472 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -607,6 +607,9 @@ impl super::PrivateCapabilities { let argument_buffers = device.argument_buffers_support(); + // Lmao + let is_virtual = device.name().to_lowercase().contains("virtual"); + Self { family_check, msl_version: if os_is_xr || version.at_least((14, 0), (17, 0), os_is_mac) { @@ -903,9 +906,11 @@ impl super::PrivateCapabilities { || device.supports_family(MTLGPUFamily::Mac2)), supports_shared_event: version.at_least((10, 14), (12, 0), os_is_mac), mesh_shaders: family_check - && device.supports_family(MTLGPUFamily::Metal3) - && (device.supports_family(MTLGPUFamily::Apple7) - || device.supports_family(MTLGPUFamily::Mac2)), + && (device.supports_family(MTLGPUFamily::Metal3) + || device.supports_family(MTLGPUFamily::Apple7) + || device.supports_family(MTLGPUFamily::Mac2)) + // Mesh shaders don't work on virtual devices even if they should be supported. + && !is_virtual, shader_barycentrics: device.supports_shader_barycentric_coordinates(), // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf#page=3 supports_memoryless_storage: if family_check { From b4abddd88fe9c663a903a7521971d5d5076415d1 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Wed, 29 Oct 2025 22:57:03 -0500 Subject: [PATCH 74/89] Nicified some stuff --- wgpu-hal/src/metal/device.rs | 141 +++++++++++++++++++---------------- 1 file changed, 78 insertions(+), 63 deletions(-) diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index a4cd1341183..0c5de7c9ec2 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1125,53 +1125,61 @@ impl crate::Device for super::Device { let vs_info; let ts_info; let ms_info; + + // Create the pipeline descriptor and do vertex/mesh pipeline specific setup let descriptor = match desc.vertex_processor { crate::VertexProcessor::Standard { vertex_buffers, ref vertex_stage, } => { + // Vertex pipeline specific setup + let descriptor = metal::RenderPipelineDescriptor::new(); ts_info = None; ms_info = None; - vs_info = Some({ - let mut vertex_buffer_mappings = - Vec::::new(); - for (i, vbl) in vertex_buffers.iter().enumerate() { - let mut attributes = Vec::::new(); - for attribute in vbl.attributes.iter() { - attributes.push(naga::back::msl::AttributeMapping { - shader_location: attribute.shader_location, - offset: attribute.offset as u32, - format: convert_vertex_format_to_naga(attribute.format), - }); - } - vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping { - id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, - stride: if vbl.array_stride > 0 { - vbl.array_stride.try_into().unwrap() - } else { - vbl.attributes - .iter() - .map(|attribute| attribute.offset + attribute.format.size()) - .max() - .unwrap_or(0) - .try_into() - .unwrap() - }, - step_mode: match (vbl.array_stride == 0, vbl.step_mode) { - (true, _) => naga::back::msl::VertexBufferStepMode::Constant, - (false, wgt::VertexStepMode::Vertex) => { - naga::back::msl::VertexBufferStepMode::ByVertex - } - (false, wgt::VertexStepMode::Instance) => { - naga::back::msl::VertexBufferStepMode::ByInstance - } - }, - attributes, + // Collect vertex buffer mappings + let mut vertex_buffer_mappings = + Vec::::new(); + for (i, vbl) in vertex_buffers.iter().enumerate() { + let mut attributes = Vec::::new(); + for attribute in vbl.attributes.iter() { + attributes.push(naga::back::msl::AttributeMapping { + shader_location: attribute.shader_location, + offset: attribute.offset as u32, + format: convert_vertex_format_to_naga(attribute.format), }); } + let mapping = naga::back::msl::VertexBufferMapping { + id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, + stride: if vbl.array_stride > 0 { + vbl.array_stride.try_into().unwrap() + } else { + vbl.attributes + .iter() + .map(|attribute| attribute.offset + attribute.format.size()) + .max() + .unwrap_or(0) + .try_into() + .unwrap() + }, + step_mode: match (vbl.array_stride == 0, vbl.step_mode) { + (true, _) => naga::back::msl::VertexBufferStepMode::Constant, + (false, wgt::VertexStepMode::Vertex) => { + naga::back::msl::VertexBufferStepMode::ByVertex + } + (false, wgt::VertexStepMode::Instance) => { + naga::back::msl::VertexBufferStepMode::ByInstance + } + }, + attributes, + }; + vertex_buffer_mappings.push(mapping); + } + + // Setup vertex shader + { let vs = self.load_shader( vertex_stage, &vertex_buffer_mappings, @@ -1188,7 +1196,7 @@ impl crate::Device for super::Device { ); } - super::PipelineStageInfo { + vs_info = Some(super::PipelineStageInfo { push_constants: desc.layout.push_constants_infos.vs, sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, sized_bindings: vs.sized_bindings, @@ -1196,8 +1204,10 @@ impl crate::Device for super::Device { library: Some(vs.library), raw_wg_size: Default::default(), work_group_memory_sizes: vec![], - } - }); + }); + } + + // Validate vertex buffer count if desc.layout.total_counters.vs.buffers + (vertex_buffers.len() as u32) > self.shared.private_caps.max_vertex_buffers { @@ -1212,6 +1222,7 @@ impl crate::Device for super::Device { )); } + // Set the pipeline vertex buffer info if !vertex_buffers.is_empty() { let vertex_descriptor = metal::VertexDescriptor::new(); for (i, vb) in vertex_buffers.iter().enumerate() { @@ -1250,14 +1261,19 @@ impl crate::Device for super::Device { } descriptor.set_vertex_descriptor(Some(vertex_descriptor)); } + MetalGenericRenderPipelineDescriptor::Standard(descriptor) } crate::VertexProcessor::Mesh { ref task_stage, ref mesh_stage, } => { + // Mesh pipeline specific setup + vs_info = None; let descriptor = metal::MeshRenderPipelineDescriptor::new(); + + // Setup task stage if let Some(ref task_stage) = task_stage { let ts = self.load_shader( task_stage, @@ -1285,6 +1301,8 @@ impl crate::Device for super::Device { } else { ts_info = None; } + + // Setup mesh stage { let ms = self.load_shader( mesh_stage, @@ -1310,9 +1328,13 @@ impl crate::Device for super::Device { work_group_memory_sizes: ms.wg_memory_sizes, }); } + MetalGenericRenderPipelineDescriptor::Mesh(descriptor) } }; + + // Standard and mesh render pipeline descriptors don't inherit from the same interface, despite sharing + // many methods. This function lets us call a function by name on whichever descriptor we are using. macro_rules! descriptor_fn { ($method:ident $( ( $($args:expr),* ) )? ) => { match descriptor { @@ -1372,6 +1394,7 @@ impl crate::Device for super::Device { } }; + // Setup pipeline color attachments for (i, ct) in desc.color_targets.iter().enumerate() { let at_descriptor = descriptor_fn!(color_attachments()) .object_at(i as u64) @@ -1402,6 +1425,7 @@ impl crate::Device for super::Device { } } + // Setup depth stencil state let depth_stencil = match desc.depth_stencil { Some(ref ds) => { let raw_format = self.shared.private_caps.map_format(ds.format); @@ -1424,6 +1448,7 @@ impl crate::Device for super::Device { None => None, }; + // Setup multisample state if desc.multisample.count != 1 { //TODO: handle sample mask match descriptor { @@ -1440,36 +1465,26 @@ impl crate::Device for super::Device { //descriptor.set_alpha_to_one_enabled(desc.multisample.alpha_to_one_enabled); } + // Set debug label if let Some(name) = desc.label { descriptor_fn!(set_label(name)); } + // Create the pipeline from descriptor let raw = match descriptor { - MetalGenericRenderPipelineDescriptor::Standard(d) => self - .shared - .device - .lock() - .new_render_pipeline_state(&d) - .map_err(|e| { - crate::PipelineError::Linkage( - wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, - format!("new_render_pipeline_state: {e:?}"), - ) - })?, - MetalGenericRenderPipelineDescriptor::Mesh(d) => self - .shared - .device - .lock() - .new_mesh_render_pipeline_state(&d) - .map_err(|e| { - crate::PipelineError::Linkage( - wgt::ShaderStages::TASK - | wgt::ShaderStages::MESH - | wgt::ShaderStages::FRAGMENT, - format!("new_mesh_render_pipeline_state: {e:?}"), - ) - })?, - }; + MetalGenericRenderPipelineDescriptor::Standard(d) => { + self.shared.device.lock().new_render_pipeline_state(&d) + } + MetalGenericRenderPipelineDescriptor::Mesh(d) => { + self.shared.device.lock().new_mesh_render_pipeline_state(&d) + } + } + .map_err(|e| { + crate::PipelineError::Linkage( + wgt::ShaderStages::VERTEX | wgt::ShaderStages::FRAGMENT, + format!("new_render_pipeline_state: {e:?}"), + ) + })?; self.counters.render_pipelines.add(1); From e100034614c00009f13e51346cbc1ad10b55b551 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 30 Oct 2025 17:02:07 -0500 Subject: [PATCH 75/89] Fixed bad validation, formatted mesh shader wgsl --- naga/src/valid/interface.rs | 1 + naga/tests/in/wgsl/mesh-shader.wgsl | 72 ++++++++++++++--------------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index f13dba1e584..f3f5a43c060 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -870,6 +870,7 @@ impl super::Validator { (crate::ShaderStage::Mesh, &None) => { return Err(EntryPointError::ExpectedMeshShaderAttributes.with_span()); } + (crate::ShaderStage::Mesh, &Some(..)) => {} (_, &Some(_)) => { return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span()); } diff --git a/naga/tests/in/wgsl/mesh-shader.wgsl b/naga/tests/in/wgsl/mesh-shader.wgsl index 70fc2aec333..7f094a82f81 100644 --- a/naga/tests/in/wgsl/mesh-shader.wgsl +++ b/naga/tests/in/wgsl/mesh-shader.wgsl @@ -1,71 +1,71 @@ enable mesh_shading; const positions = array( - vec4(0.,1.,0.,1.), - vec4(-1.,-1.,0.,1.), - vec4(1.,-1.,0.,1.) + vec4(0., 1., 0., 1.), + vec4(-1., -1., 0., 1.), + vec4(1., -1., 0., 1.) ); const colors = array( - vec4(0.,1.,0.,1.), - vec4(0.,0.,1.,1.), - vec4(1.,0.,0.,1.) + vec4(0., 1., 0., 1.), + vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.) ); struct TaskPayload { - colorMask: vec4, - visible: bool, + colorMask: vec4, + visible: bool, } var taskPayload: TaskPayload; var workgroupData: f32; struct VertexOutput { - @builtin(position) position: vec4, - @location(0) color: vec4, + @builtin(position) position: vec4, + @location(0) color: vec4, } struct PrimitiveOutput { - @builtin(triangle_indices) index: vec3, - @builtin(cull_primitive) cull: bool, - @per_primitive @location(1) colorMask: vec4, + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, } struct PrimitiveInput { - @per_primitive @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } @task @payload(taskPayload) @workgroup_size(1) fn ts_main() -> @builtin(mesh_task_size) vec3 { - workgroupData = 1.0; - taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); - taskPayload.visible = true; - return vec3(3, 1, 1); + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); } @mesh @payload(taskPayload) @vertex_output(VertexOutput, 3) @primitive_output(PrimitiveOutput, 1) @workgroup_size(1) fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { - setMeshOutputs(3, 1); - workgroupData = 2.0; - var v: VertexOutput; + setMeshOutputs(3, 1); + workgroupData = 2.0; + var v: VertexOutput; - v.position = positions[0]; - v.color = colors[0] * taskPayload.colorMask; - setVertex(0, v); + v.position = positions[0]; + v.color = colors[0] * taskPayload.colorMask; + setVertex(0, v); - v.position = positions[1]; - v.color = colors[1] * taskPayload.colorMask; - setVertex(1, v); + v.position = positions[1]; + v.color = colors[1] * taskPayload.colorMask; + setVertex(1, v); - v.position = positions[2]; - v.color = colors[2] * taskPayload.colorMask; - setVertex(2, v); + v.position = positions[2]; + v.color = colors[2] * taskPayload.colorMask; + setVertex(2, v); - var p: PrimitiveOutput; - p.index = vec3(0, 1, 2); - p.cull = !taskPayload.visible; - p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); - setPrimitive(0, p); + var p: PrimitiveOutput; + p.index = vec3(0, 1, 2); + p.cull = !taskPayload.visible; + p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); + setPrimitive(0, p); } @fragment fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { - return vertex.color * primitive.colorMask; + return vertex.color * primitive.colorMask; } From edea07e16c6d2979dbfab910bf7ab25ac9e7c2fb Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 30 Oct 2025 19:31:42 -0500 Subject: [PATCH 76/89] Rewrote the IR and parser significantly --- naga/src/back/dot/mod.rs | 19 - naga/src/back/glsl/mod.rs | 11 +- naga/src/back/hlsl/conv.rs | 6 +- naga/src/back/hlsl/writer.rs | 13 - naga/src/back/msl/mod.rs | 6 +- naga/src/back/msl/writer.rs | 8 - naga/src/back/pipeline_constants.rs | 20 - naga/src/back/spv/block.rs | 1 - naga/src/back/spv/writer.rs | 6 +- naga/src/back/wgsl/writer.rs | 1 - naga/src/common/wgsl/to_wgsl.rs | 6 +- naga/src/compact/statements.rs | 34 -- naga/src/front/spv/mod.rs | 1 - naga/src/front/wgsl/error.rs | 25 - naga/src/front/wgsl/lower/mod.rs | 153 +----- naga/src/front/wgsl/parse/ast.rs | 10 +- naga/src/front/wgsl/parse/conv.rs | 7 +- naga/src/front/wgsl/parse/mod.rs | 47 +- naga/src/ir/mod.rs | 40 +- naga/src/proc/mod.rs | 148 ++++++ naga/src/proc/terminator.rs | 1 - naga/src/valid/analyzer.rs | 30 -- naga/src/valid/function.rs | 35 -- naga/src/valid/handles.rs | 16 - naga/src/valid/interface.rs | 75 ++- naga/tests/in/wgsl/mesh-shader.wgsl | 39 +- .../out/analysis/wgsl-mesh-shader.info.ron | 422 ++++++++++++--- .../tests/out/ir/wgsl-mesh-shader.compact.ron | 480 +++++++++++------- naga/tests/out/ir/wgsl-mesh-shader.ron | 480 +++++++++++------- 29 files changed, 1256 insertions(+), 884 deletions(-) diff --git a/naga/src/back/dot/mod.rs b/naga/src/back/dot/mod.rs index 1f1396eccff..826dad1c219 100644 --- a/naga/src/back/dot/mod.rs +++ b/naga/src/back/dot/mod.rs @@ -307,25 +307,6 @@ impl StatementGraph { crate::RayQueryFunction::Terminate => "RayQueryTerminate", } } - S::MeshFunction(crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - }) => { - self.dependencies.push((id, vertex_count, "vertex_count")); - self.dependencies - .push((id, primitive_count, "primitive_count")); - "SetMeshOutputs" - } - S::MeshFunction(crate::MeshFunction::SetVertex { index, value }) => { - self.dependencies.push((id, index, "index")); - self.dependencies.push((id, value, "value")); - "SetVertex" - } - S::MeshFunction(crate::MeshFunction::SetPrimitive { index, value }) => { - self.dependencies.push((id, index, "index")); - self.dependencies.push((id, value, "value")); - "SetPrimitive" - } S::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.dependencies.push((id, predicate, "predicate")); diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 716bc8049e0..f29504010d5 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -2675,11 +2675,6 @@ impl<'a, W: Write> Writer<'a, W> { self.write_image_atomic(ctx, image, coordinate, array_index, fun, value)? } Statement::RayQuery { .. } => unreachable!(), - Statement::MeshFunction( - crate::MeshFunction::SetMeshOutputs { .. } - | crate::MeshFunction::SetVertex { .. } - | crate::MeshFunction::SetPrimitive { .. }, - ) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); @@ -5265,7 +5260,11 @@ const fn glsl_built_in(built_in: crate::BuiltIn, options: VaryingOptions) -> &'s | Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices - | Bi::MeshTaskSize => { + | Bi::MeshTaskSize + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => { unimplemented!() } } diff --git a/naga/src/back/hlsl/conv.rs b/naga/src/back/hlsl/conv.rs index 5cd43e14297..ce7f0bc3dc7 100644 --- a/naga/src/back/hlsl/conv.rs +++ b/naga/src/back/hlsl/conv.rs @@ -186,7 +186,11 @@ impl crate::BuiltIn { } Self::CullPrimitive => "SV_CullPrimitive", Self::PointIndex | Self::LineIndices | Self::TriangleIndices => unimplemented!(), - Self::MeshTaskSize => unreachable!(), + Self::MeshTaskSize + | Self::VertexCount + | Self::PrimitiveCount + | Self::Vertices + | Self::Primitives => unreachable!(), }) } } diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 6f0ba814a52..8806137d65a 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -2600,19 +2600,6 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { writeln!(self.out, ".Abort();")?; } }, - Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - }) => { - write!(self.out, "{level}SetMeshOutputCounts(")?; - self.write_expr(module, vertex_count, func_ctx)?; - write!(self.out, ", ")?; - self.write_expr(module, primitive_count, func_ctx)?; - write!(self.out, ");")?; - } - Statement::MeshFunction( - crate::MeshFunction::SetVertex { .. } | crate::MeshFunction::SetPrimitive { .. }, - ) => unimplemented!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = Baked(result).to_string(); diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index c8e5f68be9a..abb596020f8 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -707,7 +707,11 @@ impl ResolvedBinding { Bi::CullPrimitive => "primitive_culled", // TODO: figure out how to make this written as a function call Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices => unimplemented!(), - Bi::MeshTaskSize => unreachable!(), + Bi::MeshTaskSize + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => unreachable!(), }; write!(out, "{name}")?; } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 484142630d2..ca7da02a930 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -4063,14 +4063,6 @@ impl Writer { } } } - // TODO: write emitters for these - crate::Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { .. }) => { - unimplemented!() - } - crate::Statement::MeshFunction( - crate::MeshFunction::SetVertex { .. } - | crate::MeshFunction::SetPrimitive { .. }, - ) => unimplemented!(), crate::Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let name = self.namer.call(""); diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs index 109cc591e74..de643b82fab 100644 --- a/naga/src/back/pipeline_constants.rs +++ b/naga/src/back/pipeline_constants.rs @@ -860,26 +860,6 @@ fn adjust_stmt(new_pos: &HandleVec>, stmt: &mut S crate::RayQueryFunction::Terminate => {} } } - Statement::MeshFunction(crate::MeshFunction::SetMeshOutputs { - ref mut vertex_count, - ref mut primitive_count, - }) => { - adjust(vertex_count); - adjust(primitive_count); - } - Statement::MeshFunction( - crate::MeshFunction::SetVertex { - ref mut index, - ref mut value, - } - | crate::MeshFunction::SetPrimitive { - ref mut index, - ref mut value, - }, - ) => { - adjust(index); - adjust(value); - } Statement::Break | Statement::Continue | Statement::Kill diff --git a/naga/src/back/spv/block.rs b/naga/src/back/spv/block.rs index d0556acdc53..dd9a3811687 100644 --- a/naga/src/back/spv/block.rs +++ b/naga/src/back/spv/block.rs @@ -3655,7 +3655,6 @@ impl BlockContext<'_> { } => { self.write_subgroup_gather(mode, argument, result, &mut block)?; } - Statement::MeshFunction(_) => unreachable!(), } } diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index 1beb86577c8..ee1ea847739 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -2156,7 +2156,11 @@ impl Writer { | Bi::CullPrimitive | Bi::PointIndex | Bi::LineIndices - | Bi::TriangleIndices => unreachable!(), + | Bi::TriangleIndices + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => unreachable!(), }; self.decorate(id, Decoration::BuiltIn, &[built_in as u32]); diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index d1ebf62e6ee..daf32a7116f 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -856,7 +856,6 @@ impl Writer { } } Statement::RayQuery { .. } => unreachable!(), - Statement::MeshFunction(..) => unreachable!(), Statement::SubgroupBallot { result, predicate } => { write!(self.out, "{level}")?; let res_name = Baked(result).to_string(); diff --git a/naga/src/common/wgsl/to_wgsl.rs b/naga/src/common/wgsl/to_wgsl.rs index 25847a5df7b..5e6178c049c 100644 --- a/naga/src/common/wgsl/to_wgsl.rs +++ b/naga/src/common/wgsl/to_wgsl.rs @@ -194,7 +194,11 @@ impl TryToWgsl for crate::BuiltIn { | Bi::TriangleIndices | Bi::LineIndices | Bi::MeshTaskSize - | Bi::PointIndex => return None, + | Bi::PointIndex + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => return None, }) } } diff --git a/naga/src/compact/statements.rs b/naga/src/compact/statements.rs index b370501baca..39d6065f5f0 100644 --- a/naga/src/compact/statements.rs +++ b/naga/src/compact/statements.rs @@ -117,20 +117,6 @@ impl FunctionTracer<'_> { self.expressions_used.insert(query); self.trace_ray_query_function(fun); } - St::MeshFunction(crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - }) => { - self.expressions_used.insert(vertex_count); - self.expressions_used.insert(primitive_count); - } - St::MeshFunction( - crate::MeshFunction::SetPrimitive { index, value } - | crate::MeshFunction::SetVertex { index, value }, - ) => { - self.expressions_used.insert(index); - self.expressions_used.insert(value); - } St::SubgroupBallot { result, predicate } => { if let Some(predicate) = predicate { self.expressions_used.insert(predicate); @@ -349,26 +335,6 @@ impl FunctionMap { adjust(query); self.adjust_ray_query_function(fun); } - St::MeshFunction(crate::MeshFunction::SetMeshOutputs { - ref mut vertex_count, - ref mut primitive_count, - }) => { - adjust(vertex_count); - adjust(primitive_count); - } - St::MeshFunction( - crate::MeshFunction::SetVertex { - ref mut index, - ref mut value, - } - | crate::MeshFunction::SetPrimitive { - ref mut index, - ref mut value, - }, - ) => { - adjust(index); - adjust(value); - } St::SubgroupBallot { ref mut result, ref mut predicate, diff --git a/naga/src/front/spv/mod.rs b/naga/src/front/spv/mod.rs index 2a3a971a8bf..ac9eaf8306f 100644 --- a/naga/src/front/spv/mod.rs +++ b/naga/src/front/spv/mod.rs @@ -4661,7 +4661,6 @@ impl> Frontend { | S::Atomic { .. } | S::ImageAtomic { .. } | S::RayQuery { .. } - | S::MeshFunction(..) | S::SubgroupBallot { .. } | S::SubgroupCollectiveOperation { .. } | S::SubgroupGather { .. } => {} diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index 004528dbe91..a8958525ad1 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -406,19 +406,9 @@ pub(crate) enum Error<'a> { accept_span: Span, accept_type: String, }, - MissingMeshShaderInfo { - mesh_attribute_span: Span, - }, - OneMeshShaderAttribute { - attribute_span: Span, - }, ExpectedGlobalVariable { name_span: Span, }, - MeshPrimitiveNoDefinedTopology { - attribute_span: Span, - struct_span: Span, - }, StructMemberTooLarge { member_name_span: Span, }, @@ -1383,27 +1373,12 @@ impl<'a> Error<'a> { ], notes: vec![], }, - Error::MissingMeshShaderInfo { mesh_attribute_span} => ParseError { - message: "mesh shader entry point is missing both `@vertex_output` and `@primitive_output`".into(), - labels: vec![(mesh_attribute_span, "mesh shader entry declared here".into())], - notes: vec![], - }, - Error::OneMeshShaderAttribute { attribute_span } => ParseError { - message: "only one of `@vertex_output` or `@primitive_output` was given".into(), - labels: vec![(attribute_span, "only one of @vertex_output or @primitive_output is provided".into())], - notes: vec![], - }, Error::ExpectedGlobalVariable { name_span } => ParseError { message: "expected global variable".to_string(), // TODO: I would like to also include the global declaration span labels: vec![(name_span, "variable used here".into())], notes: vec![], }, - Error::MeshPrimitiveNoDefinedTopology { struct_span, attribute_span } => ParseError { - message: "mesh primitive struct must have exactly one of point indices, line indices, or triangle indices".to_string(), - labels: vec![(attribute_span, "primitive type declared here".into()), (struct_span, "primitive struct declared here".into())], - notes: vec![] - }, Error::StructMemberTooLarge { member_name_span } => ParseError { message: "struct member is too large".into(), labels: vec![(member_name_span, "this member exceeds the maximum size".into())], diff --git a/naga/src/front/wgsl/lower/mod.rs b/naga/src/front/wgsl/lower/mod.rs index ef63e6aaea7..33a1de6d579 100644 --- a/naga/src/front/wgsl/lower/mod.rs +++ b/naga/src/front/wgsl/lower/mod.rs @@ -1520,88 +1520,34 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ([0; 3], None) }; - let mesh_info = if let Some(mesh_info) = entry.mesh_shader_info { - let mut const_u32 = |expr| match self.const_u32(expr, &mut ctx.as_const()) { - Ok(value) => Ok((value.0, None)), - Err(err) => { - if let Error::ConstantEvaluatorError(ref ty, _) = *err { - match **ty { - proc::ConstantEvaluatorError::OverrideExpr => Ok(( - 0, - Some( - // This is dubious but it seems the code isn't workgroup size specific - self.workgroup_size_override(expr, &mut ctx.as_override())?, - ), - )), - _ => Err(err), - } - } else { - Err(err) - } - } - }; - let (max_vertices, max_vertices_override) = const_u32(mesh_info.vertex_count)?; - let (max_primitives, max_primitives_override) = - const_u32(mesh_info.primitive_count)?; - let vertex_output_type = - self.resolve_ast_type(mesh_info.vertex_type.0, &mut ctx.as_const())?; - let primitive_output_type = - self.resolve_ast_type(mesh_info.primitive_type.0, &mut ctx.as_const())?; - - let mut topology = None; - let struct_span = ctx.module.types.get_span(primitive_output_type); - match &ctx.module.types[primitive_output_type].inner { - &ir::TypeInner::Struct { - ref members, - span: _, - } => { - for member in members { - let out_topology = match member.binding { - Some(ir::Binding::BuiltIn(ir::BuiltIn::TriangleIndices)) => { - Some(ir::MeshOutputTopology::Triangles) - } - Some(ir::Binding::BuiltIn(ir::BuiltIn::LineIndices)) => { - Some(ir::MeshOutputTopology::Lines) - } - _ => None, - }; - if out_topology.is_some() { - if topology.is_some() { - return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { - attribute_span: mesh_info.primitive_type.1, - struct_span, - })); - } - topology = out_topology; - } - } - } - _ => { - return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { - attribute_span: mesh_info.primitive_type.1, - struct_span, + let mesh_info = if let Some((var_name, var_span)) = entry.mesh_output_variable { + let var = match ctx.globals.get(var_name) { + Some(&LoweredGlobalDecl::Var(handle)) => handle, + Some(_) => { + return Err(Box::new(Error::ExpectedGlobalVariable { + name_span: var_span, })) } - } - let topology = if let Some(t) = topology { - t - } else { - return Err(Box::new(Error::MeshPrimitiveNoDefinedTopology { - attribute_span: mesh_info.primitive_type.1, - struct_span, - })); + None => return Err(Box::new(Error::UnknownIdent(var_span, var_name))), }; - Some(ir::MeshStageInfo { - max_vertices, - max_vertices_override, - max_primitives, - max_primitives_override, + let mut info = ctx.module.analyze_mesh_shader_info(var); + if let Some(h) = info.1[0] { + info.0.max_vertices_override = Some( + ctx.module + .global_expressions + .append(crate::Expression::Override(h), Span::UNDEFINED), + ); + } + if let Some(h) = info.1[1] { + info.0.max_primitives_override = Some( + ctx.module + .global_expressions + .append(crate::Expression::Override(h), Span::UNDEFINED), + ); + } - vertex_output_type, - primitive_output_type, - topology, - }) + Some(info.0) } else { None }; @@ -3232,59 +3178,6 @@ impl<'source, 'temp> Lowerer<'source, 'temp> { ); return Ok(Some(result)); } - - "setMeshOutputs" | "setVertex" | "setPrimitive" => { - let mut args = ctx.prepare_args(arguments, 2, span); - let arg1 = args.next()?; - let arg2 = args.next()?; - args.finish()?; - - let mut cast_u32 = |arg| { - // Try to convert abstract values to the known argument types - let expr = self.expression_for_abstract(arg, ctx)?; - let goal_ty = - ctx.ensure_type_exists(ir::TypeInner::Scalar(ir::Scalar::U32)); - ctx.try_automatic_conversions( - expr, - &proc::TypeResolution::Handle(goal_ty), - ctx.ast_expressions.get_span(arg), - ) - }; - - let arg1 = cast_u32(arg1)?; - let arg2 = if function.name == "setMeshOutputs" { - cast_u32(arg2)? - } else { - self.expression(arg2, ctx)? - }; - - let rctx = ctx.runtime_expression_ctx(span)?; - - // Emit all previous expressions, even if not used directly - rctx.block - .extend(rctx.emitter.finish(&rctx.function.expressions)); - rctx.block.push( - crate::Statement::MeshFunction(match function.name { - "setMeshOutputs" => crate::MeshFunction::SetMeshOutputs { - vertex_count: arg1, - primitive_count: arg2, - }, - "setVertex" => crate::MeshFunction::SetVertex { - index: arg1, - value: arg2, - }, - "setPrimitive" => crate::MeshFunction::SetPrimitive { - index: arg1, - value: arg2, - }, - _ => unreachable!(), - }), - span, - ); - rctx.emitter.start(&rctx.function.expressions); - - return Ok(None); - } _ => { return Err(Box::new(Error::UnknownIdent(function.span, function.name))) } diff --git a/naga/src/front/wgsl/parse/ast.rs b/naga/src/front/wgsl/parse/ast.rs index 49ecddfdee5..04964e7ba5f 100644 --- a/naga/src/front/wgsl/parse/ast.rs +++ b/naga/src/front/wgsl/parse/ast.rs @@ -128,18 +128,10 @@ pub struct EntryPoint<'a> { pub stage: crate::ShaderStage, pub early_depth_test: Option, pub workgroup_size: Option<[Option>>; 3]>, - pub mesh_shader_info: Option>, + pub mesh_output_variable: Option<(&'a str, Span)>, pub task_payload: Option<(&'a str, Span)>, } -#[derive(Debug, Clone, Copy)] -pub struct EntryPointMeshShaderInfo<'a> { - pub vertex_count: Handle>, - pub primitive_count: Handle>, - pub vertex_type: (Handle>, Span), - pub primitive_type: (Handle>, Span), -} - #[cfg(doc)] use crate::front::wgsl::lower::{LocalExpressionContext, StatementContext}; diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 3b96bde7c9e..16e814f56f5 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -53,10 +53,15 @@ pub fn map_built_in( "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, // mesh "cull_primitive" => crate::BuiltIn::CullPrimitive, - "point_index" => crate::BuiltIn::PointIndex, + "vertex_indices" => crate::BuiltIn::PointIndex, "line_indices" => crate::BuiltIn::LineIndices, "triangle_indices" => crate::BuiltIn::TriangleIndices, "mesh_task_size" => crate::BuiltIn::MeshTaskSize, + // mesh global variable + "vertex_count" => crate::BuiltIn::VertexCount, + "vertices" => crate::BuiltIn::Vertices, + "primitive_count" => crate::BuiltIn::PrimitiveCount, + "primitives" => crate::BuiltIn::Primitives, _ => return Err(Box::new(Error::UnknownBuiltin(span))), }; match built_in { diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 29376614d6e..94df933a6a9 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -2803,8 +2803,7 @@ impl Parser { (ParsedAttribute::default(), ParsedAttribute::default()); let mut id = ParsedAttribute::default(); let mut payload = ParsedAttribute::default(); - let mut vertex_output = ParsedAttribute::default(); - let mut primitive_output = ParsedAttribute::default(); + let mut mesh_output = ParsedAttribute::default(); let mut must_use: ParsedAttribute = ParsedAttribute::default(); @@ -2872,27 +2871,16 @@ impl Parser { "mesh" => { stage.set(ShaderStage::Mesh, name_span)?; compute_like_span = name_span; + + lexer.expect(Token::Paren('('))?; + mesh_output.set(lexer.next_ident_with_span()?, name_span)?; + lexer.expect(Token::Paren(')'))?; } "payload" => { lexer.expect(Token::Paren('('))?; payload.set(lexer.next_ident_with_span()?, name_span)?; lexer.expect(Token::Paren(')'))?; } - "vertex_output" | "primitive_output" => { - lexer.expect(Token::Paren('('))?; - let type_span = lexer.peek().1; - let r#type = self.type_decl(lexer, &mut ctx)?; - let type_span = lexer.span_from(type_span.to_range().unwrap().start); - lexer.expect(Token::Separator(','))?; - let max_output = self.general_expression(lexer, &mut ctx)?; - let end_span = lexer.expect_span(Token::Paren(')'))?; - let total_span = name_span.until(&end_span); - if name == "vertex_output" { - vertex_output.set((r#type, type_span, max_output), total_span)?; - } else if name == "primitive_output" { - primitive_output.set((r#type, type_span, max_output), total_span)?; - } - } "workgroup_size" => { lexer.expect(Token::Paren('('))?; let mut new_workgroup_size = [None; 3]; @@ -3060,35 +3048,12 @@ impl Parser { if stage.compute_like() && workgroup_size.value.is_none() { return Err(Box::new(Error::MissingWorkgroupSize(compute_like_span))); } - if stage == ShaderStage::Mesh - && (vertex_output.value.is_none() || primitive_output.value.is_none()) - { - return Err(Box::new(Error::MissingMeshShaderInfo { - mesh_attribute_span: compute_like_span, - })); - } - let mesh_shader_info = match (vertex_output.value, primitive_output.value) { - (Some(vertex_output), Some(primitive_output)) => { - Some(ast::EntryPointMeshShaderInfo { - vertex_count: vertex_output.2, - primitive_count: primitive_output.2, - vertex_type: (vertex_output.0, vertex_output.1), - primitive_type: (primitive_output.0, primitive_output.1), - }) - } - (None, None) => None, - (Some(v), None) | (None, Some(v)) => { - return Err(Box::new(Error::OneMeshShaderAttribute { - attribute_span: v.1, - })) - } - }; Some(ast::EntryPoint { stage, early_depth_test: early_depth_test.value, workgroup_size: workgroup_size.value, - mesh_shader_info, + mesh_output_variable: mesh_output.value, task_payload: payload.value, }) } else { diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 4093d823b4b..097220a46bb 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -450,6 +450,15 @@ pub enum BuiltIn { LineIndices, /// Written in mesh shaders TriangleIndices, + + /// Written to a workgroup variable in mesh shaders + VertexCount, + /// Written to a workgroup variable in mesh shaders + Vertices, + /// Written to a workgroup variable in mesh shaders + PrimitiveCount, + /// Written to a workgroup variable in mesh shaders + Primitives, } /// Number of bytes per scalar. @@ -2211,8 +2220,6 @@ pub enum Statement { /// The specific operation we're performing on `query`. fun: RayQueryFunction, }, - /// A mesh shader intrinsic. - MeshFunction(MeshFunction), /// Calculate a bitmask using a boolean from each active thread in the subgroup SubgroupBallot { /// The [`SubgroupBallotResult`] expression representing this load's result. @@ -2569,7 +2576,7 @@ pub struct DocComments { } /// The output topology for a mesh shader. Note that mesh shaders don't allow things like triangle-strips. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2583,7 +2590,7 @@ pub enum MeshOutputTopology { } /// Information specific to mesh shader entry points. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serialize", derive(Serialize))] #[cfg_attr(feature = "deserialize", derive(Deserialize))] #[cfg_attr(feature = "arbitrary", derive(Arbitrary))] @@ -2603,29 +2610,8 @@ pub struct MeshStageInfo { pub vertex_output_type: Handle, /// The type used by primitive outputs, i.e. what is passed to `setPrimitive`. pub primitive_output_type: Handle, -} - -/// Mesh shader intrinsics -#[derive(Debug, Clone, Copy)] -#[cfg_attr(feature = "serialize", derive(Serialize))] -#[cfg_attr(feature = "deserialize", derive(Deserialize))] -#[cfg_attr(feature = "arbitrary", derive(Arbitrary))] -pub enum MeshFunction { - /// Sets the number of vertices and primitives that will be outputted. - SetMeshOutputs { - vertex_count: Handle, - primitive_count: Handle, - }, - /// Sets the output vertex at a given index. - SetVertex { - index: Handle, - value: Handle, - }, - /// Sets the output primitive at a given index. - SetPrimitive { - index: Handle, - value: Handle, - }, + /// The global variable holding the outputted vertices, primitives, and counts + pub output_variable: Handle, } /// Shader module. diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index eca63ee4fb5..dd2ae459373 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -27,6 +27,8 @@ use thiserror::Error; pub use type_methods::min_max_float_representable_by; pub use typifier::{compare_types, ResolveContext, ResolveError, TypeResolution}; +use crate::non_max_u32::NonMaxU32; + impl From for super::Scalar { fn from(format: super::StorageFormat) -> Self { use super::{ScalarKind as Sk, StorageFormat as Sf}; @@ -653,3 +655,149 @@ fn test_matrix_size() { 48, ); } + +impl crate::Module { + /// Extracts mesh shader info from a mesh output global variable. Used in frontends + /// and by validators. This only validates the output variable itself, and not the + /// vertex and primitive output types. + #[allow(clippy::type_complexity)] + pub fn analyze_mesh_shader_info( + &self, + gv: crate::Handle, + ) -> ( + crate::MeshStageInfo, + [Option>; 2], + Option>, + ) { + use crate::span::AddSpan; + use crate::valid::EntryPointError; + let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap()); + let mut output = crate::MeshStageInfo { + topology: crate::MeshOutputTopology::Triangles, + max_vertices: 0, + max_vertices_override: None, + max_primitives: 0, + max_primitives_override: None, + vertex_output_type: null_type, + primitive_output_type: null_type, + output_variable: gv, + }; + let mut error = None; + let typ = &self.types[self.global_variables[gv].ty].inner; + + let mut topology = output.topology; + // Max, max override, type + let mut vertex_info = (0, None, null_type); + let mut primitive_info = (0, None, null_type); + + match typ { + &crate::TypeInner::Struct { ref members, .. } => { + let mut builtins = crate::FastHashSet::default(); + for member in members { + match member.binding { + Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => { + if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) { + error = Some(EntryPointError::BadMeshOutputVariableField); + } + if builtins.contains(&crate::BuiltIn::VertexCount) { + error = Some(EntryPointError::BadMeshOutputVarableType); + } + builtins.insert(crate::BuiltIn::VertexCount); + } + Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => { + if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) { + error = Some(EntryPointError::BadMeshOutputVariableField); + } + if builtins.contains(&crate::BuiltIn::PrimitiveCount) { + error = Some(EntryPointError::BadMeshOutputVarableType); + } + builtins.insert(crate::BuiltIn::PrimitiveCount); + } + Some(crate::Binding::BuiltIn( + crate::BuiltIn::Vertices | crate::BuiltIn::Primitives, + )) => { + let ty = &self.types[member.ty].inner; + let (a, b, c) = match ty { + &crate::TypeInner::Array { base, size, .. } => { + let ty = base; + let (max, max_override) = match size { + crate::ArraySize::Constant(a) => (a.get(), None), + crate::ArraySize::Pending(o) => (0, Some(o)), + crate::ArraySize::Dynamic => { + error = + Some(EntryPointError::BadMeshOutputVariableField); + (0, None) + } + }; + (max, max_override, ty) + } + _ => { + error = Some(EntryPointError::BadMeshOutputVariableField); + (0, None, null_type) + } + }; + if matches!( + member.binding, + Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives)) + ) { + primitive_info = (a, b, c); + match self.types[c].inner { + crate::TypeInner::Struct { ref members, .. } => { + for member in members { + match member.binding { + Some(crate::Binding::BuiltIn( + crate::BuiltIn::PointIndex, + )) => { + topology = crate::MeshOutputTopology::Points; + } + Some(crate::Binding::BuiltIn( + crate::BuiltIn::LineIndices, + )) => { + topology = crate::MeshOutputTopology::Lines; + } + Some(crate::Binding::BuiltIn( + crate::BuiltIn::TriangleIndices, + )) => { + topology = crate::MeshOutputTopology::Triangles; + } + _ => (), + } + } + } + _ => (), + } + if builtins.contains(&crate::BuiltIn::Primitives) { + error = Some(EntryPointError::BadMeshOutputVarableType); + } + builtins.insert(crate::BuiltIn::Primitives); + } else { + vertex_info = (a, b, c); + if builtins.contains(&crate::BuiltIn::Vertices) { + error = Some(EntryPointError::BadMeshOutputVarableType); + } + builtins.insert(crate::BuiltIn::Vertices); + } + } + _ => error = Some(EntryPointError::BadMeshOutputVarableType), + } + } + output = crate::MeshStageInfo { + topology, + max_vertices: vertex_info.0, + max_vertices_override: None, + vertex_output_type: vertex_info.2, + max_primitives: primitive_info.0, + max_primitives_override: None, + primitive_output_type: primitive_info.2, + ..output + } + } + _ => error = Some(EntryPointError::BadMeshOutputVarableType), + } + ( + output, + [vertex_info.1, primitive_info.1], + error.map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)), + ) + } +} diff --git a/naga/src/proc/terminator.rs b/naga/src/proc/terminator.rs index f76d4c06a3b..b29ccb054a3 100644 --- a/naga/src/proc/terminator.rs +++ b/naga/src/proc/terminator.rs @@ -36,7 +36,6 @@ pub fn ensure_block_returns(block: &mut crate::Block) { | S::ImageStore { .. } | S::Call { .. } | S::RayQuery { .. } - | S::MeshFunction(..) | S::Atomic { .. } | S::ImageAtomic { .. } | S::WorkGroupUniformLoad { .. } diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 6ef2ca0988d..5befdfe22a6 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -1155,36 +1155,6 @@ impl FunctionInfo { } FunctionUniformity::new() } - S::MeshFunction(func) => { - self.available_stages |= ShaderStages::MESH; - match &func { - // TODO: double check all of this uniformity stuff. I frankly don't fully understand all of it. - &crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - } => { - let _ = self.add_ref(vertex_count); - let _ = self.add_ref(primitive_count); - FunctionUniformity::new() - } - &crate::MeshFunction::SetVertex { index, value } - | &crate::MeshFunction::SetPrimitive { index, value } => { - let _ = self.add_ref(index); - let _ = self.add_ref(value); - let ty = self.expressions[value.index()].ty.handle().ok_or( - FunctionError::InvalidMeshShaderOutputType(value).with_span(), - )?; - - if matches!(func, crate::MeshFunction::SetVertex { .. }) { - self.try_update_mesh_vertex_type(ty, value)?; - } else { - self.try_update_mesh_primitive_type(ty, value)?; - }; - - FunctionUniformity::new() - } - } - } S::SubgroupBallot { result: _, predicate, diff --git a/naga/src/valid/function.rs b/naga/src/valid/function.rs index 0216c6ef7f6..abf6bc430a6 100644 --- a/naga/src/valid/function.rs +++ b/naga/src/valid/function.rs @@ -1547,41 +1547,6 @@ impl super::Validator { crate::RayQueryFunction::Terminate => {} } } - S::MeshFunction(func) => { - let ensure_u32 = - |expr: Handle| -> Result<(), WithSpan> { - let u32_ty = TypeResolution::Value(Ti::Scalar(crate::Scalar::U32)); - let ty = context - .resolve_type_impl(expr, &self.valid_expression_set) - .map_err_inner(|source| { - FunctionError::Expression { - source, - handle: expr, - } - .with_span_handle(expr, context.expressions) - })?; - if !context.compare_types(&u32_ty, ty) { - return Err(FunctionError::InvalidMeshFunctionCall(expr) - .with_span_handle(expr, context.expressions)); - } - Ok(()) - }; - match func { - crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - } => { - ensure_u32(vertex_count)?; - ensure_u32(primitive_count)?; - } - crate::MeshFunction::SetVertex { index, value: _ } - | crate::MeshFunction::SetPrimitive { index, value: _ } => { - ensure_u32(index)?; - // Value is validated elsewhere (since the value type isn't known ahead of time but must match for all calls - // in a function or the function's called functions) - } - } - } S::SubgroupBallot { result, predicate } => { stages &= self.subgroup_stages; if !self.capabilities.contains(super::Capabilities::SUBGROUP) { diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index adb9f355c11..7fe6fa8803d 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -815,22 +815,6 @@ impl super::Validator { } Ok(()) } - crate::Statement::MeshFunction(func) => match func { - crate::MeshFunction::SetMeshOutputs { - vertex_count, - primitive_count, - } => { - validate_expr(vertex_count)?; - validate_expr(primitive_count)?; - Ok(()) - } - crate::MeshFunction::SetVertex { index, value } - | crate::MeshFunction::SetPrimitive { index, value } => { - validate_expr(index)?; - validate_expr(value)?; - Ok(()) - } - }, crate::Statement::SubgroupBallot { result, predicate } => { validate_expr_opt(predicate)?; validate_expr(result)?; diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index f3f5a43c060..e5e7b6997b1 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -141,24 +141,31 @@ pub enum EntryPointError { TaskPayloadWrongAddressSpace, #[error("For a task payload to be used, it must be declared with @payload")] WrongTaskPayloadUsed, - #[error("A function can only set vertex and primitive types that correspond to the mesh shader attributes")] - WrongMeshOutputType, - #[error("Only mesh shader entry points can write to mesh output vertices and primitives")] - UnexpectedMeshShaderOutput, - #[error("Mesh shader entry point cannot have a return type")] - UnexpectedMeshShaderEntryResult, #[error("Task shader entry point must return @builtin(mesh_task_size) vec3")] WrongTaskShaderEntryResult, - #[error("Mesh output type must be a user-defined struct.")] - InvalidMeshOutputType, - #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] - InvalidMeshPrimitiveOutputType, #[error("Task shaders must declare a task payload output")] ExpectedTaskPayload, #[error( - "The `MESH_SHADER` capability must be enabled to compile mesh shaders and task shaders." + "The `MESH_SHADER` capability must be enabled to compile mesh shaders and task shaders" )] MeshShaderCapabilityDisabled, + + #[error( + "Mesh shader output variable must be a struct with fields that are all allowed builtins" + )] + BadMeshOutputVarableType, + #[error("Mesh shader output variable fields must have types that are in accordance with the mesh shader spec")] + BadMeshOutputVariableField, + #[error("Mesh shader entry point cannot have a return type")] + UnexpectedMeshShaderEntryResult, + #[error( + "Mesh output type must be a user-defined struct with fields in alignment with the mesh shader spec" + )] + InvalidMeshOutputType, + #[error("Mesh primitive outputs must have exactly one of `@builtin(triangle_indices)`, `@builtin(line_indices)`, or `@builtin(point_index)`")] + InvalidMeshPrimitiveOutputType, + #[error("Mesh output global variable must live in the workgroup address space")] + WrongMeshOutputAddressSpace, } fn storage_usage(access: crate::StorageAccess) -> GlobalUse { @@ -390,6 +397,10 @@ impl VaryingContext<'_> { scalar: crate::Scalar::U32, }, ), + // Validated elsewhere + Bi::VertexCount | Bi::PrimitiveCount | Bi::Vertices | Bi::Primitives => { + (true, true) + } }; if !visible { @@ -1074,24 +1085,44 @@ impl super::Validator { } } + // TODO: validate mesh entry point info + // If this is a `Mesh` entry point, check its vertex and primitive output types. // We verified previously that only mesh shaders can have `mesh_info`. if let &Some(ref mesh_info) = &ep.mesh_info { - // Mesh shaders don't return any value. All their results are supplied through - // [`SetVertex`] and [`SetPrimitive`] calls. - if let Some((used_vertex_type, _)) = info.mesh_shader_info.vertex_type { - if used_vertex_type != mesh_info.vertex_output_type { - return Err(EntryPointError::WrongMeshOutputType - .with_span_handle(mesh_info.vertex_output_type, &module.types)); + // TODO: validate global variable + if module.global_variables[mesh_info.output_variable].space + != crate::AddressSpace::WorkGroup + { + return Err(EntryPointError::WrongMeshOutputAddressSpace.with_span()); + } + + let mut implied = module.analyze_mesh_shader_info(mesh_info.output_variable); + if let Some(e) = implied.2 { + return Err(e); + } + + if let Some(e) = mesh_info.max_vertices_override { + if let crate::Expression::Override(o) = module.global_expressions[e] { + if implied.1[0] != Some(o) { + return Err(EntryPointError::BadMeshOutputVarableType.with_span()); + } } } - if let Some((used_primitive_type, _)) = info.mesh_shader_info.primitive_type { - if used_primitive_type != mesh_info.primitive_output_type { - return Err(EntryPointError::WrongMeshOutputType - .with_span_handle(mesh_info.primitive_output_type, &module.types)); + if let Some(e) = mesh_info.max_primitives_override { + if let crate::Expression::Override(o) = module.global_expressions[e] { + if implied.1[1] != Some(o) { + return Err(EntryPointError::BadMeshOutputVarableType.with_span()); + } } } + implied.0.max_vertices_override = mesh_info.max_vertices_override; + implied.0.max_primitives_override = mesh_info.max_primitives_override; + if implied.0 != *mesh_info { + return Err(EntryPointError::BadMeshOutputVarableType.with_span()); + } + self.validate_mesh_output_type( ep, module, @@ -1110,7 +1141,7 @@ impl super::Validator { if info.mesh_shader_info.vertex_type.is_some() || info.mesh_shader_info.primitive_type.is_some() { - return Err(EntryPointError::UnexpectedMeshShaderOutput.with_span()); + return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span()); } } diff --git a/naga/tests/in/wgsl/mesh-shader.wgsl b/naga/tests/in/wgsl/mesh-shader.wgsl index 7f094a82f81..cdc7366b415 100644 --- a/naga/tests/in/wgsl/mesh-shader.wgsl +++ b/naga/tests/in/wgsl/mesh-shader.wgsl @@ -38,32 +38,35 @@ fn ts_main() -> @builtin(mesh_task_size) vec3 { taskPayload.visible = true; return vec3(3, 1, 1); } -@mesh + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var mesh_output: MeshOutput; +@mesh(mesh_output) @payload(taskPayload) -@vertex_output(VertexOutput, 3) @primitive_output(PrimitiveOutput, 1) @workgroup_size(1) fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { - setMeshOutputs(3, 1); + mesh_output.vertex_count = 3; + mesh_output.primitive_count = 1; workgroupData = 2.0; - var v: VertexOutput; - v.position = positions[0]; - v.color = colors[0] * taskPayload.colorMask; - setVertex(0, v); + mesh_output.vertices[0].position = positions[0]; + mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask; - v.position = positions[1]; - v.color = colors[1] * taskPayload.colorMask; - setVertex(1, v); + mesh_output.vertices[1].position = positions[1]; + mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask; - v.position = positions[2]; - v.color = colors[2] * taskPayload.colorMask; - setVertex(2, v); + mesh_output.vertices[2].position = positions[2]; + mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask; - var p: PrimitiveOutput; - p.index = vec3(0, 1, 2); - p.cull = !taskPayload.visible; - p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); - setPrimitive(0, p); + mesh_output.primitives[0].index = vec3(0, 1, 2); + mesh_output.primitives[0].cull = !taskPayload.visible; + mesh_output.primitives[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); } @fragment fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { diff --git a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron index 208e0aac84e..9ba7187ac69 100644 --- a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron +++ b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron @@ -9,6 +9,9 @@ ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | HOST_SHAREABLE | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), + ("DATA | SIZED | COPY | CREATION_RESOLVED | ARGUMENT | CONSTRUCTIBLE"), ], functions: [], entry_points: [ @@ -24,6 +27,7 @@ global_uses: [ ("READ | WRITE"), ("WRITE"), + (""), ], expressions: [ ( @@ -233,6 +237,7 @@ global_uses: [ ("READ"), ("WRITE"), + ("WRITE"), ], expressions: [ ( @@ -253,6 +258,30 @@ assignable_global: None, ty: Handle(6), ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 5, + space: WorkGroup, + )), + ), ( uniformity: ( non_uniform_result: None, @@ -265,6 +294,30 @@ width: 4, ))), ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 5, + space: WorkGroup, + )), + ), ( uniformity: ( non_uniform_result: None, @@ -303,26 +356,50 @@ ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), - ref_count: 9, - assignable_global: None, + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), ty: Value(Pointer( base: 4, - space: Function, + space: WorkGroup, )), ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), ty: Value(Pointer( base: 1, - space: Function, + space: WorkGroup, )), ), ( @@ -384,14 +461,50 @@ ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), ty: Value(Pointer( base: 1, - space: Function, + space: WorkGroup, )), ), ( @@ -499,31 +612,46 @@ requirements: (""), ), ref_count: 1, - assignable_global: None, - ty: Value(Scalar(( - kind: Uint, - width: 4, - ))), + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, - ty: Handle(4), + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), ty: Value(Pointer( base: 1, - space: Function, + space: WorkGroup, )), ), ( @@ -585,14 +713,50 @@ ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), ty: Value(Pointer( base: 1, - space: Function, + space: WorkGroup, )), ), ( @@ -700,31 +864,46 @@ requirements: (""), ), ref_count: 1, - assignable_global: None, - ty: Value(Scalar(( - kind: Uint, - width: 4, - ))), + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, - ty: Handle(4), + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), ty: Value(Pointer( base: 1, - space: Function, + space: WorkGroup, )), ), ( @@ -786,14 +965,50 @@ ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 9, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 4, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), ty: Value(Pointer( base: 1, - space: Function, + space: WorkGroup, )), ), ( @@ -901,43 +1116,46 @@ requirements: (""), ), ref_count: 1, - assignable_global: None, - ty: Value(Scalar(( - kind: Uint, - width: 4, - ))), + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), ), ( uniformity: ( - non_uniform_result: Some(6), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, - ty: Handle(4), + assignable_global: Some(2), + ty: Value(Pointer( + base: 10, + space: WorkGroup, + )), ), ( uniformity: ( - non_uniform_result: Some(61), + non_uniform_result: None, requirements: (""), ), - ref_count: 4, - assignable_global: None, + ref_count: 1, + assignable_global: Some(2), ty: Value(Pointer( base: 7, - space: Function, + space: WorkGroup, )), ), ( uniformity: ( - non_uniform_result: Some(61), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), ty: Value(Pointer( base: 6, - space: Function, + space: WorkGroup, )), ), ( @@ -987,14 +1205,50 @@ ), ( uniformity: ( - non_uniform_result: Some(61), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), + ty: Value(Pointer( + base: 11, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 10, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 7, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), ty: Value(Pointer( base: 2, - space: Function, + space: WorkGroup, )), ), ( @@ -1041,14 +1295,14 @@ ), ( uniformity: ( - non_uniform_result: Some(61), + non_uniform_result: None, requirements: (""), ), ref_count: 1, - assignable_global: None, + assignable_global: Some(2), ty: Value(Pointer( - base: 1, - space: Function, + base: 11, + space: WorkGroup, )), ), ( @@ -1057,11 +1311,11 @@ requirements: (""), ), ref_count: 1, - assignable_global: None, - ty: Value(Scalar(( - kind: Float, - width: 4, - ))), + assignable_global: Some(2), + ty: Value(Pointer( + base: 10, + space: WorkGroup, + )), ), ( uniformity: ( @@ -1069,11 +1323,23 @@ requirements: (""), ), ref_count: 1, - assignable_global: None, - ty: Value(Scalar(( - kind: Float, - width: 4, - ))), + assignable_global: Some(2), + ty: Value(Pointer( + base: 7, + space: WorkGroup, + )), + ), + ( + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + ref_count: 1, + assignable_global: Some(2), + ty: Value(Pointer( + base: 1, + space: WorkGroup, + )), ), ( uniformity: ( @@ -1106,7 +1372,10 @@ ), ref_count: 1, assignable_global: None, - ty: Handle(1), + ty: Value(Scalar(( + kind: Float, + width: 4, + ))), ), ( uniformity: ( @@ -1116,26 +1385,26 @@ ref_count: 1, assignable_global: None, ty: Value(Scalar(( - kind: Uint, + kind: Float, width: 4, ))), ), ( uniformity: ( - non_uniform_result: Some(61), + non_uniform_result: None, requirements: (""), ), ref_count: 1, assignable_global: None, - ty: Handle(7), + ty: Handle(1), ), ], sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, mesh_shader_info: ( - vertex_type: Some((4, 24)), - primitive_type: Some((7, 79)), + vertex_type: None, + primitive_type: None, ), ), ( @@ -1150,6 +1419,7 @@ global_uses: [ (""), (""), + (""), ], expressions: [ ( diff --git a/naga/tests/out/ir/wgsl-mesh-shader.compact.ron b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron index 38c79cba451..1147b017f5c 100644 --- a/naga/tests/out/ir/wgsl-mesh-shader.compact.ron +++ b/naga/tests/out/ir/wgsl-mesh-shader.compact.ron @@ -141,6 +141,54 @@ span: 16, ), ), + ( + name: None, + inner: Array( + base: 4, + size: Constant(3), + stride: 32, + ), + ), + ( + name: None, + inner: Array( + base: 7, + size: Constant(1), + stride: 32, + ), + ), + ( + name: Some("MeshOutput"), + inner: Struct( + members: [ + ( + name: Some("vertices"), + ty: 9, + binding: Some(BuiltIn(Vertices)), + offset: 0, + ), + ( + name: Some("primitives"), + ty: 10, + binding: Some(BuiltIn(Primitives)), + offset: 96, + ), + ( + name: Some("vertex_count"), + ty: 5, + binding: Some(BuiltIn(VertexCount)), + offset: 128, + ), + ( + name: Some("primitive_count"), + ty: 5, + binding: Some(BuiltIn(PrimitiveCount)), + offset: 132, + ), + ], + span: 144, + ), + ), ], special_types: ( ray_desc: None, @@ -167,6 +215,13 @@ ty: 0, init: None, ), + ( + name: Some("mesh_output"), + space: WorkGroup, + binding: None, + ty: 11, + init: None, + ), ], global_expressions: [], functions: [], @@ -292,28 +347,35 @@ ), ], result: None, - local_variables: [ - ( - name: Some("v"), - ty: 4, - init: None, - ), - ( - name: Some("p"), - ty: 7, - init: None, - ), - ], + local_variables: [], expressions: [ FunctionArgument(0), FunctionArgument(1), + GlobalVariable(2), + AccessIndex( + base: 2, + index: 2, + ), Literal(U32(3)), + GlobalVariable(2), + AccessIndex( + base: 5, + index: 3, + ), Literal(U32(1)), GlobalVariable(1), Literal(F32(2.0)), - LocalVariable(0), + GlobalVariable(2), + AccessIndex( + base: 10, + index: 0, + ), + AccessIndex( + base: 11, + index: 0, + ), AccessIndex( - base: 6, + base: 12, index: 0, ), Literal(F32(0.0)), @@ -323,23 +385,32 @@ Compose( ty: 1, components: [ - 8, - 9, - 10, - 11, + 14, + 15, + 16, + 17, ], ), + GlobalVariable(2), + AccessIndex( + base: 19, + index: 0, + ), + AccessIndex( + base: 20, + index: 0, + ), AccessIndex( - base: 6, + base: 21, index: 1, ), GlobalVariable(0), AccessIndex( - base: 14, + base: 23, index: 0, ), Load( - pointer: 15, + pointer: 24, ), Literal(F32(0.0)), Literal(F32(1.0)), @@ -348,23 +419,28 @@ Compose( ty: 1, components: [ - 17, - 18, - 19, - 20, + 26, + 27, + 28, + 29, ], ), Binary( op: Multiply, - left: 21, - right: 16, + left: 30, + right: 25, ), - Literal(U32(0)), - Load( - pointer: 6, + GlobalVariable(2), + AccessIndex( + base: 32, + index: 0, + ), + AccessIndex( + base: 33, + index: 1, ), AccessIndex( - base: 6, + base: 34, index: 0, ), Literal(F32(-1.0)), @@ -374,23 +450,32 @@ Compose( ty: 1, components: [ - 26, - 27, - 28, - 29, + 36, + 37, + 38, + 39, ], ), + GlobalVariable(2), AccessIndex( - base: 6, + base: 41, + index: 0, + ), + AccessIndex( + base: 42, + index: 1, + ), + AccessIndex( + base: 43, index: 1, ), GlobalVariable(0), AccessIndex( - base: 32, + base: 45, index: 0, ), Load( - pointer: 33, + pointer: 46, ), Literal(F32(0.0)), Literal(F32(0.0)), @@ -399,23 +484,28 @@ Compose( ty: 1, components: [ - 35, - 36, - 37, - 38, + 48, + 49, + 50, + 51, ], ), Binary( op: Multiply, - left: 39, - right: 34, + left: 52, + right: 47, ), - Literal(U32(1)), - Load( - pointer: 6, + GlobalVariable(2), + AccessIndex( + base: 54, + index: 0, + ), + AccessIndex( + base: 55, + index: 2, ), AccessIndex( - base: 6, + base: 56, index: 0, ), Literal(F32(1.0)), @@ -425,23 +515,32 @@ Compose( ty: 1, components: [ - 44, - 45, - 46, - 47, + 58, + 59, + 60, + 61, ], ), + GlobalVariable(2), + AccessIndex( + base: 63, + index: 0, + ), AccessIndex( - base: 6, + base: 64, + index: 2, + ), + AccessIndex( + base: 65, index: 1, ), GlobalVariable(0), AccessIndex( - base: 50, + base: 67, index: 0, ), Load( - pointer: 51, + pointer: 68, ), Literal(F32(1.0)), Literal(F32(0.0)), @@ -450,24 +549,28 @@ Compose( ty: 1, components: [ - 53, - 54, - 55, - 56, + 70, + 71, + 72, + 73, ], ), Binary( op: Multiply, - left: 57, - right: 52, + left: 74, + right: 69, ), - Literal(U32(2)), - Load( - pointer: 6, + GlobalVariable(2), + AccessIndex( + base: 76, + index: 1, + ), + AccessIndex( + base: 77, + index: 0, ), - LocalVariable(1), AccessIndex( - base: 61, + base: 78, index: 0, ), Literal(U32(0)), @@ -476,29 +579,47 @@ Compose( ty: 6, components: [ - 63, - 64, - 65, + 80, + 81, + 82, ], ), + GlobalVariable(2), AccessIndex( - base: 61, + base: 84, + index: 1, + ), + AccessIndex( + base: 85, + index: 0, + ), + AccessIndex( + base: 86, index: 1, ), GlobalVariable(0), AccessIndex( - base: 68, + base: 88, index: 1, ), Load( - pointer: 69, + pointer: 89, ), Unary( op: LogicalNot, - expr: 70, + expr: 90, + ), + GlobalVariable(2), + AccessIndex( + base: 92, + index: 1, + ), + AccessIndex( + base: 93, + index: 0, ), AccessIndex( - base: 61, + base: 94, index: 2, ), Literal(F32(1.0)), @@ -508,33 +629,45 @@ Compose( ty: 1, components: [ - 73, - 74, - 75, - 76, + 96, + 97, + 98, + 99, ], ), - Literal(U32(0)), - Load( - pointer: 61, - ), ], named_expressions: { 0: "index", 1: "id", }, body: [ - MeshFunction(SetMeshOutputs( - vertex_count: 2, - primitive_count: 3, + Emit(( + start: 3, + end: 4, )), Store( - pointer: 4, - value: 5, + pointer: 3, + value: 4, + ), + Emit(( + start: 6, + end: 7, + )), + Store( + pointer: 6, + value: 7, + ), + Store( + pointer: 8, + value: 9, ), Emit(( - start: 7, - end: 8, + start: 11, + end: 12, + )), + Emit(( + start: 12, + end: 14, )), Emit(( start: 0, @@ -549,16 +682,20 @@ end: 0, )), Emit(( - start: 12, - end: 13, + start: 18, + end: 19, )), Store( - pointer: 7, - value: 12, + pointer: 13, + value: 18, ), Emit(( - start: 13, - end: 14, + start: 20, + end: 21, + )), + Emit(( + start: 21, + end: 23, )), Emit(( start: 0, @@ -573,28 +710,24 @@ end: 0, )), Emit(( - start: 15, - end: 17, + start: 24, + end: 26, )), Emit(( - start: 21, - end: 23, + start: 30, + end: 32, )), Store( - pointer: 13, - value: 22, + pointer: 22, + value: 31, ), Emit(( - start: 24, - end: 25, - )), - MeshFunction(SetVertex( - index: 23, - value: 24, + start: 33, + end: 34, )), Emit(( - start: 25, - end: 26, + start: 34, + end: 36, )), Emit(( start: 0, @@ -609,16 +742,20 @@ end: 0, )), Emit(( - start: 30, - end: 31, + start: 40, + end: 41, )), Store( - pointer: 25, - value: 30, + pointer: 35, + value: 40, ), Emit(( - start: 31, - end: 32, + start: 42, + end: 43, + )), + Emit(( + start: 43, + end: 45, )), Emit(( start: 0, @@ -633,28 +770,24 @@ end: 0, )), Emit(( - start: 33, - end: 35, + start: 46, + end: 48, )), Emit(( - start: 39, - end: 41, + start: 52, + end: 54, )), Store( - pointer: 31, - value: 40, + pointer: 44, + value: 53, ), Emit(( - start: 42, - end: 43, - )), - MeshFunction(SetVertex( - index: 41, - value: 42, + start: 55, + end: 56, )), Emit(( - start: 43, - end: 44, + start: 56, + end: 58, )), Emit(( start: 0, @@ -669,16 +802,20 @@ end: 0, )), Emit(( - start: 48, - end: 49, + start: 62, + end: 63, )), Store( - pointer: 43, - value: 48, + pointer: 57, + value: 62, ), Emit(( - start: 49, - end: 50, + start: 64, + end: 65, + )), + Emit(( + start: 65, + end: 67, )), Emit(( start: 0, @@ -693,69 +830,65 @@ end: 0, )), Emit(( - start: 51, - end: 53, + start: 68, + end: 70, )), Emit(( - start: 57, - end: 59, + start: 74, + end: 76, )), Store( - pointer: 49, - value: 58, + pointer: 66, + value: 75, ), Emit(( - start: 60, - end: 61, - )), - MeshFunction(SetVertex( - index: 59, - value: 60, + start: 77, + end: 78, )), Emit(( - start: 62, - end: 63, + start: 78, + end: 80, )), Emit(( - start: 66, - end: 67, + start: 83, + end: 84, )), Store( - pointer: 62, - value: 66, + pointer: 79, + value: 83, ), Emit(( - start: 67, - end: 68, + start: 85, + end: 86, + )), + Emit(( + start: 86, + end: 88, )), Emit(( - start: 69, - end: 72, + start: 89, + end: 92, )), Store( - pointer: 67, - value: 71, + pointer: 87, + value: 91, ), Emit(( - start: 72, - end: 73, + start: 93, + end: 94, )), Emit(( - start: 77, - end: 78, + start: 94, + end: 96, )), - Store( - pointer: 72, - value: 77, - ), Emit(( - start: 79, - end: 80, - )), - MeshFunction(SetPrimitive( - index: 78, - value: 79, + start: 100, + end: 101, )), + Store( + pointer: 95, + value: 100, + ), Return( value: None, ), @@ -770,6 +903,7 @@ max_primitives_override: None, vertex_output_type: 4, primitive_output_type: 7, + output_variable: 2, )), task_payload: Some(0), ), diff --git a/naga/tests/out/ir/wgsl-mesh-shader.ron b/naga/tests/out/ir/wgsl-mesh-shader.ron index 38c79cba451..1147b017f5c 100644 --- a/naga/tests/out/ir/wgsl-mesh-shader.ron +++ b/naga/tests/out/ir/wgsl-mesh-shader.ron @@ -141,6 +141,54 @@ span: 16, ), ), + ( + name: None, + inner: Array( + base: 4, + size: Constant(3), + stride: 32, + ), + ), + ( + name: None, + inner: Array( + base: 7, + size: Constant(1), + stride: 32, + ), + ), + ( + name: Some("MeshOutput"), + inner: Struct( + members: [ + ( + name: Some("vertices"), + ty: 9, + binding: Some(BuiltIn(Vertices)), + offset: 0, + ), + ( + name: Some("primitives"), + ty: 10, + binding: Some(BuiltIn(Primitives)), + offset: 96, + ), + ( + name: Some("vertex_count"), + ty: 5, + binding: Some(BuiltIn(VertexCount)), + offset: 128, + ), + ( + name: Some("primitive_count"), + ty: 5, + binding: Some(BuiltIn(PrimitiveCount)), + offset: 132, + ), + ], + span: 144, + ), + ), ], special_types: ( ray_desc: None, @@ -167,6 +215,13 @@ ty: 0, init: None, ), + ( + name: Some("mesh_output"), + space: WorkGroup, + binding: None, + ty: 11, + init: None, + ), ], global_expressions: [], functions: [], @@ -292,28 +347,35 @@ ), ], result: None, - local_variables: [ - ( - name: Some("v"), - ty: 4, - init: None, - ), - ( - name: Some("p"), - ty: 7, - init: None, - ), - ], + local_variables: [], expressions: [ FunctionArgument(0), FunctionArgument(1), + GlobalVariable(2), + AccessIndex( + base: 2, + index: 2, + ), Literal(U32(3)), + GlobalVariable(2), + AccessIndex( + base: 5, + index: 3, + ), Literal(U32(1)), GlobalVariable(1), Literal(F32(2.0)), - LocalVariable(0), + GlobalVariable(2), + AccessIndex( + base: 10, + index: 0, + ), + AccessIndex( + base: 11, + index: 0, + ), AccessIndex( - base: 6, + base: 12, index: 0, ), Literal(F32(0.0)), @@ -323,23 +385,32 @@ Compose( ty: 1, components: [ - 8, - 9, - 10, - 11, + 14, + 15, + 16, + 17, ], ), + GlobalVariable(2), + AccessIndex( + base: 19, + index: 0, + ), + AccessIndex( + base: 20, + index: 0, + ), AccessIndex( - base: 6, + base: 21, index: 1, ), GlobalVariable(0), AccessIndex( - base: 14, + base: 23, index: 0, ), Load( - pointer: 15, + pointer: 24, ), Literal(F32(0.0)), Literal(F32(1.0)), @@ -348,23 +419,28 @@ Compose( ty: 1, components: [ - 17, - 18, - 19, - 20, + 26, + 27, + 28, + 29, ], ), Binary( op: Multiply, - left: 21, - right: 16, + left: 30, + right: 25, ), - Literal(U32(0)), - Load( - pointer: 6, + GlobalVariable(2), + AccessIndex( + base: 32, + index: 0, + ), + AccessIndex( + base: 33, + index: 1, ), AccessIndex( - base: 6, + base: 34, index: 0, ), Literal(F32(-1.0)), @@ -374,23 +450,32 @@ Compose( ty: 1, components: [ - 26, - 27, - 28, - 29, + 36, + 37, + 38, + 39, ], ), + GlobalVariable(2), AccessIndex( - base: 6, + base: 41, + index: 0, + ), + AccessIndex( + base: 42, + index: 1, + ), + AccessIndex( + base: 43, index: 1, ), GlobalVariable(0), AccessIndex( - base: 32, + base: 45, index: 0, ), Load( - pointer: 33, + pointer: 46, ), Literal(F32(0.0)), Literal(F32(0.0)), @@ -399,23 +484,28 @@ Compose( ty: 1, components: [ - 35, - 36, - 37, - 38, + 48, + 49, + 50, + 51, ], ), Binary( op: Multiply, - left: 39, - right: 34, + left: 52, + right: 47, ), - Literal(U32(1)), - Load( - pointer: 6, + GlobalVariable(2), + AccessIndex( + base: 54, + index: 0, + ), + AccessIndex( + base: 55, + index: 2, ), AccessIndex( - base: 6, + base: 56, index: 0, ), Literal(F32(1.0)), @@ -425,23 +515,32 @@ Compose( ty: 1, components: [ - 44, - 45, - 46, - 47, + 58, + 59, + 60, + 61, ], ), + GlobalVariable(2), + AccessIndex( + base: 63, + index: 0, + ), AccessIndex( - base: 6, + base: 64, + index: 2, + ), + AccessIndex( + base: 65, index: 1, ), GlobalVariable(0), AccessIndex( - base: 50, + base: 67, index: 0, ), Load( - pointer: 51, + pointer: 68, ), Literal(F32(1.0)), Literal(F32(0.0)), @@ -450,24 +549,28 @@ Compose( ty: 1, components: [ - 53, - 54, - 55, - 56, + 70, + 71, + 72, + 73, ], ), Binary( op: Multiply, - left: 57, - right: 52, + left: 74, + right: 69, ), - Literal(U32(2)), - Load( - pointer: 6, + GlobalVariable(2), + AccessIndex( + base: 76, + index: 1, + ), + AccessIndex( + base: 77, + index: 0, ), - LocalVariable(1), AccessIndex( - base: 61, + base: 78, index: 0, ), Literal(U32(0)), @@ -476,29 +579,47 @@ Compose( ty: 6, components: [ - 63, - 64, - 65, + 80, + 81, + 82, ], ), + GlobalVariable(2), AccessIndex( - base: 61, + base: 84, + index: 1, + ), + AccessIndex( + base: 85, + index: 0, + ), + AccessIndex( + base: 86, index: 1, ), GlobalVariable(0), AccessIndex( - base: 68, + base: 88, index: 1, ), Load( - pointer: 69, + pointer: 89, ), Unary( op: LogicalNot, - expr: 70, + expr: 90, + ), + GlobalVariable(2), + AccessIndex( + base: 92, + index: 1, + ), + AccessIndex( + base: 93, + index: 0, ), AccessIndex( - base: 61, + base: 94, index: 2, ), Literal(F32(1.0)), @@ -508,33 +629,45 @@ Compose( ty: 1, components: [ - 73, - 74, - 75, - 76, + 96, + 97, + 98, + 99, ], ), - Literal(U32(0)), - Load( - pointer: 61, - ), ], named_expressions: { 0: "index", 1: "id", }, body: [ - MeshFunction(SetMeshOutputs( - vertex_count: 2, - primitive_count: 3, + Emit(( + start: 3, + end: 4, )), Store( - pointer: 4, - value: 5, + pointer: 3, + value: 4, + ), + Emit(( + start: 6, + end: 7, + )), + Store( + pointer: 6, + value: 7, + ), + Store( + pointer: 8, + value: 9, ), Emit(( - start: 7, - end: 8, + start: 11, + end: 12, + )), + Emit(( + start: 12, + end: 14, )), Emit(( start: 0, @@ -549,16 +682,20 @@ end: 0, )), Emit(( - start: 12, - end: 13, + start: 18, + end: 19, )), Store( - pointer: 7, - value: 12, + pointer: 13, + value: 18, ), Emit(( - start: 13, - end: 14, + start: 20, + end: 21, + )), + Emit(( + start: 21, + end: 23, )), Emit(( start: 0, @@ -573,28 +710,24 @@ end: 0, )), Emit(( - start: 15, - end: 17, + start: 24, + end: 26, )), Emit(( - start: 21, - end: 23, + start: 30, + end: 32, )), Store( - pointer: 13, - value: 22, + pointer: 22, + value: 31, ), Emit(( - start: 24, - end: 25, - )), - MeshFunction(SetVertex( - index: 23, - value: 24, + start: 33, + end: 34, )), Emit(( - start: 25, - end: 26, + start: 34, + end: 36, )), Emit(( start: 0, @@ -609,16 +742,20 @@ end: 0, )), Emit(( - start: 30, - end: 31, + start: 40, + end: 41, )), Store( - pointer: 25, - value: 30, + pointer: 35, + value: 40, ), Emit(( - start: 31, - end: 32, + start: 42, + end: 43, + )), + Emit(( + start: 43, + end: 45, )), Emit(( start: 0, @@ -633,28 +770,24 @@ end: 0, )), Emit(( - start: 33, - end: 35, + start: 46, + end: 48, )), Emit(( - start: 39, - end: 41, + start: 52, + end: 54, )), Store( - pointer: 31, - value: 40, + pointer: 44, + value: 53, ), Emit(( - start: 42, - end: 43, - )), - MeshFunction(SetVertex( - index: 41, - value: 42, + start: 55, + end: 56, )), Emit(( - start: 43, - end: 44, + start: 56, + end: 58, )), Emit(( start: 0, @@ -669,16 +802,20 @@ end: 0, )), Emit(( - start: 48, - end: 49, + start: 62, + end: 63, )), Store( - pointer: 43, - value: 48, + pointer: 57, + value: 62, ), Emit(( - start: 49, - end: 50, + start: 64, + end: 65, + )), + Emit(( + start: 65, + end: 67, )), Emit(( start: 0, @@ -693,69 +830,65 @@ end: 0, )), Emit(( - start: 51, - end: 53, + start: 68, + end: 70, )), Emit(( - start: 57, - end: 59, + start: 74, + end: 76, )), Store( - pointer: 49, - value: 58, + pointer: 66, + value: 75, ), Emit(( - start: 60, - end: 61, - )), - MeshFunction(SetVertex( - index: 59, - value: 60, + start: 77, + end: 78, )), Emit(( - start: 62, - end: 63, + start: 78, + end: 80, )), Emit(( - start: 66, - end: 67, + start: 83, + end: 84, )), Store( - pointer: 62, - value: 66, + pointer: 79, + value: 83, ), Emit(( - start: 67, - end: 68, + start: 85, + end: 86, + )), + Emit(( + start: 86, + end: 88, )), Emit(( - start: 69, - end: 72, + start: 89, + end: 92, )), Store( - pointer: 67, - value: 71, + pointer: 87, + value: 91, ), Emit(( - start: 72, - end: 73, + start: 93, + end: 94, )), Emit(( - start: 77, - end: 78, + start: 94, + end: 96, )), - Store( - pointer: 72, - value: 77, - ), Emit(( - start: 79, - end: 80, - )), - MeshFunction(SetPrimitive( - index: 78, - value: 79, + start: 100, + end: 101, )), + Store( + pointer: 95, + value: 100, + ), Return( value: None, ), @@ -770,6 +903,7 @@ max_primitives_override: None, vertex_output_type: 4, primitive_output_type: 7, + output_variable: 2, )), task_payload: Some(0), ), From 3905ae8201ce20ba32b8d6b98297acf470aecd6b Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 30 Oct 2025 19:42:50 -0500 Subject: [PATCH 77/89] Improved validation slightly, remvoed obselete crap, fixed bug in compaction, made clippy happy --- naga/src/compact/mod.rs | 4 + naga/src/proc/mod.rs | 16 ++-- naga/src/valid/analyzer.rs | 93 ------------------- naga/src/valid/handles.rs | 1 + naga/src/valid/interface.rs | 19 +--- naga/tests/out/analysis/spv-shadow.info.ron | 12 --- naga/tests/out/analysis/wgsl-access.info.ron | 76 --------------- naga/tests/out/analysis/wgsl-collatz.info.ron | 8 -- .../out/analysis/wgsl-mesh-shader.info.ron | 12 --- .../out/analysis/wgsl-overrides.info.ron | 4 - .../analysis/wgsl-storage-textures.info.ron | 8 -- 11 files changed, 17 insertions(+), 236 deletions(-) diff --git a/naga/src/compact/mod.rs b/naga/src/compact/mod.rs index a7d3d463f11..2761c7cfaf8 100644 --- a/naga/src/compact/mod.rs +++ b/naga/src/compact/mod.rs @@ -226,6 +226,9 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { module_tracer.global_variables_used.insert(task_payload); } if let Some(ref mesh_info) = entry.mesh_info { + module_tracer + .global_variables_used + .insert(mesh_info.output_variable); module_tracer .types_used .insert(mesh_info.vertex_output_type); @@ -385,6 +388,7 @@ pub fn compact(module: &mut crate::Module, keep_unused: KeepUnused) { module_map.globals.adjust(task_payload); } if let Some(ref mut mesh_info) = entry.mesh_info { + module_map.globals.adjust(&mut mesh_info.output_variable); module_map.types.adjust(&mut mesh_info.vertex_output_type); module_map .types diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index dd2ae459373..4271db391c5 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -683,14 +683,14 @@ impl crate::Module { output_variable: gv, }; let mut error = None; - let typ = &self.types[self.global_variables[gv].ty].inner; + let r#type = &self.types[self.global_variables[gv].ty].inner; let mut topology = output.topology; // Max, max override, type let mut vertex_info = (0, None, null_type); let mut primitive_info = (0, None, null_type); - match typ { + match r#type { &crate::TypeInner::Struct { ref members, .. } => { let mut builtins = crate::FastHashSet::default(); for member in members { @@ -700,7 +700,7 @@ impl crate::Module { error = Some(EntryPointError::BadMeshOutputVariableField); } if builtins.contains(&crate::BuiltIn::VertexCount) { - error = Some(EntryPointError::BadMeshOutputVarableType); + error = Some(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::VertexCount); } @@ -709,7 +709,7 @@ impl crate::Module { error = Some(EntryPointError::BadMeshOutputVariableField); } if builtins.contains(&crate::BuiltIn::PrimitiveCount) { - error = Some(EntryPointError::BadMeshOutputVarableType); + error = Some(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::PrimitiveCount); } @@ -767,18 +767,18 @@ impl crate::Module { _ => (), } if builtins.contains(&crate::BuiltIn::Primitives) { - error = Some(EntryPointError::BadMeshOutputVarableType); + error = Some(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::Primitives); } else { vertex_info = (a, b, c); if builtins.contains(&crate::BuiltIn::Vertices) { - error = Some(EntryPointError::BadMeshOutputVarableType); + error = Some(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::Vertices); } } - _ => error = Some(EntryPointError::BadMeshOutputVarableType), + _ => error = Some(EntryPointError::BadMeshOutputVariableType), } } output = crate::MeshStageInfo { @@ -792,7 +792,7 @@ impl crate::Module { ..output } } - _ => error = Some(EntryPointError::BadMeshOutputVarableType), + _ => error = Some(EntryPointError::BadMeshOutputVariableType), } ( output, diff --git a/naga/src/valid/analyzer.rs b/naga/src/valid/analyzer.rs index 5befdfe22a6..e01a7b0b735 100644 --- a/naga/src/valid/analyzer.rs +++ b/naga/src/valid/analyzer.rs @@ -85,25 +85,6 @@ struct FunctionUniformity { exit: ExitFlags, } -/// Mesh shader related characteristics of a function. -#[derive(Debug, Clone, Default)] -#[cfg_attr(feature = "serialize", derive(serde::Serialize))] -#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] -#[cfg_attr(test, derive(PartialEq))] -pub struct FunctionMeshShaderInfo { - /// The type of value this function passes to [`SetVertex`], and the - /// expression that first established it. - /// - /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex - pub vertex_type: Option<(Handle, Handle)>, - - /// The type of value this function passes to [`SetPrimitive`], and the - /// expression that first established it. - /// - /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive - pub primitive_type: Option<(Handle, Handle)>, -} - impl ops::BitOr for FunctionUniformity { type Output = Self; fn bitor(self, other: Self) -> Self { @@ -321,9 +302,6 @@ pub struct FunctionInfo { /// See [`DiagnosticFilterNode`] for details on how the tree is represented and used in /// validation. diagnostic_filter_leaf: Option>, - - /// Mesh shader info for this function and its callees. - pub mesh_shader_info: FunctionMeshShaderInfo, } impl FunctionInfo { @@ -520,9 +498,6 @@ impl FunctionInfo { *mine |= *other; } - // Inherit mesh output types from our callees. - self.try_update_mesh_info(&callee.mesh_shader_info)?; - Ok(FunctionUniformity { result: callee.uniformity.clone(), exit: if callee.may_kill { @@ -1200,72 +1175,6 @@ impl FunctionInfo { } Ok(combined_uniformity) } - - /// Note the type of value passed to [`SetVertex`]. - /// - /// Record that this function passed a value of type `ty` as the second - /// argument to the [`SetVertex`] builtin function. All calls to - /// `SetVertex` must pass the same type, and this must match the - /// function's [`vertex_output_type`]. - /// - /// [`SetVertex`]: crate::ir::MeshFunction::SetVertex - /// [`vertex_output_type`]: crate::ir::MeshStageInfo::vertex_output_type - fn try_update_mesh_vertex_type( - &mut self, - ty: Handle, - value: Handle, - ) -> Result<(), WithSpan> { - if let &Some(ref existing) = &self.mesh_shader_info.vertex_type { - if existing.0 != ty { - return Err( - FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() - ); - } - } else { - self.mesh_shader_info.vertex_type = Some((ty, value)); - } - Ok(()) - } - - /// Note the type of value passed to [`SetPrimitive`]. - /// - /// Record that this function passed a value of type `ty` as the second - /// argument to the [`SetPrimitive`] builtin function. All calls to - /// `SetPrimitive` must pass the same type, and this must match the - /// function's [`primitive_output_type`]. - /// - /// [`SetPrimitive`]: crate::ir::MeshFunction::SetPrimitive - /// [`primitive_output_type`]: crate::ir::MeshStageInfo::primitive_output_type - fn try_update_mesh_primitive_type( - &mut self, - ty: Handle, - value: Handle, - ) -> Result<(), WithSpan> { - if let &Some(ref existing) = &self.mesh_shader_info.primitive_type { - if existing.0 != ty { - return Err( - FunctionError::ConflictingMeshOutputTypes(existing.1, value).with_span() - ); - } - } else { - self.mesh_shader_info.primitive_type = Some((ty, value)); - } - Ok(()) - } - - /// Update this function's mesh shader info, given that it calls `callee`. - fn try_update_mesh_info( - &mut self, - callee: &FunctionMeshShaderInfo, - ) -> Result<(), WithSpan> { - if let &Some(ref other_vertex) = &callee.vertex_type { - self.try_update_mesh_vertex_type(other_vertex.0, other_vertex.1)?; - } - if let &Some(ref other_primitive) = &callee.primitive_type { - self.try_update_mesh_primitive_type(other_primitive.0, other_primitive.1)?; - } - Ok(()) - } } impl ModuleInfo { @@ -1301,7 +1210,6 @@ impl ModuleInfo { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: fun.diagnostic_filter_leaf, - mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext::with_locals(module, &fun.local_variables, &fun.arguments); @@ -1435,7 +1343,6 @@ fn uniform_control_flow() { sampling: crate::FastHashSet::default(), dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: FunctionMeshShaderInfo::default(), }; let resolve_context = ResolveContext { constants: &Arena::new(), diff --git a/naga/src/valid/handles.rs b/naga/src/valid/handles.rs index 7fe6fa8803d..5b7fb3fab75 100644 --- a/naga/src/valid/handles.rs +++ b/naga/src/valid/handles.rs @@ -237,6 +237,7 @@ impl super::Validator { Self::validate_global_variable_handle(task_payload, global_variables)?; } if let Some(ref mesh_info) = entry_point.mesh_info { + Self::validate_global_variable_handle(mesh_info.output_variable, global_variables)?; validate_type(mesh_info.vertex_output_type)?; validate_type(mesh_info.primitive_output_type)?; for ov in mesh_info diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index e5e7b6997b1..4d437477ca1 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -153,7 +153,7 @@ pub enum EntryPointError { #[error( "Mesh shader output variable must be a struct with fields that are all allowed builtins" )] - BadMeshOutputVarableType, + BadMeshOutputVariableType, #[error("Mesh shader output variable fields must have types that are in accordance with the mesh shader spec")] BadMeshOutputVariableField, #[error("Mesh shader entry point cannot have a return type")] @@ -1085,12 +1085,9 @@ impl super::Validator { } } - // TODO: validate mesh entry point info - // If this is a `Mesh` entry point, check its vertex and primitive output types. // We verified previously that only mesh shaders can have `mesh_info`. if let &Some(ref mesh_info) = &ep.mesh_info { - // TODO: validate global variable if module.global_variables[mesh_info.output_variable].space != crate::AddressSpace::WorkGroup { @@ -1105,14 +1102,14 @@ impl super::Validator { if let Some(e) = mesh_info.max_vertices_override { if let crate::Expression::Override(o) = module.global_expressions[e] { if implied.1[0] != Some(o) { - return Err(EntryPointError::BadMeshOutputVarableType.with_span()); + return Err(EntryPointError::BadMeshOutputVariableType.with_span()); } } } if let Some(e) = mesh_info.max_primitives_override { if let crate::Expression::Override(o) = module.global_expressions[e] { if implied.1[1] != Some(o) { - return Err(EntryPointError::BadMeshOutputVarableType.with_span()); + return Err(EntryPointError::BadMeshOutputVariableType.with_span()); } } } @@ -1120,7 +1117,7 @@ impl super::Validator { implied.0.max_vertices_override = mesh_info.max_vertices_override; implied.0.max_primitives_override = mesh_info.max_primitives_override; if implied.0 != *mesh_info { - return Err(EntryPointError::BadMeshOutputVarableType.with_span()); + return Err(EntryPointError::BadMeshOutputVariableType.with_span()); } self.validate_mesh_output_type( @@ -1135,14 +1132,6 @@ impl super::Validator { mesh_info.primitive_output_type, MeshOutputType::PrimitiveOutput, )?; - } else { - // This is not a `Mesh` entry point, so ensure that it never tries to produce - // vertices or primitives. - if info.mesh_shader_info.vertex_type.is_some() - || info.mesh_shader_info.primitive_type.is_some() - { - return Err(EntryPointError::UnexpectedMeshShaderAttributes.with_span()); - } } Ok(info) diff --git a/naga/tests/out/analysis/spv-shadow.info.ron b/naga/tests/out/analysis/spv-shadow.info.ron index b08a28438ed..381f841d5d9 100644 --- a/naga/tests/out/analysis/spv-shadow.info.ron +++ b/naga/tests/out/analysis/spv-shadow.info.ron @@ -413,10 +413,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -1595,10 +1591,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], entry_points: [ @@ -1693,10 +1685,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-access.info.ron b/naga/tests/out/analysis/wgsl-access.info.ron index d297b09a404..c22cd768f2e 100644 --- a/naga/tests/out/analysis/wgsl-access.info.ron +++ b/naga/tests/out/analysis/wgsl-access.info.ron @@ -1197,10 +1197,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2527,10 +2523,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2571,10 +2563,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2624,10 +2612,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2671,10 +2655,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2769,10 +2749,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2894,10 +2870,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -2950,10 +2922,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3009,10 +2977,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3065,10 +3029,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3124,10 +3084,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3192,10 +3148,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3269,10 +3221,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3349,10 +3297,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3453,10 +3397,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -3653,10 +3593,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], entry_points: [ @@ -4354,10 +4290,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -4810,10 +4742,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -4884,10 +4812,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-collatz.info.ron b/naga/tests/out/analysis/wgsl-collatz.info.ron index 2796f544510..219e016f8d7 100644 --- a/naga/tests/out/analysis/wgsl-collatz.info.ron +++ b/naga/tests/out/analysis/wgsl-collatz.info.ron @@ -275,10 +275,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], entry_points: [ @@ -434,10 +430,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [], diff --git a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron index 9ba7187ac69..9422d07107d 100644 --- a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron +++ b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron @@ -220,10 +220,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -1402,10 +1398,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -1471,10 +1463,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [], diff --git a/naga/tests/out/analysis/wgsl-overrides.info.ron b/naga/tests/out/analysis/wgsl-overrides.info.ron index a76c9c89c9b..92e99112e53 100644 --- a/naga/tests/out/analysis/wgsl-overrides.info.ron +++ b/naga/tests/out/analysis/wgsl-overrides.info.ron @@ -201,10 +201,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [ diff --git a/naga/tests/out/analysis/wgsl-storage-textures.info.ron b/naga/tests/out/analysis/wgsl-storage-textures.info.ron index 35b5a7e320c..8bb298a6450 100644 --- a/naga/tests/out/analysis/wgsl-storage-textures.info.ron +++ b/naga/tests/out/analysis/wgsl-storage-textures.info.ron @@ -184,10 +184,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ( flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), @@ -400,10 +396,6 @@ sampling: [], dual_source_blending: false, diagnostic_filter_leaf: None, - mesh_shader_info: ( - vertex_type: None, - primitive_type: None, - ), ), ], const_expression_types: [], From 64798dd1466c16db8a1291d35182fd469b0a3908 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Sat, 1 Nov 2025 01:30:36 -0500 Subject: [PATCH 78/89] Added changelog entry --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 35b9d3c7128..a8d0ce39a89 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -105,6 +105,7 @@ SamplerDescriptor { - Removed three features from `wgpu-hal` which did nothing useful: `"cargo-clippy"`, `"gpu-allocator"`, and `"rustc-hash"`. By @kpreid in [#8357](https://github.com/gfx-rs/wgpu/pull/8357). - `wgpu_types::PollError` now always implements the `Error` trait. By @kpreid in [#8384](https://github.com/gfx-rs/wgpu/pull/8384). - The texture subresources used by the color attachments of a render pass are no longer allowed to overlap when accessed via different texture views. By @andyleiserson in [#8402](https://github.com/gfx-rs/wgpu/pull/8402). +- Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370). #### DX12 From bd923cdc271aa862f4899d4199e8a407c2295c78 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Mon, 3 Nov 2025 12:50:38 -0600 Subject: [PATCH 79/89] Made parser respect enable extension --- naga/src/front/wgsl/error.rs | 1 - naga/src/front/wgsl/parse/conv.rs | 34 +++++++++++++++++++++++--- naga/src/front/wgsl/parse/mod.rs | 40 +++++++++++++++++++++++++++++-- 3 files changed, 69 insertions(+), 6 deletions(-) diff --git a/naga/src/front/wgsl/error.rs b/naga/src/front/wgsl/error.rs index a8958525ad1..0cd7e11c737 100644 --- a/naga/src/front/wgsl/error.rs +++ b/naga/src/front/wgsl/error.rs @@ -1375,7 +1375,6 @@ impl<'a> Error<'a> { }, Error::ExpectedGlobalVariable { name_span } => ParseError { message: "expected global variable".to_string(), - // TODO: I would like to also include the global declaration span labels: vec![(name_span, "variable used here".into())], notes: vec![], }, diff --git a/naga/src/front/wgsl/parse/conv.rs b/naga/src/front/wgsl/parse/conv.rs index 16e814f56f5..0303b7ed6bb 100644 --- a/naga/src/front/wgsl/parse/conv.rs +++ b/naga/src/front/wgsl/parse/conv.rs @@ -6,7 +6,11 @@ use crate::Span; use alloc::boxed::Box; -pub fn map_address_space(word: &str, span: Span) -> Result<'_, crate::AddressSpace> { +pub fn map_address_space<'a>( + word: &str, + span: Span, + enable_extensions: &EnableExtensions, +) -> Result<'a, crate::AddressSpace> { match word { "private" => Ok(crate::AddressSpace::Private), "workgroup" => Ok(crate::AddressSpace::WorkGroup), @@ -16,7 +20,16 @@ pub fn map_address_space(word: &str, span: Span) -> Result<'_, crate::AddressSpa }), "push_constant" => Ok(crate::AddressSpace::PushConstant), "function" => Ok(crate::AddressSpace::Function), - "task_payload" => Ok(crate::AddressSpace::TaskPayload), + "task_payload" => { + if enable_extensions.contains(ImplementedEnableExtension::MeshShader) { + Ok(crate::AddressSpace::TaskPayload) + } else { + Err(Box::new(Error::EnableExtensionNotEnabled { + span, + kind: ImplementedEnableExtension::MeshShader.into(), + })) + } + } _ => Err(Box::new(Error::UnknownAddressSpace(span))), } } @@ -53,7 +66,7 @@ pub fn map_built_in( "subgroup_invocation_id" => crate::BuiltIn::SubgroupInvocationId, // mesh "cull_primitive" => crate::BuiltIn::CullPrimitive, - "vertex_indices" => crate::BuiltIn::PointIndex, + "point_index" => crate::BuiltIn::PointIndex, "line_indices" => crate::BuiltIn::LineIndices, "triangle_indices" => crate::BuiltIn::TriangleIndices, "mesh_task_size" => crate::BuiltIn::MeshTaskSize, @@ -73,6 +86,21 @@ pub fn map_built_in( })); } } + crate::BuiltIn::CullPrimitive + | crate::BuiltIn::PointIndex + | crate::BuiltIn::LineIndices + | crate::BuiltIn::TriangleIndices + | crate::BuiltIn::VertexCount + | crate::BuiltIn::Vertices + | crate::BuiltIn::PrimitiveCount + | crate::BuiltIn::Primitives => { + if !enable_extensions.contains(ImplementedEnableExtension::MeshShader) { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } + } _ => {} } Ok(built_in) diff --git a/naga/src/front/wgsl/parse/mod.rs b/naga/src/front/wgsl/parse/mod.rs index 94df933a6a9..e4c04644347 100644 --- a/naga/src/front/wgsl/parse/mod.rs +++ b/naga/src/front/wgsl/parse/mod.rs @@ -240,6 +240,15 @@ impl<'a> BindingParser<'a> { lexer.expect(Token::Paren(')'))?; } "per_primitive" => { + if !lexer + .enable_extensions + .contains(ImplementedEnableExtension::MeshShader) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: name_span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } self.per_primitive.set((), name_span)?; } _ => return Err(Box::new(Error::UnknownAttribute(name_span))), @@ -1324,7 +1333,7 @@ impl Parser { }; crate::AddressSpace::Storage { access } } - _ => conv::map_address_space(class_str, span)?, + _ => conv::map_address_space(class_str, span, &lexer.enable_extensions)?, }; lexer.expect(Token::Paren('>'))?; } @@ -1697,7 +1706,7 @@ impl Parser { "ptr" => { lexer.expect_generic_paren('<')?; let (ident, span) = lexer.next_ident_with_span()?; - let mut space = conv::map_address_space(ident, span)?; + let mut space = conv::map_address_space(ident, span, &lexer.enable_extensions)?; lexer.expect(Token::Separator(','))?; let base = self.type_decl(lexer, ctx)?; if let crate::AddressSpace::Storage { ref mut access } = space { @@ -2865,10 +2874,28 @@ impl Parser { compute_like_span = name_span; } "task" => { + if !lexer + .enable_extensions + .contains(ImplementedEnableExtension::MeshShader) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: name_span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } stage.set(ShaderStage::Task, name_span)?; compute_like_span = name_span; } "mesh" => { + if !lexer + .enable_extensions + .contains(ImplementedEnableExtension::MeshShader) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: name_span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } stage.set(ShaderStage::Mesh, name_span)?; compute_like_span = name_span; @@ -2877,6 +2904,15 @@ impl Parser { lexer.expect(Token::Paren(')'))?; } "payload" => { + if !lexer + .enable_extensions + .contains(ImplementedEnableExtension::MeshShader) + { + return Err(Box::new(Error::EnableExtensionNotEnabled { + span: name_span, + kind: ImplementedEnableExtension::MeshShader.into(), + })); + } lexer.expect(Token::Paren('('))?; payload.set(lexer.next_ident_with_span()?, name_span)?; lexer.expect(Token::Paren(')'))?; From d95070aeb7473cd90ba5c26f0ea7f3482dc87fd7 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Mon, 3 Nov 2025 13:41:59 -0600 Subject: [PATCH 80/89] Updated mesh shader spec --- docs/api-specs/mesh_shading.md | 117 +++++++++++++++++---------------- 1 file changed, 60 insertions(+), 57 deletions(-) diff --git a/docs/api-specs/mesh_shading.md b/docs/api-specs/mesh_shading.md index 4b28ec635e7..41720765a55 100644 --- a/docs/api-specs/mesh_shading.md +++ b/docs/api-specs/mesh_shading.md @@ -103,10 +103,8 @@ An example of using mesh shaders to render a single triangle can be seen [here]( * DirectX 12 support is planned. * Metal support is desired but not currently planned. - ## Naga implementation - ### Supported frontends * 🛠️ WGSL * ❌ SPIR-V @@ -114,7 +112,7 @@ An example of using mesh shaders to render a single triangle can be seen [here]( ### Supported backends * 🛠️ SPIR-V -* ❌ HLSL +* 🛠️ HLSL * ❌ MSL * 🚫 GLSL * 🚫 WGSL @@ -130,7 +128,7 @@ The majority of changes relating to mesh shaders will be in WGSL and `naga`. Using any of these features in a `wgsl` program will require adding the `enable mesh_shading` directive to the top of a program. -Two new shader stages will be added to `WGSL`. Fragment shaders are also modified slightly. Both task shaders and mesh shaders are allowed to use any compute-specific functionality, such as subgroup operations. +Two new shader stages will be added to `WGSL`. Fragment shaders are also modified slightly. Both task shaders and mesh shaders are allowed to use any compute-available functionality, including subgroup operations. ### Task shader @@ -145,6 +143,8 @@ A task shader entry point must return a `vec3` value. The return value of e Each task shader workgroup dispatches an independent mesh shader grid: in mesh shader invocations, `@builtin` values like `workgroup_id` and `global_invocation_id` describe the position of the workgroup and invocation within that grid; and `@builtin(num_workgroups)` matches the task shader workgroup's return value. Mesh shaders dispatched for other task shader workgroups are not included in the count. If it is necessary for a mesh shader to know which task shader workgroup dispatched it, the task shader can include its own workgroup id in the task payload. +Task shaders can use compute and subgroup builtin inputs, in addition to `view_index` and `draw_id`. + ### Mesh shader A function with the `@mesh` attribute is a **mesh shader entry point**. Mesh shaders must not return anything. @@ -159,17 +159,19 @@ A mesh shader entry point must have the following attributes: - `@workgroup_size`: this has the same meaning as when it appears on a compute shader entry point. -- `@vertex_output(V, NV)`: This indicates that the mesh shader workgroup will generate at most `NV` vertex values, each of type `V`. +- `@mesh(VAR)`: Here, `VAR` represents a workgroup variable storing the output information. -- `@primitive_output(P, NP)`: This indicates that the mesh shader workgroup will generate at most `NP` primitives, each of type `P`. +All mesh shader outputs are per-workgroup, and taken from the workgroup variable specified above. The type must have exactly 4 fields: +- A field decorated with `@builtin(vertex_count)`, with type `u32`: this field represents the number of vertices that will be drawn +- A field decorated with `@builtin(primitive_count)`, with type `u32`: this field represents the number of primitives that will be drawn +- A field decorated with `@builtin(vertices)`, typed as an array of `V`, where `V` is the vertex output type as specified below +- A field decorated with `@builtin(primitives)`, typed as an array of `P`, where `P` is the primitive output type as specified below -Each mesh shader entry point invocation must call the `setMeshOutputs(numVertices: u32, numPrimitives: u32)` builtin function at least once. The values passed by each workgroup's first invocation (that is, the one whose `local_invocation_index` is `0`) determine how many vertices (values of type `V`) and primitives (values of type `P`) the workgroup must produce. The user can still write past these indices, but they won't be used in the output. +For a vertex count `NV`, the first `NV` elements of the vertex array above are outputted. Therefore, `NV` must be less than or equal to the size of the vertex array. The same is true for primitives with `NP`. -The `numVertices` and `numPrimitives` arguments must be no greater than `NV` and `NP` from the `@vertex_output` and `@primitive_output` attributes. +The vertex output type `V` must meet the same requirements as a struct type returned by a `@vertex` entry point: all members must have either `@builtin` or `@location` attributes, there must be a `@builtin(position)`, and so on. -To produce vertex data, the workgroup as a whole must make `numVertices` calls to the `setVertex(i: u32, vertex: V)` builtin function. This establishes `vertex` as the value of the `i`'th vertex, where `i` is less than the maximum number of output vertices in the `@vertex_output` attribute. `V` is the type given in the `@vertex_output` attribute. `V` must meet the same requirements as a struct type returned by a `@vertex` entry point: all members must have either `@builtin` or `@location` attributes, there must be a `@builtin(position)`, and so on. - -To produce primitives, the workgroup as a whole must make `numPrimitives` calls to the `setPrimitive(i: u32, primitive: P)` builtin function. This establishes `primitive` as the value of the `i`'th primitive, where `i` is less than the maximum number of output primitives in the `@primitive_output` attribute. `P` is the type given in the `@primitive_output` attribute. `P` must be a struct type, every member of which either has a `@location` or `@builtin` attribute. The following `@builtin` attributes are allowed: +The primitive output type `P` must be a struct type, every member of which either has a `@location` or `@builtin` attribute. All members decorated with `@location` must also be decorated with `@per_primitive`, as must the corresponding fragment input. The `@per_primitive` decoration may only be applied to members decorated with `@location`. The following `@builtin` attributes are allowed: - `triangle_indices`, `line_indices`, or `point_index`: The annotated member must be of type `vec3`, `vec2`, or `u32`. @@ -179,15 +181,13 @@ To produce primitives, the workgroup as a whole must make `numPrimitives` calls - `cull_primitive`: The annotated member must be of type `bool`. If it is true, then the primitive is skipped during rendering. -Every member of `P` with a `@location` attribute must either have a `@per_primitive` attribute, or be part of a struct type that appears in the primitive data as a struct member with the `@per_primitive` attribute. - The `@location` attributes of `P` and `V` must not overlap, since they are merged to produce the user-defined inputs to the fragment shader. -It is possible to write to the same vertex or primitive index repeatedly. Since the implicit arrays written by `setVertex` and `setPrimitive` are shared by the workgroup, data races on writes to the same index for a given type are undefined behavior. +Mesh shaders can use compute and mesh shader builtin inputs, in addition to `view_index`, and if no task shader is present, `draw_id`. ### Fragment shader -Fragment shaders can access vertex output data as if it is from a vertex shader. They can also access primitive output data, provided the input is decorated with `@per_primitive`. The `@per_primitive` attribute can be applied to a value directly, such as `@per_primitive @location(1) value: vec4`, to a struct such as `@per_primitive primitive_input: PrimitiveInput` where `PrimitiveInput` is a struct containing fields decorated with `@location` and `@builtin`, or to members of a struct that are themselves decorated with `@location` or `@builtin`. +Fragment shaders can access vertex output data as if it is from a vertex shader. They can also access primitive output data, provided the input is decorated with `@per_primitive`. The `@per_primitive` decoration may only be applied to inputs or struct members decorated with `@location`. The primitive state is part of the fragment input and must match the output of the mesh shader in the pipeline. Using `@per_primitive` also requires enabling the mesh shader extension. Additionally, the locations of vertex and primitive input cannot overlap. @@ -199,72 +199,75 @@ The following is a full example of WGSL shaders that could be used to create a m enable mesh_shading; const positions = array( - vec4(0.,1.,0.,1.), - vec4(-1.,-1.,0.,1.), - vec4(1.,-1.,0.,1.) + vec4(0., 1., 0., 1.), + vec4(-1., -1., 0., 1.), + vec4(1., -1., 0., 1.) ); const colors = array( - vec4(0.,1.,0.,1.), - vec4(0.,0.,1.,1.), - vec4(1.,0.,0.,1.) + vec4(0., 1., 0., 1.), + vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.) ); struct TaskPayload { - colorMask: vec4, - visible: bool, + colorMask: vec4, + visible: bool, } var taskPayload: TaskPayload; var workgroupData: f32; struct VertexOutput { - @builtin(position) position: vec4, - @location(0) color: vec4, + @builtin(position) position: vec4, + @location(0) color: vec4, } struct PrimitiveOutput { - @builtin(triangle_indices) index: vec3, - @builtin(cull_primitive) cull: bool, - @per_primitive @location(1) colorMask: vec4, + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, } struct PrimitiveInput { - @per_primitive @location(1) colorMask: vec4, + @per_primitive @location(1) colorMask: vec4, } @task @payload(taskPayload) @workgroup_size(1) fn ts_main() -> @builtin(mesh_task_size) vec3 { - workgroupData = 1.0; - taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); - taskPayload.visible = true; - return vec3(3, 1, 1); + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, } -@mesh + +var mesh_output: MeshOutput; +@mesh(mesh_output) @payload(taskPayload) -@vertex_output(VertexOutput, 3) @primitive_output(PrimitiveOutput, 1) @workgroup_size(1) fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { - setMeshOutputs(3, 1); - workgroupData = 2.0; - var v: VertexOutput; - - v.position = positions[0]; - v.color = colors[0] * taskPayload.colorMask; - setVertex(0, v); - - v.position = positions[1]; - v.color = colors[1] * taskPayload.colorMask; - setVertex(1, v); - - v.position = positions[2]; - v.color = colors[2] * taskPayload.colorMask; - setVertex(2, v); - - var p: PrimitiveOutput; - p.index = vec3(0, 1, 2); - p.cull = !taskPayload.visible; - p.colorMask = vec4(1.0, 0.0, 1.0, 1.0); - setPrimitive(0, p); + mesh_output.vertex_count = 3; + mesh_output.primitive_count = 1; + workgroupData = 2.0; + + mesh_output.vertices[0].position = positions[0]; + mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask; + + mesh_output.vertices[1].position = positions[1]; + mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask; + + mesh_output.vertices[2].position = positions[2]; + mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask; + + mesh_output.primitives[0].index = vec3(0, 1, 2); + mesh_output.primitives[0].cull = !taskPayload.visible; + mesh_output.primitives[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); } @fragment fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { - return vertex.color * primitive.colorMask; + return vertex.color * primitive.colorMask; } ``` From ace7e17f7f8b83d96e7b6f0cfc9012f1aa514a42 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Mon, 3 Nov 2025 14:50:00 -0600 Subject: [PATCH 81/89] Cleaned up the mesh shader analyzer function --- naga/src/proc/mod.rs | 55 ++++++++++++++++++++++++++++--------- naga/src/valid/interface.rs | 4 +-- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/naga/src/proc/mod.rs b/naga/src/proc/mod.rs index 4271db391c5..64da0a9661e 100644 --- a/naga/src/proc/mod.rs +++ b/naga/src/proc/mod.rs @@ -660,6 +660,12 @@ impl crate::Module { /// Extracts mesh shader info from a mesh output global variable. Used in frontends /// and by validators. This only validates the output variable itself, and not the /// vertex and primitive output types. + /// + /// The output contains the extracted mesh stage info, with overrides unset, + /// and then the overrides separately. This is because the overrides should be + /// treated as expressions elsewhere, but that requires mutably modifying the + /// module and the expressions should only be created at parse time, not validation + /// time. #[allow(clippy::type_complexity)] pub fn analyze_mesh_shader_info( &self, @@ -671,6 +677,19 @@ impl crate::Module { ) { use crate::span::AddSpan; use crate::valid::EntryPointError; + #[derive(Default)] + struct OutError { + pub inner: Option, + } + impl OutError { + pub fn set(&mut self, err: EntryPointError) { + if self.inner.is_none() { + self.inner = Some(err); + } + } + } + + // Used to temporarily initialize stuff let null_type = crate::Handle::new(NonMaxU32::new(0).unwrap()); let mut output = crate::MeshStageInfo { topology: crate::MeshOutputTopology::Triangles, @@ -682,7 +701,8 @@ impl crate::Module { primitive_output_type: null_type, output_variable: gv, }; - let mut error = None; + // Stores the error to output, if any. + let mut error = OutError::default(); let r#type = &self.types[self.global_variables[gv].ty].inner; let mut topology = output.topology; @@ -696,20 +716,24 @@ impl crate::Module { for member in members { match member.binding { Some(crate::Binding::BuiltIn(crate::BuiltIn::VertexCount)) => { + // Must have type u32 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) { - error = Some(EntryPointError::BadMeshOutputVariableField); + error.set(EntryPointError::BadMeshOutputVariableField); } + // Each builtin should only occur once if builtins.contains(&crate::BuiltIn::VertexCount) { - error = Some(EntryPointError::BadMeshOutputVariableType); + error.set(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::VertexCount); } Some(crate::Binding::BuiltIn(crate::BuiltIn::PrimitiveCount)) => { + // Must have type u32 if self.types[member.ty].inner.scalar() != Some(crate::Scalar::U32) { - error = Some(EntryPointError::BadMeshOutputVariableField); + error.set(EntryPointError::BadMeshOutputVariableField); } + // Each builtin should only occur once if builtins.contains(&crate::BuiltIn::PrimitiveCount) { - error = Some(EntryPointError::BadMeshOutputVariableType); + error.set(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::PrimitiveCount); } @@ -717,6 +741,7 @@ impl crate::Module { crate::BuiltIn::Vertices | crate::BuiltIn::Primitives, )) => { let ty = &self.types[member.ty].inner; + // Analyze the array type to determine size and vertex/primitive type let (a, b, c) = match ty { &crate::TypeInner::Array { base, size, .. } => { let ty = base; @@ -724,15 +749,14 @@ impl crate::Module { crate::ArraySize::Constant(a) => (a.get(), None), crate::ArraySize::Pending(o) => (0, Some(o)), crate::ArraySize::Dynamic => { - error = - Some(EntryPointError::BadMeshOutputVariableField); + error.set(EntryPointError::BadMeshOutputVariableField); (0, None) } }; (max, max_override, ty) } _ => { - error = Some(EntryPointError::BadMeshOutputVariableField); + error.set(EntryPointError::BadMeshOutputVariableField); (0, None, null_type) } }; @@ -740,6 +764,7 @@ impl crate::Module { member.binding, Some(crate::Binding::BuiltIn(crate::BuiltIn::Primitives)) ) { + // Primitives require special analysis to determine topology primitive_info = (a, b, c); match self.types[c].inner { crate::TypeInner::Struct { ref members, .. } => { @@ -766,19 +791,21 @@ impl crate::Module { } _ => (), } + // Each builtin should only occur once if builtins.contains(&crate::BuiltIn::Primitives) { - error = Some(EntryPointError::BadMeshOutputVariableType); + error.set(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::Primitives); } else { vertex_info = (a, b, c); + // Each builtin should only occur once if builtins.contains(&crate::BuiltIn::Vertices) { - error = Some(EntryPointError::BadMeshOutputVariableType); + error.set(EntryPointError::BadMeshOutputVariableType); } builtins.insert(crate::BuiltIn::Vertices); } } - _ => error = Some(EntryPointError::BadMeshOutputVariableType), + _ => error.set(EntryPointError::BadMeshOutputVariableType), } } output = crate::MeshStageInfo { @@ -792,12 +819,14 @@ impl crate::Module { ..output } } - _ => error = Some(EntryPointError::BadMeshOutputVariableType), + _ => error.set(EntryPointError::BadMeshOutputVariableType), } ( output, [vertex_info.1, primitive_info.1], - error.map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)), + error + .inner + .map(|a| a.with_span_handle(self.global_variables[gv].ty, &self.types)), ) } } diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 6c297112fc5..449ae5b163a 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -397,9 +397,9 @@ impl VaryingContext<'_> { scalar: crate::Scalar::U32, }, ), - // Validated elsewhere + // Validated elsewhere, shouldn't be here Bi::VertexCount | Bi::PrimitiveCount | Bi::Vertices | Bi::Primitives => { - (true, true) + (false, true) } }; From c0278f34df186a080e7a92f7b48ecff5cc10542a Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 6 Nov 2025 20:02:59 -0600 Subject: [PATCH 82/89] Updated changelog --- CHANGELOG.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 528f30df279..a868658302b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -139,9 +139,8 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206 - Removed three features from `wgpu-hal` which did nothing useful: `"cargo-clippy"`, `"gpu-allocator"`, and `"rustc-hash"`. By @kpreid in [#8357](https://github.com/gfx-rs/wgpu/pull/8357). - `wgpu_types::PollError` now always implements the `Error` trait. By @kpreid in [#8384](https://github.com/gfx-rs/wgpu/pull/8384). - The texture subresources used by the color attachments of a render pass are no longer allowed to overlap when accessed via different texture views. By @andyleiserson in [#8402](https://github.com/gfx-rs/wgpu/pull/8402). -- Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370). -- Fixed a bug where the texture aspect was not passed through when calling `copy_texture_to_buffer` in WebGPU, causing the copy to fail for depth/stencil textures. By @Tim-Evans-Seequent in [#8445](https://github.com/gfx-rs/wgpu/pull/8445). - Validate that buffers are unmapped in `write_buffer` calls. By @ErichDonGubler in [#8454](https://github.com/gfx-rs/wgpu/pull/8454). +- Add WGSL parsing for mesh shaders. By @inner-daemons in [#8370](https://github.com/gfx-rs/wgpu/pull/8370). #### DX12 From e7dc9e54f3b94058bb0be607f8ea1734869e2a24 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 6 Nov 2025 20:09:40 -0600 Subject: [PATCH 83/89] Updated validation from spv-write --- naga/src/valid/interface.rs | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/naga/src/valid/interface.rs b/naga/src/valid/interface.rs index 6c30c554420..a040fd1604d 100644 --- a/naga/src/valid/interface.rs +++ b/naga/src/valid/interface.rs @@ -98,8 +98,6 @@ pub enum VaryingError { InvalidPerPrimitive, #[error("Non-builtin members of a mesh primitive output struct must be decorated with `@per_primitive`")] MissingPerPrimitive, - #[error("The `MESH_SHADER` capability must be enabled to use per-primitive fragment inputs.")] - PerPrimitiveNotAllowed, } #[derive(Clone, Debug, thiserror::Error)] @@ -402,6 +400,24 @@ impl VaryingContext<'_> { (false, true) } }; + match built_in { + Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices + | Bi::MeshTaskSize + | Bi::VertexCount + | Bi::PrimitiveCount + | Bi::Vertices + | Bi::Primitives => { + if !self.capabilities.contains(Capabilities::MESH_SHADER) { + return Err(VaryingError::UnsupportedCapability( + Capabilities::MESH_SHADER, + )); + } + } + _ => (), + } if !visible { return Err(VaryingError::InvalidBuiltInStage(built_in)); @@ -419,7 +435,9 @@ impl VaryingContext<'_> { per_primitive, } => { if per_primitive && !self.capabilities.contains(Capabilities::MESH_SHADER) { - return Err(VaryingError::PerPrimitiveNotAllowed); + return Err(VaryingError::UnsupportedCapability( + Capabilities::MESH_SHADER, + )); } // Only IO-shareable types may be stored in locations. if !self.type_info[ty.index()] @@ -1130,6 +1148,7 @@ impl super::Validator { mesh_info.primitive_output_type, MeshOutputType::PrimitiveOutput, )?; + info.insert_global_use(GlobalUse::READ, mesh_info.output_variable); } Ok(info) From 592ac165a2a1d7763c0c59e48b444a6643a8e1a4 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Thu, 6 Nov 2025 20:24:25 -0600 Subject: [PATCH 84/89] Slight tweaks --- naga/src/ir/mod.rs | 2 +- naga/tests/out/analysis/wgsl-mesh-shader.info.ron | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/naga/src/ir/mod.rs b/naga/src/ir/mod.rs index 097220a46bb..c3deabe706d 100644 --- a/naga/src/ir/mod.rs +++ b/naga/src/ir/mod.rs @@ -2583,7 +2583,7 @@ pub struct DocComments { pub enum MeshOutputTopology { /// Outputs individual vertices to be rendered as points. Points, - /// Outputs groups of 2 vertices to be renderedas lines . + /// Outputs groups of 2 vertices to be rendered as lines. Lines, /// Outputs groups of 3 vertices to be rendered as triangles. Triangles, diff --git a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron index 9422d07107d..eacd33ad0f1 100644 --- a/naga/tests/out/analysis/wgsl-mesh-shader.info.ron +++ b/naga/tests/out/analysis/wgsl-mesh-shader.info.ron @@ -233,7 +233,7 @@ global_uses: [ ("READ"), ("WRITE"), - ("WRITE"), + ("READ | WRITE"), ], expressions: [ ( From f0e48afd38d861aff41b8b39dc4eba55868725de Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 7 Nov 2025 14:35:24 -0600 Subject: [PATCH 85/89] Initial changes to setup, tests, examples (no changes to logic) --- examples/features/src/mesh_shader/mod.rs | 99 ++++++------ .../features/src/mesh_shader/shader.metal | 77 --------- examples/features/src/mesh_shader/shader.wgsl | 74 +++++++++ naga/tests/in/wgsl/mesh-shader.toml | 2 +- tests/tests/wgpu-gpu/mesh_shader/mod.rs | 149 +++++++----------- tests/tests/wgpu-gpu/mesh_shader/shader.metal | 77 --------- tests/tests/wgpu-gpu/mesh_shader/shader.wgsl | 74 +++++++++ 7 files changed, 257 insertions(+), 295 deletions(-) delete mode 100644 examples/features/src/mesh_shader/shader.metal create mode 100644 examples/features/src/mesh_shader/shader.wgsl delete mode 100644 tests/tests/wgpu-gpu/mesh_shader/shader.metal create mode 100644 tests/tests/wgpu-gpu/mesh_shader/shader.wgsl diff --git a/examples/features/src/mesh_shader/mod.rs b/examples/features/src/mesh_shader/mod.rs index a0d2272363d..9a202d19272 100644 --- a/examples/features/src/mesh_shader/mod.rs +++ b/examples/features/src/mesh_shader/mod.rs @@ -1,6 +1,12 @@ use std::process::Stdio; // Same as in mesh shader tests +fn compile_wgsl(device: &wgpu::Device) -> wgpu::ShaderModule { + device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()), + }) +} fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::ShaderModule { let cmd = std::process::Command::new("glslc") .args([ @@ -61,18 +67,6 @@ fn compile_hlsl(device: &wgpu::Device, entry: &str, stage_str: &str) -> wgpu::Sh } } -fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule { - unsafe { - device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough { - entry_point: entry.to_owned(), - label: None, - msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))), - num_workgroups: (1, 1, 1), - ..Default::default() - }) - } -} - pub struct Example { pipeline: wgpu::RenderPipeline, } @@ -83,24 +77,31 @@ impl crate::framework::Example for Example { device: &wgpu::Device, _queue: &wgpu::Queue, ) -> Self { - let (ts, ms, fs) = match adapter.get_info().backend { - wgpu::Backend::Vulkan => ( - compile_glsl(device, "task"), - compile_glsl(device, "mesh"), - compile_glsl(device, "frag"), - ), - wgpu::Backend::Dx12 => ( - compile_hlsl(device, "Task", "as"), - compile_hlsl(device, "Mesh", "ms"), - compile_hlsl(device, "Frag", "ps"), - ), - wgpu::Backend::Metal => ( - compile_msl(device, "taskShader"), - compile_msl(device, "meshShader"), - compile_msl(device, "fragShader"), - ), - _ => panic!("Example can currently only run on vulkan, dx12 or metal"), - }; + let (ts, ms, fs, ts_name, ms_name, fs_name) = + if adapter.get_info().backend == wgpu::Backend::Metal { + let s = compile_wgsl(device); + (s.clone(), s.clone(), s, "ts_main", "ms_main", "fs_main") + } else if adapter.get_info().backend == wgpu::Backend::Vulkan { + ( + compile_glsl(device, "task"), + compile_glsl(device, "mesh"), + compile_glsl(device, "frag"), + "main", + "main", + "main", + ) + } else if adapter.get_info().backend == wgpu::Backend::Dx12 { + ( + compile_hlsl(device, "Task", "as"), + compile_hlsl(device, "Mesh", "ms"), + compile_hlsl(device, "Frag", "ps"), + "main", + "main", + "main", + ) + } else { + panic!("Example can only run on vulkan or dx12"); + }; let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: None, bind_group_layouts: &[], @@ -111,17 +112,17 @@ impl crate::framework::Example for Example { layout: Some(&pipeline_layout), task: Some(wgpu::TaskState { module: &ts, - entry_point: Some("main"), + entry_point: Some(ts_name), compilation_options: Default::default(), }), mesh: wgpu::MeshState { module: &ms, - entry_point: Some("main"), + entry_point: Some(ms_name), compilation_options: Default::default(), }, fragment: Some(wgpu::FragmentState { module: &fs, - entry_point: Some("main"), + entry_point: Some(fs_name), compilation_options: Default::default(), targets: &[Some(config.view_formats[0].into())], }), @@ -197,18 +198,20 @@ pub fn main() { #[cfg(test)] #[wgpu_test::gpu_test] -pub static TEST: crate::framework::ExampleTestParams = crate::framework::ExampleTestParams { - name: "mesh_shader", - image_path: "/examples/features/src/mesh_shader/screenshot.png", - width: 1024, - height: 768, - optional_features: wgpu::Features::default(), - base_test_parameters: wgpu_test::TestParameters::default() - .features( - wgpu::Features::EXPERIMENTAL_MESH_SHADER - | wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS, - ) - .limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()), - comparisons: &[wgpu_test::ComparisonType::Mean(0.01)], - _phantom: std::marker::PhantomData::, -}; +pub static TEST: crate::framework::ExampleTestParams = + crate::framework::ExampleTestParams { + name: "mesh_shader", + // Generated on 1080ti on Vk/Windows + image_path: "/examples/features/src/mesh_shader/screenshot.png", + width: 1024, + height: 768, + optional_features: wgpu::Features::default(), + base_test_parameters: wgpu_test::TestParameters::default() + .features( + wgpu::Features::EXPERIMENTAL_MESH_SHADER + | wgpu::Features::EXPERIMENTAL_PASSTHROUGH_SHADERS, + ) + .limits(wgpu::Limits::defaults().using_recommended_minimum_mesh_shader_values()), + comparisons: &[wgpu_test::ComparisonType::Mean(0.005)], + _phantom: std::marker::PhantomData::, + }; diff --git a/examples/features/src/mesh_shader/shader.metal b/examples/features/src/mesh_shader/shader.metal deleted file mode 100644 index 4c7da503832..00000000000 --- a/examples/features/src/mesh_shader/shader.metal +++ /dev/null @@ -1,77 +0,0 @@ -using namespace metal; - -struct OutVertex { - float4 Position [[position]]; - float4 Color [[user(locn0)]]; -}; - -struct OutPrimitive { - float4 ColorMask [[flat]] [[user(locn1)]]; - bool CullPrimitive [[primitive_culled]]; -}; - -struct InVertex { -}; - -struct InPrimitive { - float4 ColorMask [[flat]] [[user(locn1)]]; -}; - -struct FragmentIn { - float4 Color [[user(locn0)]]; - float4 ColorMask [[flat]] [[user(locn1)]]; -}; - -struct PayloadData { - float4 ColorMask; - bool Visible; -}; - -using Meshlet = metal::mesh; - - -constant float4 positions[3] = { - float4(0.0, 1.0, 0.0, 1.0), - float4(-1.0, -1.0, 0.0, 1.0), - float4(1.0, -1.0, 0.0, 1.0) -}; - -constant float4 colors[3] = { - float4(0.0, 1.0, 0.0, 1.0), - float4(0.0, 0.0, 1.0, 1.0), - float4(1.0, 0.0, 0.0, 1.0) -}; - - -[[object]] -void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) { - outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0); - outPayload.Visible = true; - grid.set_threadgroups_per_grid(uint3(3, 1, 1)); -} - -[[mesh]] -void meshShader( - object_data PayloadData const& payload [[payload]], - Meshlet out -) -{ - out.set_primitive_count(1); - - for(int i = 0;i < 3;i++) { - OutVertex vert; - vert.Position = positions[i]; - vert.Color = colors[i] * payload.ColorMask; - out.set_vertex(i, vert); - out.set_index(i, i); - } - - OutPrimitive prim; - prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0); - prim.CullPrimitive = !payload.Visible; - out.set_primitive(0, prim); -} - -fragment float4 fragShader(FragmentIn data [[stage_in]]) { - return data.Color * data.ColorMask; -} diff --git a/examples/features/src/mesh_shader/shader.wgsl b/examples/features/src/mesh_shader/shader.wgsl new file mode 100644 index 00000000000..cdc7366b415 --- /dev/null +++ b/examples/features/src/mesh_shader/shader.wgsl @@ -0,0 +1,74 @@ +enable mesh_shading; + +const positions = array( + vec4(0., 1., 0., 1.), + vec4(-1., -1., 0., 1.), + vec4(1., -1., 0., 1.) +); +const colors = array( + vec4(0., 1., 0., 1.), + vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var mesh_output: MeshOutput; +@mesh(mesh_output) +@payload(taskPayload) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + mesh_output.vertex_count = 3; + mesh_output.primitive_count = 1; + workgroupData = 2.0; + + mesh_output.vertices[0].position = positions[0]; + mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask; + + mesh_output.vertices[1].position = positions[1]; + mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask; + + mesh_output.vertices[2].position = positions[2]; + mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask; + + mesh_output.primitives[0].index = vec3(0, 1, 2); + mesh_output.primitives[0].cull = !taskPayload.visible; + mesh_output.primitives[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} diff --git a/naga/tests/in/wgsl/mesh-shader.toml b/naga/tests/in/wgsl/mesh-shader.toml index 1f8b4e23baa..accbae9f2de 100644 --- a/naga/tests/in/wgsl/mesh-shader.toml +++ b/naga/tests/in/wgsl/mesh-shader.toml @@ -1,7 +1,7 @@ # Stolen from ray-query.toml god_mode = true -targets = "IR | ANALYSIS" +targets = "IR | ANALYSIS | METAL" [msl] fake_missing_bindings = true diff --git a/tests/tests/wgpu-gpu/mesh_shader/mod.rs b/tests/tests/wgpu-gpu/mesh_shader/mod.rs index 161c49de569..675436a1d7e 100644 --- a/tests/tests/wgpu-gpu/mesh_shader/mod.rs +++ b/tests/tests/wgpu-gpu/mesh_shader/mod.rs @@ -3,11 +3,16 @@ use std::{ process::Stdio, }; -use wgpu::util::DeviceExt; +use wgpu::{util::DeviceExt, Backends}; use wgpu_test::{ - fail, gpu_test, GpuTestConfiguration, GpuTestInitializer, TestParameters, TestingContext, + gpu_test, FailureCase, GpuTestConfiguration, GpuTestInitializer, TestParameters, TestingContext, }; +/// Backends that support mesh shaders +const MESH_SHADER_BACKENDS: Backends = Backends::DX12 + .union(Backends::VULKAN) + .union(Backends::METAL); + pub fn all_tests(tests: &mut Vec) { tests.extend([ MESH_PIPELINE_BASIC_MESH, @@ -19,11 +24,16 @@ pub fn all_tests(tests: &mut Vec) { MESH_MULTI_DRAW_INDIRECT_COUNT, MESH_PIPELINE_BASIC_MESH_NO_DRAW, MESH_PIPELINE_BASIC_TASK_MESH_FRAG_NO_DRAW, - MESH_DISABLED, ]); } // Same as in mesh shader example +fn compile_wgsl(device: &wgpu::Device) -> wgpu::ShaderModule { + device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: None, + source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()), + }) +} fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::ShaderModule { let cmd = std::process::Command::new("glslc") .args([ @@ -51,7 +61,6 @@ fn compile_glsl(device: &wgpu::Device, shader_stage: &'static str) -> wgpu::Shad }) } } - fn compile_hlsl( device: &wgpu::Device, entry: &str, @@ -94,18 +103,6 @@ fn compile_hlsl( } } -fn compile_msl(device: &wgpu::Device, entry: &str) -> wgpu::ShaderModule { - unsafe { - device.create_shader_module_passthrough(wgpu::ShaderModuleDescriptorPassthrough { - entry_point: entry.to_owned(), - label: None, - msl: Some(std::borrow::Cow::Borrowed(include_str!("shader.metal"))), - num_workgroups: (1, 1, 1), - ..Default::default() - }) - } -} - fn get_shaders( device: &wgpu::Device, backend: wgpu::Backend, @@ -115,47 +112,47 @@ fn get_shaders( Option, wgpu::ShaderModule, Option, + &'static str, + &'static str, + &'static str, ) { - // On backends that don't support mesh shaders, or for the MESH_DISABLED - // test, compile a dummy shader so we can construct a structurally valid - // pipeline description and test that `create_mesh_pipeline` fails. - // (In the case that the platform does support mesh shaders, the dummy - // shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS.) + // In the case that the platform does support mesh shaders, the dummy + // shader is used to avoid requiring EXPERIMENTAL_PASSTHROUGH_SHADERS. let dummy_shader = device.create_shader_module(wgpu::include_wgsl!("non_mesh.wgsl")); - match backend { - wgpu::Backend::Vulkan => ( + if backend == wgpu::Backend::Metal { + let s = compile_wgsl(device); + ( + info.use_task.then(|| s.clone()), + s.clone(), + info.use_frag.then_some(s), + "ts_main", + if info.use_task { "ms_main" } else { "ms_no_ts" }, + "fs_main", + ) + } else if backend == wgpu::Backend::Vulkan { + ( info.use_task.then(|| compile_glsl(device, "task")), - if info.use_mesh { - compile_glsl(device, "mesh") - } else { - dummy_shader - }, + compile_glsl(device, "mesh"), info.use_frag.then(|| compile_glsl(device, "frag")), - ), - wgpu::Backend::Dx12 => ( + "main", + "main", + "main", + ) + } else if backend == wgpu::Backend::Dx12 { + ( info.use_task .then(|| compile_hlsl(device, "Task", "as", test_name)), - if info.use_mesh { - compile_hlsl(device, "Mesh", "ms", test_name) - } else { - dummy_shader - }, + compile_hlsl(device, "Mesh", "ms", test_name), info.use_frag .then(|| compile_hlsl(device, "Frag", "ps", test_name)), - ), - wgpu::Backend::Metal => ( - info.use_task.then(|| compile_msl(device, "taskShader")), - if info.use_mesh { - compile_msl(device, "meshShader") - } else { - dummy_shader - }, - info.use_frag.then(|| compile_msl(device, "fragShader")), - ), - _ => { - assert!(!info.use_task && !info.use_mesh && !info.use_frag); - (None, dummy_shader, None) - } + "main", + "main", + "main", + ) + } else { + assert!(!MESH_SHADER_BACKENDS.contains(Backends::from(backend))); + assert!(!info.use_task && !info.use_frag); + (None, dummy_shader, None, "main", "main", "main") } } @@ -190,7 +187,6 @@ fn create_depth( struct MeshPipelineTestInfo { use_task: bool, - use_mesh: bool, use_frag: bool, draw: bool, } @@ -207,7 +203,8 @@ fn mesh_pipeline_build(ctx: &TestingContext, info: MeshPipelineTestInfo) { let (_depth_image, depth_view, depth_state) = create_depth(device); let test_hash = hash_testing_context(ctx).to_string(); - let (task, mesh, frag) = get_shaders(device, backend, &test_hash, &info); + let (task, mesh, frag, ts_name, ms_name, fs_name) = + get_shaders(device, backend, &test_hash, &info); let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { label: None, bind_group_layouts: &[], @@ -218,17 +215,17 @@ fn mesh_pipeline_build(ctx: &TestingContext, info: MeshPipelineTestInfo) { layout: Some(&layout), task: task.as_ref().map(|task| wgpu::TaskState { module: task, - entry_point: Some("main"), + entry_point: Some(ts_name), compilation_options: Default::default(), }), mesh: wgpu::MeshState { module: &mesh, - entry_point: Some("main"), + entry_point: Some(ms_name), compilation_options: Default::default(), }, fragment: frag.as_ref().map(|frag| wgpu::FragmentState { module: frag, - entry_point: Some("main"), + entry_point: Some(fs_name), targets: &[], compilation_options: Default::default(), }), @@ -289,11 +286,11 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { let test_hash = hash_testing_context(ctx).to_string(); let info = MeshPipelineTestInfo { use_task: true, - use_mesh: true, use_frag: true, draw: true, }; - let (task, mesh, frag) = get_shaders(device, backend, &test_hash, &info); + let (task, mesh, frag, ts_name, ms_name, fs_name) = + get_shaders(device, backend, &test_hash, &info); let task = task.unwrap(); let frag = frag.unwrap(); let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { @@ -306,17 +303,17 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { layout: Some(&layout), task: Some(wgpu::TaskState { module: &task, - entry_point: Some("main"), + entry_point: Some(ts_name), compilation_options: Default::default(), }), mesh: wgpu::MeshState { module: &mesh, - entry_point: Some("main"), + entry_point: Some(ms_name), compilation_options: Default::default(), }, fragment: Some(wgpu::FragmentState { module: &frag, - entry_point: Some("main"), + entry_point: Some(fs_name), targets: &[], compilation_options: Default::default(), }), @@ -393,6 +390,7 @@ fn mesh_draw(ctx: &TestingContext, draw_type: DrawType) { fn default_gpu_test_config(draw_type: DrawType) -> GpuTestConfiguration { GpuTestConfiguration::new().parameters( TestParameters::default() + .skip(FailureCase::backend(!MESH_SHADER_BACKENDS)) .test_features_limits() .features( wgpu::Features::EXPERIMENTAL_MESH_SHADER @@ -415,7 +413,6 @@ pub static MESH_PIPELINE_BASIC_MESH: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: false, - use_mesh: true, use_frag: false, draw: true, }, @@ -428,7 +425,6 @@ pub static MESH_PIPELINE_BASIC_TASK_MESH: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: true, - use_mesh: true, use_frag: false, draw: true, }, @@ -441,7 +437,6 @@ pub static MESH_PIPELINE_BASIC_MESH_FRAG: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: false, - use_mesh: true, use_frag: true, draw: true, }, @@ -454,7 +449,6 @@ pub static MESH_PIPELINE_BASIC_TASK_MESH_FRAG: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: true, - use_mesh: true, use_frag: true, draw: true, }, @@ -467,7 +461,6 @@ pub static MESH_PIPELINE_BASIC_MESH_NO_DRAW: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: false, - use_mesh: true, use_frag: false, draw: false, }, @@ -480,7 +473,6 @@ pub static MESH_PIPELINE_BASIC_TASK_MESH_FRAG_NO_DRAW: GpuTestConfiguration = &ctx, MeshPipelineTestInfo { use_task: true, - use_mesh: true, use_frag: true, draw: false, }, @@ -503,30 +495,3 @@ pub static MESH_MULTI_DRAW_INDIRECT_COUNT: GpuTestConfiguration = default_gpu_test_config(DrawType::MultiIndirectCount).run_sync(|ctx| { mesh_draw(&ctx, DrawType::MultiIndirectCount); }); - -/// When the mesh shading feature is disabled, calls to `create_mesh_pipeline` -/// should be rejected. This should be the case on all backends, not just the -/// ones where the feature could be turned on. -#[gpu_test] -pub static MESH_DISABLED: GpuTestConfiguration = GpuTestConfiguration::new().run_sync(|ctx| { - fail( - &ctx.device, - || { - mesh_pipeline_build( - &ctx, - MeshPipelineTestInfo { - use_task: false, - use_mesh: false, - use_frag: false, - draw: true, - }, - ); - }, - Some(concat![ - "Features Features { ", - "features_wgpu: FeaturesWGPU(EXPERIMENTAL_MESH_SHADER), ", - "features_webgpu: FeaturesWebGPU(0x0) ", - "} are required but not enabled on the device", - ]), - ) -}); diff --git a/tests/tests/wgpu-gpu/mesh_shader/shader.metal b/tests/tests/wgpu-gpu/mesh_shader/shader.metal deleted file mode 100644 index 4c7da503832..00000000000 --- a/tests/tests/wgpu-gpu/mesh_shader/shader.metal +++ /dev/null @@ -1,77 +0,0 @@ -using namespace metal; - -struct OutVertex { - float4 Position [[position]]; - float4 Color [[user(locn0)]]; -}; - -struct OutPrimitive { - float4 ColorMask [[flat]] [[user(locn1)]]; - bool CullPrimitive [[primitive_culled]]; -}; - -struct InVertex { -}; - -struct InPrimitive { - float4 ColorMask [[flat]] [[user(locn1)]]; -}; - -struct FragmentIn { - float4 Color [[user(locn0)]]; - float4 ColorMask [[flat]] [[user(locn1)]]; -}; - -struct PayloadData { - float4 ColorMask; - bool Visible; -}; - -using Meshlet = metal::mesh; - - -constant float4 positions[3] = { - float4(0.0, 1.0, 0.0, 1.0), - float4(-1.0, -1.0, 0.0, 1.0), - float4(1.0, -1.0, 0.0, 1.0) -}; - -constant float4 colors[3] = { - float4(0.0, 1.0, 0.0, 1.0), - float4(0.0, 0.0, 1.0, 1.0), - float4(1.0, 0.0, 0.0, 1.0) -}; - - -[[object]] -void taskShader(uint3 tid [[thread_position_in_grid]], object_data PayloadData &outPayload [[payload]], mesh_grid_properties grid) { - outPayload.ColorMask = float4(1.0, 1.0, 0.0, 1.0); - outPayload.Visible = true; - grid.set_threadgroups_per_grid(uint3(3, 1, 1)); -} - -[[mesh]] -void meshShader( - object_data PayloadData const& payload [[payload]], - Meshlet out -) -{ - out.set_primitive_count(1); - - for(int i = 0;i < 3;i++) { - OutVertex vert; - vert.Position = positions[i]; - vert.Color = colors[i] * payload.ColorMask; - out.set_vertex(i, vert); - out.set_index(i, i); - } - - OutPrimitive prim; - prim.ColorMask = float4(1.0, 0.0, 0.0, 1.0); - prim.CullPrimitive = !payload.Visible; - out.set_primitive(0, prim); -} - -fragment float4 fragShader(FragmentIn data [[stage_in]]) { - return data.Color * data.ColorMask; -} diff --git a/tests/tests/wgpu-gpu/mesh_shader/shader.wgsl b/tests/tests/wgpu-gpu/mesh_shader/shader.wgsl new file mode 100644 index 00000000000..cdc7366b415 --- /dev/null +++ b/tests/tests/wgpu-gpu/mesh_shader/shader.wgsl @@ -0,0 +1,74 @@ +enable mesh_shading; + +const positions = array( + vec4(0., 1., 0., 1.), + vec4(-1., -1., 0., 1.), + vec4(1., -1., 0., 1.) +); +const colors = array( + vec4(0., 1., 0., 1.), + vec4(0., 0., 1., 1.), + vec4(1., 0., 0., 1.) +); +struct TaskPayload { + colorMask: vec4, + visible: bool, +} +var taskPayload: TaskPayload; +var workgroupData: f32; +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, +} +struct PrimitiveOutput { + @builtin(triangle_indices) index: vec3, + @builtin(cull_primitive) cull: bool, + @per_primitive @location(1) colorMask: vec4, +} +struct PrimitiveInput { + @per_primitive @location(1) colorMask: vec4, +} + +@task +@payload(taskPayload) +@workgroup_size(1) +fn ts_main() -> @builtin(mesh_task_size) vec3 { + workgroupData = 1.0; + taskPayload.colorMask = vec4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return vec3(3, 1, 1); +} + +struct MeshOutput { + @builtin(vertices) vertices: array, + @builtin(primitives) primitives: array, + @builtin(vertex_count) vertex_count: u32, + @builtin(primitive_count) primitive_count: u32, +} + +var mesh_output: MeshOutput; +@mesh(mesh_output) +@payload(taskPayload) +@workgroup_size(1) +fn ms_main(@builtin(local_invocation_index) index: u32, @builtin(global_invocation_id) id: vec3) { + mesh_output.vertex_count = 3; + mesh_output.primitive_count = 1; + workgroupData = 2.0; + + mesh_output.vertices[0].position = positions[0]; + mesh_output.vertices[0].color = colors[0] * taskPayload.colorMask; + + mesh_output.vertices[1].position = positions[1]; + mesh_output.vertices[1].color = colors[1] * taskPayload.colorMask; + + mesh_output.vertices[2].position = positions[2]; + mesh_output.vertices[2].color = colors[2] * taskPayload.colorMask; + + mesh_output.primitives[0].index = vec3(0, 1, 2); + mesh_output.primitives[0].cull = !taskPayload.visible; + mesh_output.primitives[0].colorMask = vec4(1.0, 0.0, 1.0, 1.0); +} +@fragment +fn fs_main(vertex: VertexOutput, primitive: PrimitiveInput) -> @location(0) vec4 { + return vertex.color * primitive.colorMask; +} From 439817127e5e3af94011a228bfc6e28a8bd7babe Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 7 Nov 2025 14:40:16 -0600 Subject: [PATCH 86/89] Added capabilitiy thing blah blah blah --- wgpu-core/src/device/mod.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/wgpu-core/src/device/mod.rs b/wgpu-core/src/device/mod.rs index ec5203f291b..c9e77e8ba45 100644 --- a/wgpu-core/src/device/mod.rs +++ b/wgpu-core/src/device/mod.rs @@ -514,6 +514,10 @@ pub fn create_validator( Caps::SHADER_BARYCENTRICS, features.intersects(wgt::Features::SHADER_BARYCENTRICS), ); + caps.set( + Caps::MESH_SHADER, + features.intersects(wgt::Features::EXPERIMENTAL_MESH_SHADER), + ); naga::valid::Validator::new(flags, caps) } From 106ae83172a6c2a463684639b566c94ed3c7bed7 Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 7 Nov 2025 15:09:08 -0600 Subject: [PATCH 87/89] Got snapshots to write (not correctly) --- naga/src/back/msl/mod.rs | 11 +-- naga/src/back/msl/writer.rs | 48 +++++++++---- naga/tests/out/msl/wgsl-mesh-shader.msl | 95 +++++++++++++++++++++++++ 3 files changed, 136 insertions(+), 18 deletions(-) create mode 100644 naga/tests/out/msl/wgsl-mesh-shader.msl diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 2456cdbae8b..984084b0700 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -711,14 +711,15 @@ impl ResolvedBinding { Bi::CullDistance | Bi::DrawID => { return Err(Error::UnsupportedBuiltIn(built_in)) } - Bi::CullPrimitive => "primitive_culled", - // TODO: figure out how to make this written as a function call - Bi::PointIndex | Bi::LineIndices | Bi::TriangleIndices => unimplemented!(), - Bi::MeshTaskSize + Bi::CullPrimitive + | Bi::PointIndex + | Bi::LineIndices + | Bi::TriangleIndices + | Bi::MeshTaskSize | Bi::VertexCount | Bi::PrimitiveCount | Bi::Vertices - | Bi::Primitives => unreachable!(), + | Bi::Primitives => "TODO_MESH_BUILTIN", }; write!(out, "{name}")?; } diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index ca7da02a930..3e282577147 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -608,7 +608,7 @@ impl crate::AddressSpace { // may end up with "const" even if the binding is read-write, // and that should be OK. Self::Storage { .. } => true, - Self::TaskPayload => unimplemented!(), + Self::TaskPayload => true, // These should always be read-write. Self::Private | Self::WorkGroup => false, // These translate to `constant` address space, no need for qualifiers. @@ -6603,26 +6603,42 @@ template self.write_wrapped_functions(module, &ctx)?; - let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage { + let (em_str, in_mode, out_mode, can_vertex_pull, extra_attribute) = match ep.stage { crate::ShaderStage::Vertex => ( - "vertex", + Some("vertex"), LocationMode::VertexInput, LocationMode::VertexOutput, true, + None, ), crate::ShaderStage::Fragment => ( - "fragment", + Some("fragment"), LocationMode::FragmentInput, LocationMode::FragmentOutput, false, + None, ), crate::ShaderStage::Compute => ( - "kernel", + Some("kernel"), LocationMode::Uniform, LocationMode::Uniform, false, + None, + ), + crate::ShaderStage::Task => ( + None, + LocationMode::Uniform, + LocationMode::Uniform, + false, + Some("task"), + ), + crate::ShaderStage::Mesh => ( + None, + LocationMode::Uniform, + LocationMode::Uniform, + false, + Some("mesh"), ), - crate::ShaderStage::Task | crate::ShaderStage::Mesh => unimplemented!(), }; // Should this entry point be modified to do vertex pulling? @@ -6689,9 +6705,7 @@ template break; } } - crate::AddressSpace::TaskPayload => { - unimplemented!() - } + crate::AddressSpace::TaskPayload => {} crate::AddressSpace::Function | crate::AddressSpace::Private | crate::AddressSpace::WorkGroup => {} @@ -6817,7 +6831,7 @@ template let stage_out_name = self.namer.call(&format!("{fun_name}Output")); let result_member_name = self.namer.call("member"); let result_type_name = match fun.result { - Some(ref result) => { + Some(ref result) if ep.stage != crate::ShaderStage::Task => { let mut result_members = Vec::new(); if let crate::TypeInner::Struct { ref members, .. } = module.types[result.ty].inner @@ -6888,7 +6902,7 @@ template writeln!(self.out, "}};")?; &stage_out_name } - None => "void", + _ => "void", }; // If we're doing a vertex pulling transform, define the buffer @@ -6908,8 +6922,16 @@ template } } + // Mesh/task (object) shaders use `[[mesh]] void ...` syntax instead of `kernel void ...`. + if let Some(extra_attribute) = extra_attribute { + writeln!(self.out, "[[{extra_attribute}]]")?; + } + // Write the entry point function's name, and begin its argument list. - writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?; + if let Some(em_str) = em_str { + write!(self.out, "{em_str} ")?; + } + writeln!(self.out, "{result_type_name} {fun_name}(")?; let mut is_first_argument = true; let mut separator = || { @@ -7114,7 +7136,7 @@ template // the resolves have already been checked for `!fake_missing_bindings` case let resolved = match var.space { crate::AddressSpace::PushConstant => options.resolve_push_constants(ep).ok(), - crate::AddressSpace::WorkGroup => None, + crate::AddressSpace::WorkGroup | crate::AddressSpace::TaskPayload => None, _ => options .resolve_resource_binding(ep, var.binding.as_ref().unwrap()) .ok(), diff --git a/naga/tests/out/msl/wgsl-mesh-shader.msl b/naga/tests/out/msl/wgsl-mesh-shader.msl new file mode 100644 index 00000000000..7b46dc63947 --- /dev/null +++ b/naga/tests/out/msl/wgsl-mesh-shader.msl @@ -0,0 +1,95 @@ +// language: metal2.4 +#include +#include + +using metal::uint; + +struct TaskPayload { + metal::float4 colorMask; + bool visible; + char _pad2[15]; +}; +struct VertexOutput { + metal::float4 position; + metal::float4 color; +}; +struct PrimitiveOutput { + metal::packed_uint3 index; + bool cull; + char _pad2[3]; + metal::float4 colorMask; +}; +struct PrimitiveInput { + metal::float4 colorMask; +}; +struct type_5 { + VertexOutput inner[3]; +}; +struct type_6 { + PrimitiveOutput inner[1]; +}; +struct MeshOutput { + type_5 vertices; + type_6 primitives; + uint vertex_count; + uint primitive_count; + char _pad4[8]; +}; + +[[task]] +void ts_main( + object_data TaskPayload& taskPayload +, threadgroup float& workgroupData +) { + workgroupData = 1.0; + taskPayload.colorMask = metal::float4(1.0, 1.0, 0.0, 1.0); + taskPayload.visible = true; + return ts_mainOutput { metal::uint3(3u, 1u, 1u) }; +} + + +struct ms_mainInput { +}; +[[mesh]] +void ms_main( + uint index [[thread_index_in_threadgroup]] +, metal::uint3 id [[thread_position_in_grid]] +, object_data TaskPayload const& taskPayload +, threadgroup float& workgroupData +, threadgroup MeshOutput& mesh_output +) { + mesh_output.vertex_count = 3u; + mesh_output.primitive_count = 1u; + workgroupData = 2.0; + mesh_output.vertices.inner[0].position = metal::float4(0.0, 1.0, 0.0, 1.0); + metal::float4 _e25 = taskPayload.colorMask; + mesh_output.vertices.inner[0].color = metal::float4(0.0, 1.0, 0.0, 1.0) * _e25; + mesh_output.vertices.inner[1].position = metal::float4(-1.0, -1.0, 0.0, 1.0); + metal::float4 _e47 = taskPayload.colorMask; + mesh_output.vertices.inner[1].color = metal::float4(0.0, 0.0, 1.0, 1.0) * _e47; + mesh_output.vertices.inner[2].position = metal::float4(1.0, -1.0, 0.0, 1.0); + metal::float4 _e69 = taskPayload.colorMask; + mesh_output.vertices.inner[2].color = metal::float4(1.0, 0.0, 0.0, 1.0) * _e69; + mesh_output.primitives.inner[0].index = metal::uint3(0u, 1u, 2u); + bool _e90 = taskPayload.visible; + mesh_output.primitives.inner[0].cull = !(_e90); + mesh_output.primitives.inner[0].colorMask = metal::float4(1.0, 0.0, 1.0, 1.0); + return; +} + + +struct fs_mainInput { + metal::float4 color [[user(loc0), center_perspective]]; + metal::float4 colorMask [[user(loc1), center_perspective]]; +}; +struct fs_mainOutput { + metal::float4 member_2 [[color(0)]]; +}; +fragment fs_mainOutput fs_main( + fs_mainInput varyings_2 [[stage_in]] +, metal::float4 position [[position]] +) { + const VertexOutput vertex_ = { position, varyings_2.color }; + const PrimitiveInput primitive = { varyings_2.colorMask }; + return fs_mainOutput { vertex_.color * primitive.colorMask }; +} From b699a30ce1dcf9015d9d3deafd71c2bafb1379cf Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 7 Nov 2025 15:17:43 -0600 Subject: [PATCH 88/89] Added changelog entry --- CHANGELOG.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ef2c7869485..ee4a4b5c4b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -158,6 +158,10 @@ By @SupaMaggie70Incorporated in [#8206](https://github.com/gfx-rs/wgpu/pull/8206 - Fixed a bug where the texture aspect was not passed through when calling `copy_texture_to_buffer` in WebGPU, causing the copy to fail for depth/stencil textures. By @Tim-Evans-Seequent in [#8445](https://github.com/gfx-rs/wgpu/pull/8445). +### Metal + +- Complete support for mesh shaders without passthrough shaders. By @inner-daemons in [#8493](https://github.com/gfx-rs/wgpu/pull/8493). + #### hal - `DropCallback`s are now called after dropping all other fields of their parent structs. By @jerzywilczek in [#8353](https://github.com/gfx-rs/wgpu/pull/8353) From 58edde28bc0da91dddc5d45526d171800d74cf2b Mon Sep 17 00:00:00 2001 From: SupaMaggie70 Date: Fri, 7 Nov 2025 15:28:57 -0600 Subject: [PATCH 89/89] Slight improvement to task payload --- naga/src/back/msl/writer.rs | 18 ++++++++++++------ naga/tests/out/msl/wgsl-mesh-shader.msl | 9 ++++++--- 2 files changed, 18 insertions(+), 9 deletions(-) diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 3e282577147..adb94a55d33 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -410,7 +410,7 @@ impl TypedGlobalVariable<'_> { first_time: false, }; - let (space, access, reference) = match var.space.to_msl_name() { + let (space, access, reference, trailing_attribute) = match var.space.to_msl_name() { Some(space) if self.reference => { let access = if var.space.needs_access_qualifier() && !self.usage.intersects(valid::GlobalUse::WRITE) @@ -419,14 +419,19 @@ impl TypedGlobalVariable<'_> { } else { "" }; - (space, access, "&") + let trailing_attribute = if var.space == crate::AddressSpace::TaskPayload { + " [[payload]]" + } else { + "" + }; + (space, access, "&", trailing_attribute) } - _ => ("", "", ""), + _ => ("", "", "", ""), }; Ok(write!( out, - "{}{}{}{}{}{} {}", + "{}{}{}{}{}{} {}{}", space, if space.is_empty() { "" } else { " " }, ty_name, @@ -434,6 +439,7 @@ impl TypedGlobalVariable<'_> { access, reference, name, + trailing_attribute )?) } } @@ -6831,7 +6837,7 @@ template let stage_out_name = self.namer.call(&format!("{fun_name}Output")); let result_member_name = self.namer.call("member"); let result_type_name = match fun.result { - Some(ref result) if ep.stage != crate::ShaderStage::Task => { + Some(ref result) => { let mut result_members = Vec::new(); if let crate::TypeInner::Struct { ref members, .. } = module.types[result.ty].inner @@ -6902,7 +6908,7 @@ template writeln!(self.out, "}};")?; &stage_out_name } - _ => "void", + None => "void", }; // If we're doing a vertex pulling transform, define the buffer diff --git a/naga/tests/out/msl/wgsl-mesh-shader.msl b/naga/tests/out/msl/wgsl-mesh-shader.msl index 7b46dc63947..5280464ea63 100644 --- a/naga/tests/out/msl/wgsl-mesh-shader.msl +++ b/naga/tests/out/msl/wgsl-mesh-shader.msl @@ -36,9 +36,12 @@ struct MeshOutput { char _pad4[8]; }; +struct ts_mainOutput { + metal::uint3 member [[TODO_MESH_BUILTIN]]; +}; [[task]] -void ts_main( - object_data TaskPayload& taskPayload +ts_mainOutput ts_main( + object_data TaskPayload& taskPayload [[payload]] , threadgroup float& workgroupData ) { workgroupData = 1.0; @@ -54,7 +57,7 @@ struct ms_mainInput { void ms_main( uint index [[thread_index_in_threadgroup]] , metal::uint3 id [[thread_position_in_grid]] -, object_data TaskPayload const& taskPayload +, object_data TaskPayload const& taskPayload [[payload]] , threadgroup float& workgroupData , threadgroup MeshOutput& mesh_output ) {