diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index e66f0c5e30..d12d7fd930 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -27,7 +27,7 @@ jobs: - name: Install Vulkan SDK uses: jakoch/install-vulkan-sdk-action@v1 with: - vulkan_version: 1.4.309.0 + vulkan_version: 1.4.321.0 install_runtime: true cache: true stripdown: true @@ -92,7 +92,7 @@ jobs: - name: Install Vulkan SDK uses: jakoch/install-vulkan-sdk-action@v1 with: - vulkan_version: 1.4.309.0 + vulkan_version: 1.4.321.0 install_runtime: true cache: true stripdown: true @@ -139,7 +139,7 @@ jobs: - name: Install Vulkan SDK uses: jakoch/install-vulkan-sdk-action@v1 with: - vulkan_version: 1.4.309.0 + vulkan_version: 1.4.321.0 install_runtime: true cache: true stripdown: true @@ -165,7 +165,7 @@ jobs: - name: Install Vulkan SDK uses: jakoch/install-vulkan-sdk-action@v1 with: - vulkan_version: 1.4.309.0 + vulkan_version: 1.4.321.0 install_runtime: true cache: true stripdown: true @@ -231,7 +231,7 @@ jobs: - name: Install Vulkan SDK uses: jakoch/install-vulkan-sdk-action@v1 with: - vulkan_version: 1.4.309.0 + vulkan_version: 1.4.321.0 install_runtime: true cache: true stripdown: true diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 30fc1ae6c9..389e9a763d 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -2381,13 +2381,6 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { #[instrument(level = "trace", skip(self), fields(ptr, ptr_ty = ?self.debug_type(ptr.ty), dest_ty = ?self.debug_type(dest_ty)))] fn pointercast(&mut self, ptr: Self::Value, dest_ty: Self::Type) -> Self::Value { - // HACK(eddyb) reuse the special-casing in `const_bitcast`, which relies - // on adding a pointer type to an untyped pointer (to some const data). - if let SpirvValueKind::IllegalConst(_) = ptr.kind { - trace!("illegal const"); - return self.const_bitcast(ptr, dest_ty); - } - if ptr.ty == dest_ty { trace!("ptr.ty == dest_ty"); return ptr; @@ -2446,13 +2439,43 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.debug_type(ptr_pointee), self.debug_type(dest_pointee), ); + + // HACK(eddyb) reuse the special-casing in `const_bitcast`, which relies + // on adding a pointer type to an untyped pointer (to some const data). + if self.builder.lookup_const(ptr).is_some() { + // FIXME(eddyb) remove the condition on `zombie_waiting_for_span`, + // and constant-fold all pointer bitcasts, regardless of "legality", + // once `strip_ptrcasts` can undo `const_bitcast`, as well. + if ptr.zombie_waiting_for_span { + trace!("illegal const"); + return self.const_bitcast(ptr, dest_ty); + } + } + // Defer the cast so that it has a chance to be avoided. - let original_ptr = ptr.def(self); + let ptr_id = ptr.def(self); + let bitcast_result_id = self.emit().bitcast(dest_ty, None, ptr_id).unwrap(); + + self.zombie( + bitcast_result_id, + &format!( + "cannot cast between pointer types\ + \nfrom `{}`\ + \n to `{}`", + self.debug_type(ptr.ty), + self.debug_type(dest_ty) + ), + ); + SpirvValue { - kind: SpirvValueKind::LogicalPtrCast { - original_ptr, - original_ptr_ty: ptr.ty, - bitcast_result_id: self.emit().bitcast(dest_ty, None, original_ptr).unwrap(), + zombie_waiting_for_span: false, + kind: SpirvValueKind::Def { + id: bitcast_result_id, + original_ptr_before_casts: Some(SpirvValue { + zombie_waiting_for_span: ptr.zombie_waiting_for_span, + kind: ptr_id, + ty: ptr.ty, + }), }, ty: dest_ty, } @@ -3269,7 +3292,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { return_type, arguments, } => ( - if let SpirvValueKind::FnAddr { function } = callee.kind { + if let SpirvValueKind::FnAddr { function, .. } = callee.kind { assert_ty_eq!(self, callee_ty, pointee); function } @@ -3406,11 +3429,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // HACK(eddyb) some entry-points only take a `&str`, not `fmt::Arguments`. if let [ SpirvValue { - kind: SpirvValueKind::Def(a_id), + kind: SpirvValueKind::Def { id: a_id, .. }, .. }, SpirvValue { - kind: SpirvValueKind::Def(b_id), + kind: SpirvValueKind::Def { id: b_id, .. }, .. }, ref other_args @ .., @@ -3429,14 +3452,20 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // HACK(eddyb) `panic_nounwind_fmt` takes an extra argument. [ SpirvValue { - kind: SpirvValueKind::Def(format_args_id), + kind: + SpirvValueKind::Def { + id: format_args_id, .. + }, .. }, _, // `&'static panic::Location<'static>` ] | [ SpirvValue { - kind: SpirvValueKind::Def(format_args_id), + kind: + SpirvValueKind::Def { + id: format_args_id, .. + }, .. }, _, // `force_no_backtrace: bool` @@ -4110,10 +4139,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { self.codegen_buffer_store_intrinsic(args, mode); let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self); - return SpirvValue { - kind: SpirvValueKind::IllegalTypeUsed(void_ty), - ty: void_ty, - }; + return self.undef(void_ty); } if let Some((source_ty, target_ty)) = from_trait_impl { diff --git a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs index ab2a78cf65..60f0109573 100644 --- a/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs +++ b/crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs @@ -2,7 +2,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use super::Builder; -use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind}; +use crate::builder_spirv::{SpirvValue, SpirvValueExt}; use crate::spirv_type::SpirvType; use rspirv::spirv::{Decoration, Word}; use rustc_abi::{Align, Size}; @@ -186,12 +186,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { pass_mode: &PassMode, ) -> SpirvValue { match pass_mode { - PassMode::Ignore => { - return SpirvValue { - kind: SpirvValueKind::IllegalTypeUsed(result_type), - ty: result_type, - }; - } + PassMode::Ignore => return self.undef(result_type), + // PassMode::Pair is identical to PassMode::Direct - it's returned as a struct PassMode::Direct(_) | PassMode::Pair(_, _) => (), PassMode::Cast { .. } => { diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index e05b433057..f430dfdf7f 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -16,7 +16,6 @@ use rustc_abi::Size; use rustc_arena::DroplessArena; use rustc_codegen_ssa::traits::ConstCodegenMethods as _; use rustc_data_structures::fx::{FxHashMap, FxHashSet}; -use rustc_middle::bug; use rustc_middle::mir::interpret::ConstAllocation; use rustc_middle::ty::TyCtxt; use rustc_span::source_map::SourceMap; @@ -31,91 +30,86 @@ use std::str; use std::sync::Arc; use std::{fs::File, io::Write, path::Path}; +// HACK(eddyb) silence warnings that are inaccurate wrt future changes. +#[non_exhaustive] #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] pub enum SpirvValueKind { - Def(Word), - - /// The ID of a global instruction matching a `SpirvConst`, but which cannot - /// pass validation. Used to error (or attach zombie spans), at the usesites - /// of such constants, instead of where they're generated (and cached). - IllegalConst(Word), - - /// This can only happen in one specific case - which is as a result of - /// `codegen_buffer_store_intrinsic`, that function is supposed to return - /// `OpTypeVoid`, however because it gets inline by the compiler it can't. - /// Instead we return this, and trigger an error if we ever end up using the - /// result of this function call (which we can't). - IllegalTypeUsed(Word), + Def { + id: Word, + + /// If `id` is a pointer cast, this will be `Some`, and contain all the + /// information necessary to regenerate the original `SpirvValue` before + /// *any* pointer casts were applied, effectively deferring the casts + /// (as long as all downstream uses apply `.strip_ptrcasts()` first), + /// and bypassing errors they might cause (due to SPIR-V limitations). + // + // FIXME(eddyb) wouldn't it be easier to use this for *any* bitcasts? + // (with some caveats around dedicated int<->ptr casts vs bitcasts) + original_ptr_before_casts: Option>, + }, // FIXME(eddyb) this shouldn't be needed, but `rustc_codegen_ssa` still relies // on converting `Function`s to `Value`s even for direct calls, the `Builder` // should just have direct and indirect `call` variants (or a `Callee` enum). FnAddr { function: Word, - }, - - /// Deferred pointer cast, for the `Logical` addressing model (which doesn't - /// really support raw pointers in the way Rust expects to be able to use). - /// - /// The cast's target pointer type is the `ty` of the `SpirvValue` that has - /// `LogicalPtrCast` as its `kind`, as it would be redundant to have it here. - LogicalPtrCast { - /// Pointer value being cast. - original_ptr: Word, - - /// Pointer type of `original_ptr`. - original_ptr_ty: Word, - /// Result ID for the `OpBitcast` instruction representing the cast, - /// to attach zombies to. - // - // HACK(eddyb) having an `OpBitcast` only works by being DCE'd away, - // or by being replaced with a noop in `qptr::lower`. - bitcast_result_id: Word, + // FIXME(eddyb) replace this ad-hoc zombie with a proper `SpirvConst`. + zombie_id: Word, }, } #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] -pub struct SpirvValue { - pub kind: SpirvValueKind, +pub struct SpirvValue { + // HACK(eddyb) used to cheaply check whether this is a SPIR-V value ID + // with a "zombie" (deferred error) attached to it, that may need a `Span` + // still (e.g. such as constants, which can't easily take a `Span`). + // FIXME(eddyb) a whole `bool` field is sadly inefficient, but anything + // which may make `SpirvValue` smaller requires far too much impl effort. + pub zombie_waiting_for_span: bool, + + pub kind: K, pub ty: Word, } +impl SpirvValue { + fn map_kind(self, f: impl FnOnce(K) -> K2) -> SpirvValue { + let SpirvValue { + zombie_waiting_for_span, + kind, + ty, + } = self; + SpirvValue { + zombie_waiting_for_span, + kind: f(kind), + ty, + } + } +} + impl SpirvValue { pub fn strip_ptrcasts(self) -> Self { match self.kind { - SpirvValueKind::LogicalPtrCast { - original_ptr, - original_ptr_ty, - bitcast_result_id: _, - } => original_ptr.with_type(original_ptr_ty), + SpirvValueKind::Def { + id: _, + original_ptr_before_casts: Some(original_ptr), + } => original_ptr.map_kind(|id| SpirvValueKind::Def { + id, + original_ptr_before_casts: None, + }), _ => self, } } pub fn const_fold_load(self, cx: &CodegenCx<'_>) -> Option { - match self.kind { - SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => { - let &entry = cx.builder.id_to_const.borrow().get(&id)?; - match entry.val { - SpirvConst::PtrTo { pointee } => { - let ty = match cx.lookup_type(self.ty) { - SpirvType::Pointer { pointee } => pointee, - ty => bug!("load called on value that wasn't a pointer: {:?}", ty), - }; - // FIXME(eddyb) deduplicate this `if`-`else` and its other copies. - let kind = if entry.legal.is_ok() { - SpirvValueKind::Def(pointee) - } else { - SpirvValueKind::IllegalConst(pointee) - }; - Some(SpirvValue { kind, ty }) - } - _ => None, - } + match cx.builder.lookup_const(self)? { + SpirvConst::PtrTo { pointee } => { + // HACK(eddyb) this obtains a `SpirvValue` from the ID it contains, + // so there's some conceptual inefficiency there, but it does + // prevent any of the other details from being lost accidentally. + Some(cx.builder.id_to_const_and_val.borrow().get(&pointee)?.val.1) } - _ => None, } } @@ -134,80 +128,13 @@ impl SpirvValue { } pub fn def_with_span(self, cx: &CodegenCx<'_>, span: Span) -> Word { - match self.kind { - SpirvValueKind::Def(id) => id, - - SpirvValueKind::IllegalConst(id) => { - let entry = &cx.builder.id_to_const.borrow()[&id]; - let msg = match entry.legal.unwrap_err() { - IllegalConst::Shallow(cause) => { - if let ( - LeafIllegalConst::CompositeContainsPtrTo, - SpirvConst::Composite(_fields), - ) = (cause, &entry.val) - { - // FIXME(eddyb) materialize this at runtime, using - // `OpCompositeConstruct` (transitively, i.e. after - // putting every field through `SpirvValue::def`), - // if we have a `Builder` to do that in. - // FIXME(eddyb) this isn't possible right now, as - // the builder would be dynamically "locked" anyway - // (i.e. attempting to do `bx.emit()` would panic). - } - - cause.message() - } - - IllegalConst::Indirect(cause) => cause.message(), - }; - - cx.zombie_with_span(id, span, msg); - - id - } - - SpirvValueKind::IllegalTypeUsed(id) => { - cx.tcx - .dcx() - .struct_span_err(span, "Can't use type as a value") - .with_note(format!("Type: *{}", cx.debug_type(id))) - .emit(); - - id - } - - SpirvValueKind::FnAddr { .. } => { - cx.builder - .const_to_id - .borrow() - .get(&WithType { - ty: self.ty, - val: SpirvConst::ZombieUndefForFnAddr, - }) - .expect("FnAddr didn't go through proper undef registration") - .val - } - - SpirvValueKind::LogicalPtrCast { - original_ptr: _, - original_ptr_ty, - bitcast_result_id, - } => { - cx.zombie_with_span( - bitcast_result_id, - span, - &format!( - "cannot cast between pointer types\ - \nfrom `{}`\ - \n to `{}`", - cx.debug_type(original_ptr_ty), - cx.debug_type(self.ty) - ), - ); - - bitcast_result_id - } + let id = match self.kind { + SpirvValueKind::Def { id, .. } | SpirvValueKind::FnAddr { zombie_id: id, .. } => id, + }; + if self.zombie_waiting_for_span { + cx.add_span_to_zombie_if_missing(id, span); } + id } } @@ -218,7 +145,11 @@ pub trait SpirvValueExt { impl SpirvValueExt for Word { fn with_type(self, ty: Word) -> SpirvValue { SpirvValue { - kind: SpirvValueKind::Def(self), + zombie_waiting_for_span: false, + kind: SpirvValueKind::Def { + id: self, + original_ptr_before_casts: None, + }, ty, } } @@ -436,11 +367,12 @@ pub struct BuilderSpirv<'tcx> { builder: RefCell, // Bidirectional maps between `SpirvConst` and the ID of the defined global - // (e.g. `OpConstant...`) instruction. - // NOTE(eddyb) both maps have `WithConstLegality` around their keys, which - // allows getting that legality information without additional lookups. - const_to_id: RefCell>, WithConstLegality>>, - id_to_const: RefCell>>>, + // (e.g. `OpConstant...`) instruction, with additional information in values + // (i.e. each map is keyed by only some part of the other map's value type), + // as needed to streamline operations (e.g. avoiding rederiving `SpirvValue`). + const_to_val: RefCell>, SpirvValue>>, + id_to_const_and_val: + RefCell, SpirvValue)>>>, debug_file_cache: RefCell>>, @@ -511,8 +443,8 @@ impl<'tcx> BuilderSpirv<'tcx> { source_map: tcx.sess.source_map(), dropless_arena: &tcx.arena.dropless, builder: RefCell::new(builder), - const_to_id: Default::default(), - id_to_const: Default::default(), + const_to_val: Default::default(), + id_to_const_and_val: Default::default(), debug_file_cache: Default::default(), enabled_capabilities, } @@ -616,14 +548,8 @@ impl<'tcx> BuilderSpirv<'tcx> { }; let val_with_type = WithType { ty, val }; - if let Some(entry) = self.const_to_id.borrow().get(&val_with_type) { - // FIXME(eddyb) deduplicate this `if`-`else` and its other copies. - let kind = if entry.legal.is_ok() { - SpirvValueKind::Def(entry.val) - } else { - SpirvValueKind::IllegalConst(entry.val) - }; - return SpirvValue { kind, ty }; + if let Some(&v) = self.const_to_val.borrow().get(&val_with_type) { + return v; } let val = val_with_type.val; @@ -755,11 +681,11 @@ impl<'tcx> BuilderSpirv<'tcx> { SpirvConst::Composite(v) => v .iter() .map(|field| { - let field_entry = &self.id_to_const.borrow()[field]; + let field_entry = &self.id_to_const_and_val.borrow()[field]; field_entry.legal.and( // `field` is itself some legal `SpirvConst`, but can we have // it as part of an `OpConstantComposite`? - match field_entry.val { + match field_entry.val.0 { SpirvConst::PtrTo { .. } => Err(IllegalConst::Shallow( LeafIllegalConst::CompositeContainsPtrTo, )), @@ -787,50 +713,71 @@ impl<'tcx> BuilderSpirv<'tcx> { }) .unwrap_or(Ok(())), - SpirvConst::PtrTo { pointee } => match self.id_to_const.borrow()[&pointee].legal { - Ok(()) => Ok(()), + SpirvConst::PtrTo { pointee } => { + match self.id_to_const_and_val.borrow()[&pointee].legal { + Ok(()) => Ok(()), - // `Shallow` becomes `Indirect` when placed behind a pointer. - Err(IllegalConst::Shallow(cause) | IllegalConst::Indirect(cause)) => { - Err(IllegalConst::Indirect(cause)) + // `Shallow` becomes `Indirect` when placed behind a pointer. + Err(IllegalConst::Shallow(cause) | IllegalConst::Indirect(cause)) => { + Err(IllegalConst::Indirect(cause)) + } } - }, + } SpirvConst::ConstDataFromAlloc(_) => Err(IllegalConst::Shallow( LeafIllegalConst::UntypedConstDataFromAlloc, )), }; + + // FIXME(eddyb) avoid dragging "const (il)legality" around, as well + // (sadly that does require that `SpirvConst` -> SPIR-V be injective, + // e.g. `OpUndef` can never be used for unrepresentable constants). + if let Err(illegal) = legal { + let msg = match illegal { + IllegalConst::Shallow(cause) | IllegalConst::Indirect(cause) => cause.message(), + }; + cx.zombie_no_span(id, msg); + } + let val = val.tcx_arena_alloc_slices(cx); + + // FIXME(eddyb) the `val`/`v` name clash is a bit unfortunate. + let v = SpirvValue { + zombie_waiting_for_span: legal.is_err(), + kind: SpirvValueKind::Def { + id, + original_ptr_before_casts: None, + }, + ty, + }; + assert_matches!( - self.const_to_id + self.const_to_val .borrow_mut() - .insert(WithType { ty, val }, WithConstLegality { val: id, legal }), + .insert(WithType { ty, val }, v), None ); assert_matches!( - self.id_to_const - .borrow_mut() - .insert(id, WithConstLegality { val, legal }), + self.id_to_const_and_val.borrow_mut().insert( + id, + WithConstLegality { + val: (val, v), + legal + } + ), None ); - // FIXME(eddyb) deduplicate this `if`-`else` and its other copies. - let kind = if legal.is_ok() { - SpirvValueKind::Def(id) - } else { - SpirvValueKind::IllegalConst(id) - }; - SpirvValue { kind, ty } + + v } pub fn lookup_const_by_id(&self, id: Word) -> Option> { - Some(self.id_to_const.borrow().get(&id)?.val) + Some(self.id_to_const_and_val.borrow().get(&id)?.val.0) } pub fn lookup_const(&self, def: SpirvValue) -> Option> { match def.kind { - SpirvValueKind::Def(id) | SpirvValueKind::IllegalConst(id) => { - self.lookup_const_by_id(id) - } + SpirvValueKind::Def { id, .. } => self.lookup_const_by_id(id), _ => None, } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index eb8783049c..1eef73b3f1 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -3,7 +3,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa; use super::CodegenCx; use crate::abi::ConvSpirvType; -use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt, SpirvValueKind}; +use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt}; use crate::spirv_type::SpirvType; use itertools::Itertools as _; use rspirv::spirv::Word; @@ -334,8 +334,7 @@ impl<'tcx> CodegenCx<'tcx> { pub fn const_bitcast(&self, val: SpirvValue, ty: Word) -> SpirvValue { // HACK(eddyb) special-case `const_data_from_alloc` + `static_addr_of` // as the old `from_const_alloc` (now `OperandRef::from_const_alloc`). - if let SpirvValueKind::IllegalConst(_) = val.kind - && let Some(SpirvConst::PtrTo { pointee }) = self.builder.lookup_const(val) + if let Some(SpirvConst::PtrTo { pointee }) = self.builder.lookup_const(val) && let Some(SpirvConst::ConstDataFromAlloc(alloc)) = self.builder.lookup_const_by_id(pointee) && let SpirvType::Pointer { pointee } = self.lookup_type(ty) diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index f3811f8a19..fcebbde3f3 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -245,9 +245,9 @@ impl<'tcx> CodegenCx<'tcx> { /// is stripped from the binary. /// /// Errors will only be emitted (by `linker::zombies`) for reachable zombies. - pub fn zombie_with_span(&self, word: Word, span: Span, reason: &str) { + pub fn zombie_with_span(&self, id: Word, span: Span, reason: &str) { self.zombie_decorations.borrow_mut().insert( - word, + id, ( ZombieDecoration { // FIXME(eddyb) this could take advantage of `Cow` and use @@ -258,8 +258,16 @@ impl<'tcx> CodegenCx<'tcx> { ), ); } - pub fn zombie_no_span(&self, word: Word, reason: &str) { - self.zombie_with_span(word, DUMMY_SP, reason); + pub fn zombie_no_span(&self, id: Word, reason: &str) { + self.zombie_with_span(id, DUMMY_SP, reason); + } + + pub fn add_span_to_zombie_if_missing(&self, id: Word, span: Span) { + if span != DUMMY_SP + && let Some((_, src_loc @ None)) = self.zombie_decorations.borrow_mut().get_mut(&id) + { + *src_loc = SrcLocDecoration::from_rustc_span(span, &self.builder); + } } pub fn finalize_module(self) -> Module { @@ -846,11 +854,15 @@ impl<'tcx> MiscCodegenMethods<'tcx> for CodegenCx<'tcx> { // Create these `OpUndef`s up front, instead of on-demand in `SpirvValue::def`, // because `SpirvValue::def` can't use `cx.emit()`. - self.def_constant(ty, SpirvConst::ZombieUndefForFnAddr); + let zombie_id = self + .def_constant(ty, SpirvConst::ZombieUndefForFnAddr) + .def_with_span(self, span); SpirvValue { + zombie_waiting_for_span: false, kind: SpirvValueKind::FnAddr { function: function.id, + zombie_id, }, ty, } diff --git a/crates/rustc_codegen_spirv/src/linker/duplicates.rs b/crates/rustc_codegen_spirv/src/linker/duplicates.rs index 76f85a7713..70554caddd 100644 --- a/crates/rustc_codegen_spirv/src/linker/duplicates.rs +++ b/crates/rustc_codegen_spirv/src/linker/duplicates.rs @@ -283,7 +283,20 @@ pub fn remove_duplicate_debuginfo(module: &mut Module) { }) .map(|inst| inst.result_id.unwrap()); + let deduper = DebuginfoDeduplicator { + custom_ext_inst_set_import, + }; for func in &mut module.functions { + deduper.remove_duplicate_debuginfo_in_function(func); + } +} + +pub struct DebuginfoDeduplicator { + pub custom_ext_inst_set_import: Option, +} + +impl DebuginfoDeduplicator { + pub fn remove_duplicate_debuginfo_in_function(&self, func: &mut rspirv::dr::Function) { for block in &mut func.blocks { // Ignore the terminator, it's effectively "outside" debuginfo. let (_, insts) = block.instructions.split_last_mut().unwrap(); @@ -339,7 +352,8 @@ pub fn remove_duplicate_debuginfo(module: &mut Module) { let inst = &insts[inst_idx]; let custom_op = match inst.class.opcode { Op::ExtInst - if Some(inst.operands[0].unwrap_id_ref()) == custom_ext_inst_set_import => + if Some(inst.operands[0].unwrap_id_ref()) + == self.custom_ext_inst_set_import => { Some(CustomOp::decode_from_ext_inst(inst)) } diff --git a/crates/rustc_codegen_spirv/src/linker/inline.rs b/crates/rustc_codegen_spirv/src/linker/inline.rs index 0ab19bd40f..10883f5f87 100644 --- a/crates/rustc_codegen_spirv/src/linker/inline.rs +++ b/crates/rustc_codegen_spirv/src/linker/inline.rs @@ -133,12 +133,58 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> { .map(Ok) .collect(); - // Inline functions in post-order (aka inside-out aka bottom-out) - that is, + + let mut mem2reg_pointer_to_pointee = FxHashMap::default(); + let mut mem2reg_constants = FxHashMap::default(); + { + let mut u32 = None; + for inst in &module.types_global_values { + match inst.class.opcode { + Op::TypePointer => { + mem2reg_pointer_to_pointee + .insert(inst.result_id.unwrap(), inst.operands[1].unwrap_id_ref()); + } + Op::TypeInt + if inst.operands[0].unwrap_literal_bit32() == 32 + && inst.operands[1].unwrap_literal_bit32() == 0 => + { + assert!(u32.is_none()); + u32 = Some(inst.result_id.unwrap()); + } + Op::Constant if u32.is_some() && inst.result_type == u32 => { + let value = inst.operands[0].unwrap_literal_bit32(); + mem2reg_constants.insert(inst.result_id.unwrap(), value); + } + _ => {} + } + } + } + + // Inline functions in post-order (aka inside-out aka bottom-up) - that is, // callees are processed before their callers, to avoid duplicating work. for func_idx in call_graph.post_order() { let mut function = mem::replace(&mut functions[func_idx], Err(FuncIsBeingInlined)).unwrap(); inliner.inline_fn(&mut function, &functions); fuse_trivial_branches(&mut function); + + super::duplicates::DebuginfoDeduplicator { + custom_ext_inst_set_import, + } + .remove_duplicate_debuginfo_in_function(&mut function); + + { + super::simple_passes::block_ordering_pass(&mut function); + // Note: mem2reg requires functions to be in RPO order (i.e. block_ordering_pass) + super::mem2reg::mem2reg( + inliner.header, + &mut module.types_global_values, + &mem2reg_pointer_to_pointee, + &mem2reg_constants, + &mut function, + ); + super::destructure_composites::destructure_composites(&mut function); + } + functions[func_idx] = Ok(function); } @@ -411,7 +457,7 @@ fn should_inline( } // If the call isn't passing a legal pointer argument (a "memory object", - // i.e. an `OpVariable` or one of the caller's `OpFunctionParameters), + // i.e. an `OpVariable` or one of the caller's `OpFunctionParameter`s), // then inlining is required to have a chance at producing legal SPIR-V. // // FIXME(eddyb) rewriting away the pointer could be another alternative. @@ -826,7 +872,7 @@ impl Inliner<'_, '_> { } // `vars_and_debuginfo_range.end` indicates where `OpVariable`s - // end and other instructions start (module debuginfo), but to + // end and other instructions start (modulo debuginfo), but to // split the block in two, both sides of the "cut" need "repair": // - the variables are missing "inlined call frames" pops, that // may happen later in the block, and have to be synthesized diff --git a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs index df2434d63d..c6fd8c084e 100644 --- a/crates/rustc_codegen_spirv/src/linker/mem2reg.rs +++ b/crates/rustc_codegen_spirv/src/linker/mem2reg.rs @@ -193,6 +193,8 @@ fn insert_phis_all( for (var_map, _) in &var_maps_and_types { split_copy_memory(header, blocks, var_map); } + + let mut rewrite_rules = FxHashMap::default(); for &(ref var_map, base_var_type) in &var_maps_and_types { let blocks_with_phi = insert_phis(blocks, dominance_frontier, var_map); let mut renamer = Renamer { @@ -205,16 +207,15 @@ fn insert_phis_all( phi_defs: FxHashSet::default(), visited: FxHashSet::default(), stack: Vec::new(), - rewrite_rules: FxHashMap::default(), + rewrite_rules: &mut rewrite_rules, }; renamer.rename(0, None); - // FIXME(eddyb) shouldn't this full rescan of the function be done once? - apply_rewrite_rules( - &renamer.rewrite_rules, - blocks.values_mut().map(|block| &mut **block), - ); - remove_nops(blocks); } + apply_rewrite_rules( + &rewrite_rules, + blocks.values_mut().map(|block| &mut **block), + ); + remove_nops(blocks); remove_old_variables(blocks, &var_maps_and_types); true } @@ -443,7 +444,7 @@ struct Renamer<'a, 'b> { phi_defs: FxHashSet, visited: FxHashSet, stack: Vec, - rewrite_rules: FxHashMap, + rewrite_rules: &'a mut FxHashMap, } impl Renamer<'_, '_> { diff --git a/crates/rustc_codegen_spirv/src/linker/mod.rs b/crates/rustc_codegen_spirv/src/linker/mod.rs index de40a19131..916d35b337 100644 --- a/crates/rustc_codegen_spirv/src/linker/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/mod.rs @@ -103,7 +103,7 @@ fn apply_rewrite_rules<'a>( ) }); for id in all_ids_mut { - if let Some(&rewrite) = rewrite_rules.get(id) { + while let Some(&rewrite) = rewrite_rules.get(id) { *id = rewrite; } } @@ -562,6 +562,12 @@ pub fn link( ); } + { + let timer = before_pass("spirt_passes::explicit_layout::erase_when_invalid"); + spirt_passes::explicit_layout::erase_when_invalid(module); + after_pass(Some(module), timer); + } + { let timer = before_pass("spirt_passes::validate"); spirt_passes::validate::validate(module); diff --git a/crates/rustc_codegen_spirv/src/linker/spirt_passes/explicit_layout.rs b/crates/rustc_codegen_spirv/src/linker/spirt_passes/explicit_layout.rs new file mode 100644 index 0000000000..2812e452be --- /dev/null +++ b/crates/rustc_codegen_spirv/src/linker/spirt_passes/explicit_layout.rs @@ -0,0 +1,860 @@ +//! SPIR-T passes related to "explicit layout decorations" (`Offset`/`ArrayStride`). + +use either::Either; +use itertools::Itertools; +use rustc_data_structures::fx::{FxHashMap, FxHashSet}; +use smallvec::SmallVec; +use spirt::func_at::{FuncAt, FuncAtMut}; +use spirt::transform::{InnerInPlaceTransform, InnerTransform, Transformed, Transformer}; +use spirt::visit::InnerVisit as _; +use spirt::{ + AddrSpace, Attr, AttrSetDef, Const, ConstKind, Context, ControlNode, ControlNodeKind, DataInst, + DataInstDef, DataInstForm, DataInstFormDef, DataInstKind, DeclDef, Diag, Func, FuncDecl, + GlobalVar, GlobalVarDecl, Module, Type, TypeDef, TypeKind, TypeOrConst, Value, spv, +}; +use std::cmp::Ordering; +use std::collections::VecDeque; + +/// Erase explicit layout decorations from struct/array types, when used with +/// storage classes that do not support them (as per the Vulkan spec). +// +// NOTE(eddyb) this is a stop-gap until `spirt::{mem,qptr}` can replace it. +pub fn erase_when_invalid(module: &mut Module) { + let spv_spec = super::SpvSpecWithExtras::get(); + let wk = &spv_spec.well_known; + + let mut eraser = SelectiveEraser { + cx: &module.cx(), + wk, + + transformed_types: FxHashMap::default(), + transformed_consts: FxHashMap::default(), + transformed_data_inst_forms: FxHashMap::default(), + seen_global_vars: FxHashSet::default(), + global_var_queue: VecDeque::new(), + seen_funcs: FxHashSet::default(), + func_queue: VecDeque::new(), + + cached_erased_explicit_layout_types: FxHashMap::default(), + cached_erased_explicit_layout_consts: FxHashMap::default(), + + parent_block: None, + }; + + // Seed the queues starting from the module exports. + for exportee in module.exports.values_mut() { + exportee + .inner_transform_with(&mut eraser) + .apply_to(exportee); + } + + // Process the queues until they're all empty. + while !eraser.global_var_queue.is_empty() || !eraser.func_queue.is_empty() { + while let Some(gv) = eraser.global_var_queue.pop_front() { + eraser.in_place_transform_global_var_decl(&mut module.global_vars[gv]); + } + while let Some(func) = eraser.func_queue.pop_front() { + eraser.in_place_transform_func_decl(&mut module.funcs[func]); + } + } +} + +// FIXME(eddyb) name could be better (avoiding overly verbose is a bit tricky). +struct SelectiveEraser<'a> { + cx: &'a Context, + wk: &'static super::SpvWellKnownWithExtras, + + // FIXME(eddyb) build some automation to avoid ever repeating these. + transformed_types: FxHashMap>, + transformed_consts: FxHashMap>, + transformed_data_inst_forms: FxHashMap>, + seen_global_vars: FxHashSet, + global_var_queue: VecDeque, + seen_funcs: FxHashSet, + func_queue: VecDeque, + + // FIXME(eddyb) these overlap with some `transformed_*` fields above, + // but they're contextually transformed additionally. + // HACK(eddyb) these are now used via the `EraseExplicitLayout` newtype. + cached_erased_explicit_layout_types: FxHashMap>, + cached_erased_explicit_layout_consts: FxHashMap>, + + // HACK(eddyb) this is to allow `in_place_transform_data_inst_def` inject + // new instructions into its parent block. + parent_block: Option, +} + +impl Transformer for SelectiveEraser<'_> { + // FIXME(eddyb) build some automation to avoid ever repeating these. + fn transform_type_use(&mut self, ty: Type) -> Transformed { + if let Some(&cached) = self.transformed_types.get(&ty) { + return cached; + } + let transformed = self + .transform_type_def(&self.cx[ty]) + .map(|ty_def| self.cx.intern(ty_def)); + self.transformed_types.insert(ty, transformed); + transformed + } + fn transform_const_use(&mut self, ct: Const) -> Transformed { + if let Some(&cached) = self.transformed_consts.get(&ct) { + return cached; + } + let transformed = self + .transform_const_def(&self.cx[ct]) + .map(|ct_def| self.cx.intern(ct_def)); + self.transformed_consts.insert(ct, transformed); + transformed + } + fn transform_data_inst_form_use( + &mut self, + data_inst_form: DataInstForm, + ) -> Transformed { + if let Some(&cached) = self.transformed_data_inst_forms.get(&data_inst_form) { + return cached; + } + let transformed = self + .transform_data_inst_form_def(&self.cx[data_inst_form]) + .map(|data_inst_form_def| self.cx.intern(data_inst_form_def)); + self.transformed_data_inst_forms + .insert(data_inst_form, transformed); + transformed + } + + fn transform_global_var_use(&mut self, gv: GlobalVar) -> Transformed { + if self.seen_global_vars.insert(gv) { + self.global_var_queue.push_back(gv); + } + Transformed::Unchanged + } + fn transform_func_use(&mut self, func: Func) -> Transformed { + if self.seen_funcs.insert(func) { + self.func_queue.push_back(func); + } + Transformed::Unchanged + } + + // NOTE(eddyb) above methods are plumbing, erasure methods are below. + + fn transform_type_def(&mut self, ty_def: &TypeDef) -> Transformed { + let wk = self.wk; + + let needs_erasure_of_explicit_layout = match &ty_def.kind { + TypeKind::SpvInst { + spv_inst, + type_and_const_inputs: _, + } if spv_inst.opcode == wk.OpTypePointer => match spv_inst.imms[..] { + [spv::Imm::Short(sc_kind, sc)] => { + assert_eq!(sc_kind, wk.StorageClass); + !self.addr_space_allows_explicit_layout(AddrSpace::SpvStorageClass(sc)) + } + _ => unreachable!(), + }, + + _ => false, + }; + if needs_erasure_of_explicit_layout { + ty_def.inner_transform_with(&mut EraseExplicitLayout(self)) + } else { + ty_def.inner_transform_with(self) + } + } + + fn in_place_transform_global_var_decl(&mut self, gv_decl: &mut GlobalVarDecl) { + let needs_erasure_of_explicit_layout = + !self.addr_space_allows_explicit_layout(gv_decl.addr_space); + if needs_erasure_of_explicit_layout { + gv_decl.inner_in_place_transform_with(&mut EraseExplicitLayout(self)); + } else { + gv_decl.inner_in_place_transform_with(self); + } + } + + fn in_place_transform_func_decl(&mut self, func_decl: &mut FuncDecl) { + // HACK(eddyb) to catch any instructions having their input/output types + // changed from under them, a separate visit has to be used before *any* + // region input/node output declarations in the function body may change. + if let DeclDef::Present(func_def_body) = &mut func_decl.def { + let mut errors_to_attach = vec![]; + func_def_body.inner_visit_with(&mut super::VisitAllControlRegionsAndNodes { + state: (), + visit_control_region: |_: &mut _, _| {}, + visit_control_node: |_: &mut _, func_at_node: FuncAt<'_, ControlNode>| { + if let ControlNodeKind::Block { insts } = func_at_node.def().kind { + for func_at_inst in func_at_node.at(insts) { + if let Err(e) = self.pre_check_data_inst(func_at_inst) { + errors_to_attach.push((func_at_inst.position, e)); + } + } + } + }, + }); + for (inst, err) in errors_to_attach { + func_def_body + .at_mut(inst) + .def() + .attrs + .push_diag(self.cx, err); + } + } + + func_decl.inner_in_place_transform_with(self); + } + + fn in_place_transform_control_node_def( + &mut self, + mut func_at_control_node: FuncAtMut<'_, ControlNode>, + ) { + let old_parent_block = self.parent_block.take(); + if let ControlNodeKind::Block { .. } = func_at_control_node.reborrow().def().kind { + self.parent_block = Some(func_at_control_node.position); + } + func_at_control_node.inner_in_place_transform_with(self); + self.parent_block = old_parent_block; + } + + fn in_place_transform_data_inst_def(&mut self, mut func_at_data_inst: FuncAtMut<'_, DataInst>) { + let cx = self.cx; + let wk = self.wk; + + func_at_data_inst + .reborrow() + .inner_in_place_transform_with(self); + + let func_at_data_inst_frozen = func_at_data_inst.reborrow().freeze(); + let data_inst = func_at_data_inst_frozen.position; + let data_inst_def = func_at_data_inst_frozen.def(); + let data_inst_form_def = &cx[data_inst_def.form]; + let func = func_at_data_inst_frozen.at(()); + let type_of_val = |v: Value| func.at(v).type_of(cx); + let pointee_type_of_ptr_val = |p: Value| match &cx[type_of_val(p)].kind { + TypeKind::SpvInst { + spv_inst, + type_and_const_inputs, + } if spv_inst.opcode == wk.OpTypePointer => match type_and_const_inputs[..] { + [TypeOrConst::Type(elem_type)] => Some(elem_type), + _ => unreachable!(), + }, + _ => None, + }; + + let DataInstKind::SpvInst(spv_inst) = &data_inst_form_def.kind else { + return; + }; + + // FIXME(eddyb) filter attributes into debuginfo and + // semantic, and understand the semantic ones. + let attrs = data_inst_def.attrs; + + let mk_bitcast_def = |in_value, out_type| DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { + kind: DataInstKind::SpvInst(wk.OpBitcast.into()), + output_type: Some(out_type), + }), + inputs: [in_value].into_iter().collect(), + }; + + if spv_inst.opcode == wk.OpLoad { + let pointee_type = pointee_type_of_ptr_val(data_inst_def.inputs[0]); + let value_type = data_inst_form_def.output_type.unwrap(); + // FIXME(eddyb) leave a BUG diagnostic in the `None` case? + if pointee_type.is_some_and(|ty| { + ty != value_type && ty == self.erase_explicit_layout_in_type(value_type) + }) { + let func = func_at_data_inst.at(()); + let ControlNodeKind::Block { insts } = + &mut func.control_nodes[self.parent_block.unwrap()].kind + else { + unreachable!() + }; + + let fixed_load_inst = func.data_insts.define( + cx, + DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { + kind: data_inst_form_def.kind.clone(), + output_type: Some(pointee_type.unwrap()), + }), + inputs: func.data_insts[data_inst].inputs.clone(), + } + .into(), + ); + insts.insert_before(fixed_load_inst, data_inst, func.data_insts); + *func.data_insts[data_inst] = + mk_bitcast_def(Value::DataInstOutput(fixed_load_inst), value_type); + + self.disaggregate_bitcast(func.at(data_inst)); + } + } else if spv_inst.opcode == wk.OpStore { + let pointee_type = pointee_type_of_ptr_val(data_inst_def.inputs[0]); + let value_type = type_of_val(data_inst_def.inputs[1]); + // FIXME(eddyb) leave a BUG diagnostic in the `None` case? + if pointee_type.is_some_and(|ty| { + ty != value_type && ty == self.erase_explicit_layout_in_type(value_type) + }) { + let func = func_at_data_inst.at(()); + let stored_value = &mut func.data_insts[data_inst].inputs[1]; + + if let Value::Const(ct) = stored_value { + EraseExplicitLayout(self) + .transform_const_use(*ct) + .apply_to(ct); + } else { + let original_stored_value = *stored_value; + + let ControlNodeKind::Block { insts } = + &mut func.control_nodes[self.parent_block.unwrap()].kind + else { + unreachable!() + }; + let stored_value_cast_inst = func.data_insts.define( + cx, + mk_bitcast_def(original_stored_value, pointee_type.unwrap()).into(), + ); + insts.insert_before(stored_value_cast_inst, data_inst, func.data_insts); + func.data_insts[data_inst].inputs[1] = + Value::DataInstOutput(stored_value_cast_inst); + + self.disaggregate_bitcast(func.at(stored_value_cast_inst)); + } + } + } else if spv_inst.opcode == wk.OpCopyMemory { + let dst_ptr = data_inst_def.inputs[0]; + let src_ptr = data_inst_def.inputs[1]; + let [dst_pointee_type, src_pointee_type] = + [dst_ptr, src_ptr].map(pointee_type_of_ptr_val); + // FIXME(eddyb) leave a BUG diagnostic in the `None` case? + let mismatched_dst_src_types = match [dst_pointee_type, src_pointee_type] { + [Some(a), Some(b)] => { + // FIXME(eddyb) there has to be a nicer way to write this?? + fn equal([a, b]: [T; 2]) -> bool { + a == b + } + !equal([a, b]) && equal([a, b].map(|ty| self.erase_explicit_layout_in_type(ty))) + } + _ => false, + }; + if mismatched_dst_src_types { + let is_memory_access_imm = + |imm| matches!(imm, &spv::Imm::Short(k, _) if k == wk.MemoryAccess); + + // HACK(eddyb) this relies on `MemoryAccess` being non-recursive + // (in fact, its parameters seem to only be simple literals/IDs). + let (dst_imms, src_imms) = + match (spv_inst.imms.iter().positions(is_memory_access_imm)) + .collect::>()[..] + { + [] | [0] => (&spv_inst.imms[..], &spv_inst.imms[..]), + [0, i] => spv_inst.imms.split_at(i), + _ => unreachable!(), + }; + + let func = func_at_data_inst.at(()); + let ControlNodeKind::Block { insts } = + &mut func.control_nodes[self.parent_block.unwrap()].kind + else { + unreachable!() + }; + + let load_inst = func.data_insts.define( + cx, + DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { + kind: DataInstKind::SpvInst(spv::Inst { + opcode: wk.OpLoad, + imms: src_imms.iter().copied().collect(), + }), + output_type: Some(src_pointee_type.unwrap()), + }), + inputs: [src_ptr].into_iter().collect(), + } + .into(), + ); + insts.insert_before(load_inst, data_inst, func.data_insts); + let cast_inst = func.data_insts.define( + cx, + mk_bitcast_def(Value::DataInstOutput(load_inst), dst_pointee_type.unwrap()) + .into(), + ); + insts.insert_before(cast_inst, data_inst, func.data_insts); + + *func.data_insts[data_inst] = DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { + kind: DataInstKind::SpvInst(spv::Inst { + opcode: wk.OpStore, + imms: dst_imms.iter().copied().collect(), + }), + output_type: None, + }), + inputs: [dst_ptr, Value::DataInstOutput(cast_inst)] + .into_iter() + .collect(), + }; + + self.disaggregate_bitcast(func.at(cast_inst)); + } + } + } +} + +impl<'a> SelectiveEraser<'a> { + fn addr_space_allows_explicit_layout(&self, addr_space: AddrSpace) -> bool { + let wk = self.wk; + + // FIXME(eddyb) this might need to include `Workgroup`, specifically when + // `WorkgroupMemoryExplicitLayoutKHR` is being relied upon. + [ + wk.PushConstant, + wk.Uniform, + wk.StorageBuffer, + wk.PhysicalStorageBuffer, + ] + .map(AddrSpace::SpvStorageClass) + .contains(&addr_space) + } + + // FIXME(eddyb) properly distinguish between zero-extension and sign-extension. + fn const_as_u32(&self, ct: Const) -> Option { + if let ConstKind::SpvInst { + spv_inst_and_const_inputs, + } = &self.cx[ct].kind + { + let (spv_inst, _const_inputs) = &**spv_inst_and_const_inputs; + if spv_inst.opcode == self.wk.OpConstant && spv_inst.imms.len() == 1 { + match spv_inst.imms[..] { + [spv::Imm::Short(_, x)] => return Some(x), + _ => unreachable!(), + } + } + } + None + } + + fn aggregate_component_types( + &self, + ty: Type, + ) -> Option + Clone + 'a> { + let cx = self.cx; + let wk = self.wk; + + match &cx[ty].kind { + TypeKind::SpvInst { + spv_inst, + type_and_const_inputs, + } if spv_inst.opcode == wk.OpTypeStruct => { + Some(Either::Left(type_and_const_inputs.iter().map( + |&ty_or_ct| match ty_or_ct { + TypeOrConst::Type(ty) => ty, + TypeOrConst::Const(_) => unreachable!(), + }, + ))) + } + TypeKind::SpvInst { + spv_inst, + type_and_const_inputs, + } if spv_inst.opcode == wk.OpTypeArray => { + let [TypeOrConst::Type(elem_type), TypeOrConst::Const(count)] = + type_and_const_inputs[..] + else { + unreachable!() + }; + let count = self.const_as_u32(count)?; + Some(Either::Right((0..count).map(move |_| elem_type))) + } + _ => None, + } + } + + fn erase_explicit_layout_in_type(&mut self, mut ty: Type) -> Type { + EraseExplicitLayout(self) + .transform_type_use(ty) + .apply_to(&mut ty); + ty + } + + // HACK(eddyb) this expands an illegal `OpBitcast` of a struct/array, into + // leaf values from the source aggregate that are then recomposed into the + // target aggregate - this should go away when SPIR-T `disaggregate` lands. + fn disaggregate_bitcast(&mut self, mut func_at_cast_inst: FuncAtMut<'_, DataInst>) { + let cx = self.cx; + let wk = self.wk; + + let cast_inst = func_at_cast_inst.position; + let cast_def = func_at_cast_inst.reborrow().freeze().def().clone(); + let cast_form_def = &cx[cast_def.form]; + + // FIXME(eddyb) filter attributes into debuginfo and + // semantic, and understand the semantic ones. + let attrs = cast_def.attrs; + + assert!(cast_form_def.kind == DataInstKind::SpvInst(wk.OpBitcast.into())); + let in_value = cast_def.inputs[0]; + let out_type = cast_form_def.output_type.unwrap(); + + let mut func = func_at_cast_inst.reborrow(); + let in_type = func.reborrow().freeze().at(in_value).type_of(cx); + + // FIXME(eddyb) there has to be a nicer way to write this?? + fn equal([a, b]: [T; 2]) -> bool { + a == b + } + + let [in_component_types, out_component_types] = Some([in_type, out_type]) + .filter(|&types| { + !equal(types) && equal(types.map(|ty| self.erase_explicit_layout_in_type(ty))) + }) + .map(|types| types.map(|ty| self.aggregate_component_types(ty))) + .unwrap_or_default(); + + // NOTE(eddyb) such sanity checks should always succeed, because of the + // "in/out types are equal after erasure" check, earlier above. + assert_eq!( + in_component_types.as_ref().map(|iter| iter.len()), + out_component_types.as_ref().map(|iter| iter.len()), + ); + + let [Some(in_component_types), Some(out_component_types)] = + [in_component_types, out_component_types] + else { + return; + }; + + let components = (in_component_types.zip_eq(out_component_types).enumerate()) + .map(|(component_idx, (component_in_type, component_out_type))| { + let component_idx = u32::try_from(component_idx).unwrap(); + + let component_cast_types = + Some([component_in_type, component_out_type]).filter(|&types| !equal(types)); + if let Some(component_cast_types) = component_cast_types { + assert!(equal( + component_cast_types.map(|ty| self.erase_explicit_layout_in_type(ty)) + )); + } + + let component_extract_inst = func.data_insts.define( + cx, + DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { + kind: DataInstKind::SpvInst(spv::Inst { + opcode: wk.OpCompositeExtract, + imms: [spv::Imm::Short(wk.LiteralInteger, component_idx)] + .into_iter() + .collect(), + }), + output_type: Some(component_in_type), + }), + inputs: [in_value].into_iter().collect(), + } + .into(), + ); + + let ControlNodeKind::Block { insts } = + &mut func.control_nodes[self.parent_block.unwrap()].kind + else { + unreachable!() + }; + insts.insert_before(component_extract_inst, cast_inst, func.data_insts); + + let component_cast_inst = component_cast_types.map(|[_, component_out_type]| { + let inst = func.data_insts.define( + cx, + DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { + kind: DataInstKind::SpvInst(wk.OpBitcast.into()), + output_type: Some(component_out_type), + }), + inputs: [Value::DataInstOutput(component_extract_inst)] + .into_iter() + .collect(), + } + .into(), + ); + insts.insert_before(inst, cast_inst, func.data_insts); + + inst + }); + + if let Some(component_cast_inst) = component_cast_inst { + self.disaggregate_bitcast(func.reborrow().at(component_cast_inst)); + } + + Value::DataInstOutput(component_cast_inst.unwrap_or(component_extract_inst)) + }) + .collect(); + + *func.at(cast_inst).def() = DataInstDef { + attrs, + form: cx.intern(DataInstFormDef { + kind: DataInstKind::SpvInst(wk.OpCompositeConstruct.into()), + output_type: Some(out_type), + }), + inputs: components, + }; + } + + // HACK(eddyb) this runs on every `DataInst` in a function body, before the + // declarations of any region input/node output, are ever changed, to catch + // the cases that would need special handling, but lack it. + fn pre_check_data_inst(&mut self, func_at_inst: FuncAt<'_, DataInst>) -> Result<(), Diag> { + let cx = self.cx; + let wk = self.wk; + + let data_inst_def = func_at_inst.def(); + let data_inst_form_def = &cx[data_inst_def.form]; + + // FIXME(eddyb) consider preserving the actual type change in the error. + let any_types_will_change = (data_inst_form_def.output_type.into_iter()) + .chain( + data_inst_def + .inputs + .iter() + .map(|&v| func_at_inst.at(v).type_of(cx)), + ) + .any(|ty| { + let mut new_ty = ty; + self.transform_type_use(ty).apply_to(&mut new_ty); + new_ty != ty + }); + if !any_types_will_change { + return Ok(()); + } + + let spv_inst = match &data_inst_form_def.kind { + DataInstKind::FuncCall(_) => return Ok(()), + + DataInstKind::SpvInst(spv_inst) + if [wk.OpLoad, wk.OpStore, wk.OpCopyMemory].contains(&spv_inst.opcode) => + { + return Ok(()); + } + + DataInstKind::QPtr(_) => { + return Err(Diag::bug([ + "unhandled pointer type change in unexpected `qptr` instruction".into(), + ])); + } + &DataInstKind::SpvExtInst { ext_set, inst } => { + let ext_set = &cx[ext_set]; + return Err(Diag::bug([format!( + "unhandled pointer type change in extended SPIR-V \ + (`{ext_set}` / #{inst}) instruction" + ) + .into()])); + } + + DataInstKind::SpvInst(spv_inst) => spv_inst, + }; + + let sigs = crate::spirv_type_constraints::instruction_signatures( + rspirv::spirv::Op::from_u32(spv_inst.opcode.as_u16().into()).unwrap(), + ); + let pointer_pointee_correlated_sigs: SmallVec<[_; 1]> = sigs + .unwrap_or(&[]) + .iter() + .filter(|sig| { + use crate::spirv_type_constraints::{TyListPat, TyPat}; + + #[derive(Copy, Clone, Default)] + struct ConstrainedVars { + direct: bool, + in_pointee: bool, + } + impl std::ops::BitOr for ConstrainedVars { + type Output = Self; + fn bitor(self, rhs: Self) -> Self { + Self { + direct: self.direct | rhs.direct, + in_pointee: self.in_pointee | rhs.in_pointee, + } + } + } + impl ConstrainedVars { + fn collect_from(pat: &TyPat<'_>) -> Self { + match pat { + TyPat::Pointer(_, inner) => { + let Self { direct, in_pointee } = Self::collect_from(inner); + Self { + direct: false, + in_pointee: direct | in_pointee, + } + } + + TyPat::Any | TyPat::Void => Self::default(), + TyPat::Var(_) => Self { + direct: true, + in_pointee: false, + }, + TyPat::Either(a, b) => Self::collect_from(a) | Self::collect_from(b), + TyPat::Array(inner) + | TyPat::Vector(inner) + | TyPat::Vector4(inner) + | TyPat::Matrix(inner) + | TyPat::Image(inner) + | TyPat::Pipe(inner) + | TyPat::SampledImage(inner) + | TyPat::IndexComposite(inner) => Self::collect_from(inner), + TyPat::Struct(fields) => Self::collect_from_list_leaves(fields), + TyPat::Function(output, inputs) => { + Self::collect_from(output) | Self::collect_from_list_leaves(inputs) + } + } + } + fn collect_from_list_leaves(pat: &TyListPat<'_>) -> Self { + match pat { + TyListPat::Any | TyListPat::Nil | TyListPat::Var(_) => Self::default(), + TyListPat::Repeat(inner) => Self::collect_from_list_leaves(inner), + TyListPat::Cons { first, suffix } => { + Self::collect_from(first) | Self::collect_from_list_leaves(suffix) + } + } + } + } + + let mut min_expected_inputs = 0; + let mut constrained_vars = sig + .output_type + .map(ConstrainedVars::collect_from) + .unwrap_or_default(); + + let mut inputs = sig.input_types; + while let TyListPat::Cons { first, suffix } = inputs { + min_expected_inputs += 1; + constrained_vars = constrained_vars | ConstrainedVars::collect_from(first); + + inputs = suffix; + } + + if let (Ordering::Less, _) | (Ordering::Greater, TyListPat::Nil) = + (data_inst_def.inputs.len().cmp(&min_expected_inputs), inputs) + { + return false; + } + + constrained_vars.direct && constrained_vars.in_pointee + }) + .collect(); + if !pointer_pointee_correlated_sigs.is_empty() { + return Err(Diag::bug([format!( + "unhandled pointer type change in `{}` SPIR-V instruction: \ + {pointer_pointee_correlated_sigs:#?}", + spv_inst.opcode.name() + ) + .into()])); + } + Ok(()) + } +} + +// HACK(eddyb) wrapper modifying `Transformer` behavior of `SelectiveEraser`. +struct EraseExplicitLayout<'a, 'b>(&'a mut SelectiveEraser<'b>); + +impl Transformer for EraseExplicitLayout<'_, '_> { + // FIXME(eddyb) build some automation to avoid ever repeating these. + fn transform_type_use(&mut self, ty: Type) -> Transformed { + if let Some(&cached) = self.0.cached_erased_explicit_layout_types.get(&ty) { + return cached; + } + let transformed = self + .transform_type_def(&self.0.cx[ty]) + .map(|ty_def| self.0.cx.intern(ty_def)); + self.0 + .cached_erased_explicit_layout_types + .insert(ty, transformed); + transformed + } + fn transform_const_use(&mut self, ct: Const) -> Transformed { + if let Some(&cached) = self.0.cached_erased_explicit_layout_consts.get(&ct) { + return cached; + } + let transformed = self + .transform_const_def(&self.0.cx[ct]) + .map(|ct_def| self.0.cx.intern(ct_def)); + self.0 + .cached_erased_explicit_layout_consts + .insert(ct, transformed); + transformed + } + fn transform_data_inst_form_use(&mut self, _: DataInstForm) -> Transformed { + unreachable!() + } + + fn transform_global_var_use(&mut self, gv: GlobalVar) -> Transformed { + self.0.transform_global_var_use(gv) + } + fn transform_func_use(&mut self, func: Func) -> Transformed { + self.0.transform_func_use(func) + } + + // NOTE(eddyb) above methods are plumbing, erasure methods are below. + + fn transform_type_def(&mut self, ty_def: &TypeDef) -> Transformed { + let wk = self.0.wk; + + // HACK(eddyb) reconsider pointer types, based on *their* storage class + // (e.g. implicit-layout pointers to explicit-layout pointers, even if + // for Vulkan that's only possible by involving `PhysicalStorageBuffer`). + match &ty_def.kind { + TypeKind::SpvInst { + spv_inst, + type_and_const_inputs: _, + } if spv_inst.opcode == wk.OpTypePointer => { + return self.0.transform_type_def(ty_def); + } + _ => {} + } + + let transformed = ty_def.inner_transform_with(self); + + let old_attrs = match &transformed { + Transformed::Unchanged => ty_def.attrs, + Transformed::Changed(new_ty_def) => new_ty_def.attrs, + }; + + let new_attrs = self.0.cx.intern(AttrSetDef { + attrs: self.0.cx[old_attrs] + .attrs + .iter() + .filter(|attr| { + // FIXME(eddyb) `rustfmt` breaks down for `matches!`. + #[allow(clippy::match_like_matches_macro)] + let is_explicit_layout_decoration = match attr { + Attr::SpvAnnotation(attr_spv_inst) + if (attr_spv_inst.opcode == wk.OpDecorate + && [wk.ArrayStride, wk.MatrixStride] + .map(|d| spv::Imm::Short(wk.Decoration, d)) + .contains(&attr_spv_inst.imms[0])) + || (attr_spv_inst.opcode == wk.OpMemberDecorate + && attr_spv_inst.imms[1] + == spv::Imm::Short(wk.Decoration, wk.Offset)) => + { + true + } + + _ => false, + }; + !is_explicit_layout_decoration + }) + .cloned() + .collect(), + }); + + if old_attrs == new_attrs { + return transformed; + } + + let mut ty_def = TypeDef { + attrs: ty_def.attrs, + kind: ty_def.kind.clone(), + }; + transformed.apply_to(&mut ty_def); + + ty_def.attrs = new_attrs; + Transformed::Changed(ty_def) + } +} diff --git a/crates/rustc_codegen_spirv/src/linker/spirt_passes/mod.rs b/crates/rustc_codegen_spirv/src/linker/spirt_passes/mod.rs index f8b8c8518a..f41dc5aa4d 100644 --- a/crates/rustc_codegen_spirv/src/linker/spirt_passes/mod.rs +++ b/crates/rustc_codegen_spirv/src/linker/spirt_passes/mod.rs @@ -3,6 +3,7 @@ pub(crate) mod controlflow; pub(crate) mod debuginfo; pub(crate) mod diagnostics; +pub(crate) mod explicit_layout; mod fuse_selects; mod reduce; pub(crate) mod validate; @@ -64,15 +65,16 @@ macro_rules! def_spv_spec_with_extra_well_known { let spv_spec = spv::spec::Spec::get(); let wk = &spv_spec.well_known; - let decorations = match wk.Decoration.def() { + let [decorations, storage_classes] = [wk.Decoration, wk.StorageClass].map(|kind| match kind.def() { spv::spec::OperandKindDef::ValueEnum { variants } => variants, _ => unreachable!(), - }; + }); let lookup_fns = PerWellKnownGroup { opcode: |name| spv_spec.instructions.lookup(name).unwrap(), operand_kind: |name| spv_spec.operand_kinds.lookup(name).unwrap(), decoration: |name| decorations.lookup(name).unwrap().into(), + storage_class: |name| storage_classes.lookup(name).unwrap().into(), }; SpvSpecWithExtras { @@ -100,14 +102,25 @@ def_spv_spec_with_extra_well_known! { OpBitcast, OpCompositeInsert, OpCompositeExtract, + OpCompositeConstruct, + + OpCopyMemory, ], operand_kind: spv::spec::OperandKind = [ Capability, ExecutionModel, ImageFormat, + MemoryAccess, ], decoration: u32 = [ UserTypeGOOGLE, + MatrixStride, + ], + storage_class: u32 = [ + PushConstant, + Uniform, + StorageBuffer, + PhysicalStorageBuffer, ], } diff --git a/tests/compiletests/ui/dis/entry-pass-mode-cast-array.stderr b/tests/compiletests/ui/dis/entry-pass-mode-cast-array.stderr index 5fd5cdd459..92f15bfa31 100644 --- a/tests/compiletests/ui/dis/entry-pass-mode-cast-array.stderr +++ b/tests/compiletests/ui/dis/entry-pass-mode-cast-array.stderr @@ -2,12 +2,18 @@ %4 = OpLabel OpLine %5 13 12 %6 = OpLoad %7 %8 -OpLine %5 14 4 %9 = OpCompositeExtract %10 %6 0 -%11 = OpFAdd %10 %9 %12 -%13 = OpCompositeInsert %7 %11 %6 0 +%11 = OpCompositeExtract %10 %6 1 +%12 = OpCompositeConstruct %13 %9 %11 +OpLine %5 14 4 +%14 = OpCompositeExtract %10 %12 0 +%15 = OpFAdd %10 %14 %16 +%17 = OpCompositeInsert %13 %15 %12 0 OpLine %5 15 4 -OpStore %14 %13 +%18 = OpCompositeExtract %10 %17 0 +%19 = OpCompositeExtract %10 %17 1 +%20 = OpCompositeConstruct %7 %18 %19 +OpStore %21 %20 OpNoLine OpReturn OpFunctionEnd diff --git a/tests/compiletests/ui/dis/issue-731.stderr b/tests/compiletests/ui/dis/issue-731.stderr index 78fdc54539..a891aaf5b0 100644 --- a/tests/compiletests/ui/dis/issue-731.stderr +++ b/tests/compiletests/ui/dis/issue-731.stderr @@ -2,12 +2,20 @@ %4 = OpLabel OpLine %5 11 12 %6 = OpLoad %7 %8 -OpLine %5 12 4 %9 = OpCompositeExtract %10 %6 0 -%11 = OpFAdd %10 %9 %12 -%13 = OpCompositeInsert %7 %11 %6 0 +%11 = OpCompositeExtract %10 %6 1 +%12 = OpCompositeExtract %10 %6 2 +%13 = OpCompositeConstruct %14 %9 %11 %12 +OpLine %5 12 4 +%15 = OpCompositeExtract %10 %13 0 +%16 = OpFAdd %10 %15 %17 +%18 = OpCompositeInsert %14 %16 %13 0 OpLine %5 13 4 -OpStore %14 %13 +%19 = OpCompositeExtract %10 %18 0 +%20 = OpCompositeExtract %10 %18 1 +%21 = OpCompositeExtract %10 %18 2 +%22 = OpCompositeConstruct %7 %19 %20 %21 +OpStore %23 %22 OpNoLine OpReturn OpFunctionEnd diff --git a/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr b/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr index edef031324..a3bee642b1 100644 --- a/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr +++ b/tests/compiletests/ui/dis/panic_builtin_bounds_check.stderr @@ -12,39 +12,45 @@ OpDecorate %6 ArrayStride 4 %8 = OpTypeFunction %7 %9 = OpTypeInt 32 0 %10 = OpConstant %9 4 +%11 = OpTypeArray %9 %10 +%12 = OpTypePointer Function %11 %6 = OpTypeArray %9 %10 -%11 = OpTypePointer Function %6 -%12 = OpConstant %9 0 -%13 = OpConstant %9 1 -%14 = OpConstant %9 2 -%15 = OpConstant %9 3 -%16 = OpTypeBool -%17 = OpConstant %9 5 -%18 = OpTypePointer Function %9 +%13 = OpConstant %9 0 +%14 = OpConstant %9 1 +%15 = OpConstant %9 2 +%16 = OpConstant %9 3 +%17 = OpTypeBool +%18 = OpConstant %9 5 +%19 = OpTypePointer Function %9 %2 = OpFunction %7 None %8 -%19 = OpLabel +%20 = OpLabel OpLine %5 32 4 -%20 = OpVariable %11 Function +%21 = OpVariable %12 Function OpLine %5 32 23 -%21 = OpCompositeConstruct %6 %12 %13 %14 %15 +%22 = OpCompositeConstruct %6 %13 %14 %15 %16 OpLine %5 27 4 -OpStore %20 %21 -%22 = OpULessThan %16 %17 %10 +%23 = OpCompositeExtract %9 %22 0 +%24 = OpCompositeExtract %9 %22 1 +%25 = OpCompositeExtract %9 %22 2 +%26 = OpCompositeExtract %9 %22 3 +%27 = OpCompositeConstruct %11 %23 %24 %25 %26 +OpStore %21 %27 +%28 = OpULessThan %17 %18 %10 OpNoLine -OpSelectionMerge %23 None -OpBranchConditional %22 %24 %25 -%24 = OpLabel -OpBranch %23 -%25 = OpLabel +OpSelectionMerge %29 None +OpBranchConditional %28 %30 %31 +%30 = OpLabel +OpBranch %29 +%31 = OpLabel OpLine %4 280 4 -%26 = OpExtInst %7 %1 1 %3 %10 %17 +%32 = OpExtInst %7 %1 1 %3 %10 %18 OpNoLine OpReturn -%23 = OpLabel +%29 = OpLabel OpLine %5 27 4 -%27 = OpIAdd %9 %12 %17 -%28 = OpInBoundsAccessChain %18 %20 %27 -%29 = OpLoad %9 %28 +%33 = OpIAdd %9 %13 %18 +%34 = OpInBoundsAccessChain %19 %21 %33 +%35 = OpLoad %9 %34 OpNoLine OpReturn OpFunctionEnd diff --git a/tests/compiletests/ui/dis/ptr_copy.normal.stderr b/tests/compiletests/ui/dis/ptr_copy.normal.stderr index c7db2ddf11..b993618ede 100644 --- a/tests/compiletests/ui/dis/ptr_copy.normal.stderr +++ b/tests/compiletests/ui/dis/ptr_copy.normal.stderr @@ -28,6 +28,12 @@ note: called by `main` error: cannot cast between pointer types from `*f32` to `*struct () { }` + --> $CORE_SRC/ptr/mod.rs:625:34 + | +625 | src: *const () = src as *const (), + | ^^^^^^^^^^^^^^^^ + | +note: used from within `core::ptr::copy::` --> $CORE_SRC/ptr/mod.rs:621:9 | 621 | / ub_checks::assert_unsafe_precondition!( @@ -37,6 +43,29 @@ error: cannot cast between pointer types 631 | | && ub_checks::maybe_is_aligned_and_not_null(dst, align, zero_size) 632 | | ); | |_________^ +note: called by `ptr_copy::copy_via_raw_ptr` + --> $DIR/ptr_copy.rs:28:18 + | +28 | unsafe { core::ptr::copy(src, dst, 1) } + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +note: called by `ptr_copy::main` + --> $DIR/ptr_copy.rs:33:5 + | +33 | copy_via_raw_ptr(&i, o); + | ^^^^^^^^^^^^^^^^^^^^^^^ +note: called by `main` + --> $DIR/ptr_copy.rs:32:8 + | +32 | pub fn main(i: f32, o: &mut f32) { + | ^^^^ + +error: cannot cast between pointer types + from `*f32` + to `*struct () { }` + --> $CORE_SRC/ptr/mod.rs:626:32 + | +626 | dst: *mut () = dst as *mut (), + | ^^^^^^^^^^^^^^ | note: used from within `core::ptr::copy::` --> $CORE_SRC/ptr/mod.rs:621:9 @@ -64,5 +93,5 @@ note: called by `main` 32 | pub fn main(i: f32, o: &mut f32) { | ^^^^ -error: aborting due to 2 previous errors +error: aborting due to 3 previous errors