diff --git a/crates/spirv-std/src/builtin.rs b/crates/spirv-std/src/builtin.rs new file mode 100644 index 0000000000..42f17af8b5 --- /dev/null +++ b/crates/spirv-std/src/builtin.rs @@ -0,0 +1,27 @@ +//! Symbols to query SPIR-V read-only global built-ins + +/// compute shader built-ins +pub mod compute { + #[cfg(target_arch = "spirv")] + use core::arch::asm; + use glam::UVec3; + + /// GLSL: `gl_LocalInvocationID()` + /// WGSL: `local_invocation_id` + #[doc(alias = "gl_LocalInvocationID")] + #[inline] + #[gpu_only] + pub fn local_invocation_id() -> UVec3 { + unsafe { + let result = UVec3::default(); + asm! { + "%builtin = OpVariable typeof{result} Input", + "OpDecorate %builtin BuiltIn LocalInvocationId", + "%result = OpLoad typeof*{result} %builtin", + "OpStore {result} %result", + result = in(reg) &result, + } + result + } + } +} diff --git a/crates/spirv-std/src/lib.rs b/crates/spirv-std/src/lib.rs index 2c85dc9af0..cf26855488 100644 --- a/crates/spirv-std/src/lib.rs +++ b/crates/spirv-std/src/lib.rs @@ -91,6 +91,7 @@ pub use macros::spirv; pub use macros::{debug_printf, debug_printfln}; pub mod arch; +pub mod builtin; pub mod byte_addressable_buffer; pub mod debug_printf; pub mod float; diff --git a/tests/compiletests/ui/builtin/compute.rs b/tests/compiletests/ui/builtin/compute.rs new file mode 100644 index 0000000000..55089199d4 --- /dev/null +++ b/tests/compiletests/ui/builtin/compute.rs @@ -0,0 +1,32 @@ +// build-pass +// compile-flags: -C llvm-args=--disassemble +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "%\d+ = OpString .*\n" -> "" +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" +// normalize-stderr-test "; .*\n" -> "" +// ignore-spv1.0 +// ignore-spv1.1 +// ignore-spv1.2 +// ignore-spv1.3 +// ignore-vulkan1.0 +// ignore-vulkan1.1 + +use spirv_std::glam::*; +use spirv_std::spirv; + +#[spirv(compute(threads(1)))] +pub fn compute( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut u32, + // #[spirv(global_invocation_id)] global_invocation_id: UVec3, + // #[spirv(local_invocation_id)] local_invocation_id: UVec3, + // #[spirv(subgroup_local_invocation_id)] subgroup_local_invocation_id: u32, + // #[spirv(num_subgroups)] num_subgroups: u32, + // #[spirv(num_workgroups)] num_workgroups: UVec3, + // #[spirv(subgroup_id)] subgroup_id: u32, + // #[spirv(workgroup_id)] workgroup_id: UVec3, +) { + let local_invocation_id = spirv_std::builtin::compute::local_invocation_id(); + *out = local_invocation_id.x; +} diff --git a/tests/compiletests/ui/builtin/compute.stderr b/tests/compiletests/ui/builtin/compute.stderr new file mode 100644 index 0000000000..d1afba02ad --- /dev/null +++ b/tests/compiletests/ui/builtin/compute.stderr @@ -0,0 +1,29 @@ +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "compute" %2 %3 +OpExecutionMode %1 LocalSize 1 1 1 +OpDecorate %6 Block +OpMemberDecorate %6 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpDecorate %3 BuiltIn LocalInvocationId +%7 = OpTypeInt 32 0 +%6 = OpTypeStruct %7 +%8 = OpTypePointer StorageBuffer %6 +%9 = OpTypeVoid +%10 = OpTypeFunction %9 +%11 = OpTypePointer StorageBuffer %7 +%2 = OpVariable %8 StorageBuffer +%12 = OpConstant %7 0 +%13 = OpTypeVector %7 3 +%14 = OpTypePointer Input %13 +%3 = OpVariable %14 Input +%1 = OpFunction %9 None %10 +%15 = OpLabel +%16 = OpInBoundsAccessChain %11 %2 %12 +%17 = OpLoad %13 %3 +%18 = OpCompositeExtract %7 %17 0 +OpStore %16 %18 +OpNoLine +OpReturn +OpFunctionEnd diff --git a/tests/compiletests/ui/builtin/compute_attr.rs b/tests/compiletests/ui/builtin/compute_attr.rs new file mode 100644 index 0000000000..e7bcd6bd6d --- /dev/null +++ b/tests/compiletests/ui/builtin/compute_attr.rs @@ -0,0 +1,31 @@ +// build-pass +// compile-flags: -C llvm-args=--disassemble +// normalize-stderr-test "OpLine .*\n" -> "" +// normalize-stderr-test "OpSource .*\n" -> "" +// normalize-stderr-test "%\d+ = OpString .*\n" -> "" +// normalize-stderr-test "OpCapability VulkanMemoryModel\n" -> "" +// normalize-stderr-test "OpMemoryModel Logical Vulkan" -> "OpMemoryModel Logical Simple" +// normalize-stderr-test "; .*\n" -> "" +// ignore-spv1.0 +// ignore-spv1.1 +// ignore-spv1.2 +// ignore-spv1.3 +// ignore-vulkan1.0 +// ignore-vulkan1.1 + +use spirv_std::glam::*; +use spirv_std::spirv; + +#[spirv(compute(threads(1)))] +pub fn compute( + #[spirv(descriptor_set = 0, binding = 0, storage_buffer)] out: &mut u32, + // #[spirv(global_invocation_id)] global_invocation_id: UVec3, + #[spirv(local_invocation_id)] local_invocation_id: UVec3, + // #[spirv(subgroup_local_invocation_id)] subgroup_local_invocation_id: u32, + // #[spirv(num_subgroups)] num_subgroups: u32, + // #[spirv(num_workgroups)] num_workgroups: UVec3, + // #[spirv(subgroup_id)] subgroup_id: u32, + // #[spirv(workgroup_id)] workgroup_id: UVec3, +) { + *out = local_invocation_id.x; +} diff --git a/tests/compiletests/ui/builtin/compute_attr.stderr b/tests/compiletests/ui/builtin/compute_attr.stderr new file mode 100644 index 0000000000..6fa60f919a --- /dev/null +++ b/tests/compiletests/ui/builtin/compute_attr.stderr @@ -0,0 +1,30 @@ +OpCapability Shader +OpMemoryModel Logical Simple +OpEntryPoint GLCompute %1 "compute" %2 %3 +OpExecutionMode %1 LocalSize 1 1 1 +OpName %3 "local_invocation_id" +OpDecorate %5 Block +OpMemberDecorate %5 0 Offset 0 +OpDecorate %2 Binding 0 +OpDecorate %2 DescriptorSet 0 +OpDecorate %3 BuiltIn LocalInvocationId +%6 = OpTypeInt 32 0 +%5 = OpTypeStruct %6 +%7 = OpTypePointer StorageBuffer %5 +%8 = OpTypeVector %6 3 +%9 = OpTypePointer Input %8 +%10 = OpTypeVoid +%11 = OpTypeFunction %10 +%12 = OpTypePointer StorageBuffer %6 +%2 = OpVariable %7 StorageBuffer +%13 = OpConstant %6 0 +%3 = OpVariable %9 Input +%1 = OpFunction %10 None %11 +%14 = OpLabel +%15 = OpInBoundsAccessChain %12 %2 %13 +%16 = OpLoad %8 %3 +%17 = OpCompositeExtract %6 %16 0 +OpStore %15 %17 +OpNoLine +OpReturn +OpFunctionEnd