From 08279574c358eaf8f1665cb2fe0467e9446ec1ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Sun, 19 Oct 2025 12:36:23 +0200 Subject: [PATCH 1/7] first definition of `offload` intrinsic (dirty code) --- .../src/builder/gpu_offload.rs | 138 ++++++++++++------ compiler/rustc_codegen_llvm/src/intrinsic.rs | 71 +++++++++ compiler/rustc_codegen_llvm/src/lib.rs | 2 + .../rustc_hir_analysis/src/check/intrinsic.rs | 2 + compiler/rustc_span/src/symbol.rs | 1 + library/core/src/intrinsics/mod.rs | 4 + .../gpu_offload/offload_intrinsic.rs | 37 +++++ 7 files changed, 210 insertions(+), 45 deletions(-) create mode 100644 tests/codegen-llvm/gpu_offload/offload_intrinsic.rs diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 5c2f8f700627e..c2df6489a726e 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -4,17 +4,18 @@ use llvm::Linkage::*; use rustc_abi::Align; use rustc_codegen_ssa::back::write::CodegenContext; use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; use crate::builder::SBuilder; -use crate::common::AsCCharPtr; use crate::llvm::AttributePlace::Function; -use crate::llvm::{self, Linkage, Type, Value}; +use crate::llvm::{self, BasicBlock, Linkage, Type, Value}; use crate::{LlvmCodegenBackend, SimpleCx, attributes}; pub(crate) fn handle_gpu_code<'ll>( _cgcx: &CodegenContext, - cx: &'ll SimpleCx<'_>, + _cx: &'ll SimpleCx<'_>, ) { + /* // The offload memory transfer type for each kernel let mut memtransfer_types = vec![]; let mut region_ids = vec![]; @@ -32,6 +33,7 @@ pub(crate) fn handle_gpu_code<'ll>( } gen_call_handling(&cx, &memtransfer_types, ®ion_ids); + */ } // ; Function Attrs: nounwind @@ -79,7 +81,7 @@ fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value { at_one } -struct TgtOffloadEntry { +pub(crate) struct TgtOffloadEntry { // uint64_t Reserved; // uint16_t Version; // uint16_t Kind; @@ -256,11 +258,14 @@ pub(crate) fn add_global<'ll>( // This function returns a memtransfer value which encodes how arguments to this kernel shall be // mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be // concatenated into the list of region_ids. -fn gen_define_handling<'ll>( - cx: &'ll SimpleCx<'_>, +pub(crate) fn gen_define_handling<'ll, 'tcx>( + cx: &SimpleCx<'ll>, + tcx: TyCtxt<'tcx>, kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, - num: i64, + // TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need + tt: Vec>, + symbol: &str, ) -> (&'ll llvm::Value, &'ll llvm::Value) { let types = cx.func_params_types(cx.get_type_of_global(kernel)); // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or @@ -270,11 +275,21 @@ fn gen_define_handling<'ll>( .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)) .count(); + // TODO(Sa4dUs): Add typetrees here + let ptr_sizes = types + .iter() + .zip(tt) + .filter_map(|(&x, ty)| match cx.type_kind(x) { + rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)), + _ => None, + }) + .collect::>(); + // We do not know their size anymore at this level, so hardcode a placeholder. // A follow-up pr will track these from the frontend, where we still have Rust types. // Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes. // I decided that 1024 bytes is a great placeholder value for now. - add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]); + add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes); // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2), // or both to and from the gpu (=3). Other values shouldn't affect us for now. // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten @@ -282,25 +297,28 @@ fn gen_define_handling<'ll>( // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later). let memtransfer_types = add_priv_unnamed_arr( &cx, - &format!(".offload_maptypes.{num}"), + &format!(".offload_maptypes.{symbol}"), &vec![1 + 2 + 32; num_ptr_types], ); + // Next: For each function, generate these three entries. A weak constant, // the llvm.rodata entry name, and the llvm_offload_entries value - let name = format!(".kernel_{num}.region_id"); + let name = format!(".{symbol}.region_id"); let initializer = cx.get_const_i8(0); let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage); - let c_entry_name = CString::new(format!("kernel_{num}")).unwrap(); + let c_entry_name = CString::new(symbol).unwrap(); let c_val = c_entry_name.as_bytes_with_nul(); - let offload_entry_name = format!(".offloading.entry_name.{num}"); + let offload_entry_name = format!(".offloading.entry_name.{symbol}"); let initializer = crate::common::bytes_in_context(cx.llcx, c_val); let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage); llvm::set_alignment(llglobal, Align::ONE); llvm::set_section(llglobal, c".llvm.rodata.offloading"); - let name = format!(".offloading.entry.kernel_{num}"); + + // Not actively used yet, for calling real kernels + let name = format!(".offloading.entry.{symbol}"); // See the __tgt_offload_entry documentation above. let elems = TgtOffloadEntry::new(&cx, region_id, llglobal); @@ -317,7 +335,57 @@ fn gen_define_handling<'ll>( (memtransfer_types, region_id) } -pub(crate) fn declare_offload_fn<'ll>( +// TODO(Sa4dUs): move this to a proper place +fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 { + match ty.kind() { + /* + rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Char => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Str => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(), + */ + ty::Ref(_, inner, _) => get_payload_size(tcx, *inner), + /* + rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Never => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(), + */ + _ => { + tcx + // TODO(Sa4dUs): Maybe `.as_query_input()`? + .layout_of(PseudoCanonicalInput { + typing_env: TypingEnv::fully_monomorphized(), + value: ty, + }) + .unwrap() + .size + .bytes() + } + } +} + +fn declare_offload_fn<'ll>( cx: &'ll SimpleCx<'_>, name: &str, ty: &'ll llvm::Type, @@ -352,10 +420,13 @@ pub(crate) fn declare_offload_fn<'ll>( // 4. set insert point after kernel call. // 5. generate all the GEPS and stores, to be used in 6) // 6. generate __tgt_target_data_end calls to move data from the GPU -fn gen_call_handling<'ll>( - cx: &'ll SimpleCx<'_>, +pub(crate) fn gen_call_handling<'ll>( + cx: &SimpleCx<'ll>, + bb: &BasicBlock, + kernels: &[&'ll llvm::Value], memtransfer_types: &[&'ll llvm::Value], region_ids: &[&'ll llvm::Value], + llfn: &'ll Value, ) { let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } @@ -368,27 +439,14 @@ fn gen_call_handling<'ll>( let tgt_kernel_decl = KernelArgsTy::new_decl(&cx); let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx); - let main_fn = cx.get_function("main"); - let Some(main_fn) = main_fn else { return }; - let kernel_name = "kernel_1"; - let call = unsafe { - llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len()) - }; - let Some(kernel_call) = call else { - return; - }; - let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) }; - let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() }; - let mut builder = SBuilder::build(cx, kernel_call_bb); - - let types = cx.func_params_types(cx.get_type_of_global(called)); + let mut builder = SBuilder::build(cx, bb); + + let types = cx.func_params_types(cx.get_type_of_global(kernels[0])); let num_args = types.len() as u64; // Step 0) // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } // %6 = alloca %struct.__tgt_bin_desc, align 8 - unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) }; - let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc"); let ty = cx.type_array(cx.type_ptr(), num_args); @@ -404,15 +462,14 @@ fn gen_call_handling<'ll>( let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args"); // Step 1) - unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) }; builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT); // Now we allocate once per function param, a copy to be passed to one of our maps. let mut vals = vec![]; let mut geps = vec![]; let i32_0 = cx.get_const_i32(0); - for index in 0..types.len() { - let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() }; + for index in 0..num_args { + let v = unsafe { llvm::LLVMGetParam(llfn, index as u32) }; let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]); vals.push(v); geps.push(gep); @@ -504,13 +561,8 @@ fn gen_call_handling<'ll>( region_ids[0], a5, ]; - let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None); + builder.call(tgt_target_kernel_ty, tgt_decl, &args, None); // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args) - unsafe { - let next = llvm::LLVMGetNextInstruction(offload_success).unwrap(); - llvm::LLVMRustPositionAfter(builder.llbuilder, next); - llvm::LLVMInstructionEraseFromParent(next); - } // Step 4) let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4); @@ -519,8 +571,4 @@ fn gen_call_handling<'ll>( builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None); drop(builder); - // FIXME(offload) The issue is that we right now add a call to the gpu version of the function, - // and then delete the call to the CPU version. In the future, we should use an intrinsic which - // directly resolves to a call to the GPU version. - unsafe { llvm::LLVMDeleteFunction(called) }; } diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 0626cb3f2f16b..24b521dab7b03 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -23,6 +23,7 @@ use tracing::debug; use crate::abi::FnAbiLlvmExt; use crate::builder::Builder; use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; +use crate::builder::gpu_offload::TgtOffloadEntry; use crate::context::CodegenCx; use crate::errors::AutoDiffWithoutEnable; use crate::llvm::{self, Metadata, Type, Value}; @@ -195,6 +196,10 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { codegen_autodiff(self, tcx, instance, args, result); return Ok(()); } + sym::offload => { + codegen_offload(self, tcx, instance, args, result); + return Ok(()); + } sym::is_val_statically_known => { if let OperandValue::Immediate(imm) = args[0].val { self.call_intrinsic( @@ -1227,6 +1232,72 @@ fn codegen_autodiff<'ll, 'tcx>( ); } +fn codegen_offload<'ll, 'tcx>( + bx: &mut Builder<'_, 'll, 'tcx>, + tcx: TyCtxt<'tcx>, + instance: ty::Instance<'tcx>, + _args: &[OperandRef<'tcx, &'ll Value>], + _result: PlaceRef<'tcx, &'ll Value>, +) { + let cx = bx.cx; + let fn_args = instance.args; + + let (target_id, target_args) = match fn_args.into_type_list(tcx)[0].kind() { + ty::FnDef(def_id, params) => (def_id, params), + _ => bug!("invalid offload intrinsic arg"), + }; + + let fn_target = match Instance::try_resolve(tcx, cx.typing_env(), *target_id, target_args) { + Ok(Some(instance)) => instance, + Ok(None) => bug!( + "could not resolve ({:?}, {:?}) to a specific offload instance", + target_id, + target_args + ), + Err(_) => { + // An error has already been emitted + return; + } + }; + + // TODO(Sa4dUs): Will need typetrees + let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE); + let Some(kernel) = cx.get_function(&target_symbol) else { + bug!("could not find target function") + }; + + let offload_entry_ty = TgtOffloadEntry::new_decl(&cx); + + // Build TypeTree (or something similar) + let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder(); + let inputs = sig.inputs(); + + // TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory + let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling( + cx, + tcx, + kernel, + offload_entry_ty, + inputs.to_vec(), + &target_symbol, + ); + + let kernels = &[kernel]; + + let llfn = bx.llfn(); + + // TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix + let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) }; + crate::builder::gpu_offload::gen_call_handling( + cx, + bb, + kernels, + &[memtransfer_type], + &[region_id], + llfn, + ); +} + fn get_args_from_tuple<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, tuple_op: OperandRef<'tcx, &'ll Value>, diff --git a/compiler/rustc_codegen_llvm/src/lib.rs b/compiler/rustc_codegen_llvm/src/lib.rs index aaf1f6fbc804a..71afe9d9b363d 100644 --- a/compiler/rustc_codegen_llvm/src/lib.rs +++ b/compiler/rustc_codegen_llvm/src/lib.rs @@ -4,6 +4,8 @@ //! //! This API is completely unstable and subject to change. +// TODO(Sa4dUs): remove this once we have a great version, just to ignore unused LLVM wrappers +#![allow(unused)] // tidy-alphabetical-start #![feature(assert_matches)] #![feature(extern_types)] diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index a8e8830db99d7..737337b901f4f 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -163,6 +163,7 @@ fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -> hi | sym::minnumf128 | sym::mul_with_overflow | sym::needs_drop + | sym::offload | sym::powf16 | sym::powf32 | sym::powf64 @@ -310,6 +311,7 @@ pub(crate) fn check_intrinsic_type( let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity(); (0, 0, vec![type_id, type_id], tcx.types.bool) } + sym::offload => (2, 0, vec![param(0)], param(1)), sym::offset => (2, 0, vec![param(0), param(1)], param(0)), sym::arith_offset => ( 1, diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index 38718bad9e57e..7241aac6eb22d 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -1578,6 +1578,7 @@ symbols! { object_safe_for_dispatch, of, off, + offload, offset, offset_of, offset_of_enum, diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index 5ba2d92a4596f..dafc88e66ed2c 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3276,6 +3276,10 @@ pub const fn copysignf128(x: f128, y: f128) -> f128; #[rustc_intrinsic] pub const fn autodiff(f: F, df: G, args: T) -> R; +#[rustc_nounwind] +#[rustc_intrinsic] +pub const fn offload(f: F) -> R; + /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] #[rustc_allow_const_fn_unstable(const_eval_select)] diff --git a/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs b/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs new file mode 100644 index 0000000000000..739186abc4f45 --- /dev/null +++ b/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs @@ -0,0 +1,37 @@ +//@ compile-flags: -Zoffload=Enable -Zunstable-options -C opt-level=0 -Clto=fat +//@ no-prefer-dynamic +//@ needs-enzyme + +// This test is verifying that we generate __tgt_target_data_*_mapper before and after a call to the +// kernel_1. Better documentation to what each global or variable means is available in the gpu +// offlaod code, or the LLVM offload documentation. This code does not launch any GPU kernels yet, +// and will be rewritten once a proper offload frontend has landed. +// +// We currently only handle memory transfer for specific calls to functions named `kernel_{num}`, +// when inside of a function called main. This, too, is a temporary workaround for not having a +// frontend. + +// CHECK: ; +#![feature(core_intrinsics)] +#![no_main] + +#[unsafe(no_mangle)] +fn main() { + let mut x = [3.0; 256]; + kernel(&mut x); + core::hint::black_box(&x); +} + +#[unsafe(no_mangle)] +#[inline(never)] +pub fn kernel(x: &mut [f32; 256]) { + core::intrinsics::offload(_kernel) +} + +#[unsafe(no_mangle)] +#[inline(never)] +pub fn _kernel(x: &mut [f32; 256]) { + for i in 0..256 { + x[i] = 21.0; + } +} From 9f3ccf3940811f76c85c3e4c5026db9c219eee19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 21 Oct 2025 09:49:12 +0200 Subject: [PATCH 2/7] Add basic offload metadata --- .../src/builder/gpu_offload.rs | 65 ++--------------- compiler/rustc_codegen_llvm/src/intrinsic.rs | 12 ++-- compiler/rustc_middle/src/ty/mod.rs | 1 + compiler/rustc_middle/src/ty/offload_meta.rs | 70 +++++++++++++++++++ 4 files changed, 84 insertions(+), 64 deletions(-) create mode 100644 compiler/rustc_middle/src/ty/offload_meta.rs diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index c2df6489a726e..b5a15673d1833 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -4,6 +4,7 @@ use llvm::Linkage::*; use rustc_abi::Align; use rustc_codegen_ssa::back::write::CodegenContext; use rustc_codegen_ssa::traits::BaseTypeCodegenMethods; +use rustc_middle::ty::offload_meta::OffloadMetadata; use rustc_middle::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; use crate::builder::SBuilder; @@ -263,8 +264,7 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( tcx: TyCtxt<'tcx>, kernel: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, - // TODO(Sa4dUs): Define a typetree once i have a better idea of what do we exactly need - tt: Vec>, + metadata: Vec, symbol: &str, ) -> (&'ll llvm::Value, &'ll llvm::Value) { let types = cx.func_params_types(cx.get_type_of_global(kernel)); @@ -275,12 +275,11 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)) .count(); - // TODO(Sa4dUs): Add typetrees here let ptr_sizes = types .iter() - .zip(tt) - .filter_map(|(&x, ty)| match cx.type_kind(x) { - rustc_codegen_ssa::common::TypeKind::Pointer => Some(get_payload_size(tcx, ty)), + .zip(metadata) + .filter_map(|(&x, meta)| match cx.type_kind(x) { + rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta.payload_size), _ => None, }) .collect::>(); @@ -335,56 +334,6 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( (memtransfer_types, region_id) } -// TODO(Sa4dUs): move this to a proper place -fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 { - match ty.kind() { - /* - rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Char => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Str => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(), - */ - ty::Ref(_, inner, _) => get_payload_size(tcx, *inner), - /* - rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Never => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(), - */ - _ => { - tcx - // TODO(Sa4dUs): Maybe `.as_query_input()`? - .layout_of(PseudoCanonicalInput { - typing_env: TypingEnv::fully_monomorphized(), - value: ty, - }) - .unwrap() - .size - .bytes() - } - } -} - fn declare_offload_fn<'ll>( cx: &'ll SimpleCx<'_>, name: &str, @@ -423,7 +372,7 @@ fn declare_offload_fn<'ll>( pub(crate) fn gen_call_handling<'ll>( cx: &SimpleCx<'ll>, bb: &BasicBlock, - kernels: &[&'ll llvm::Value], + kernel: &'ll llvm::Value, memtransfer_types: &[&'ll llvm::Value], region_ids: &[&'ll llvm::Value], llfn: &'ll Value, @@ -441,7 +390,7 @@ pub(crate) fn gen_call_handling<'ll>( let mut builder = SBuilder::build(cx, bb); - let types = cx.func_params_types(cx.get_type_of_global(kernels[0])); + let types = cx.func_params_types(cx.get_type_of_global(kernel)); let num_args = types.len() as u64; // Step 0) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 24b521dab7b03..e492e91682754 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -13,6 +13,7 @@ use rustc_hir::def_id::LOCAL_CRATE; use rustc_hir::{self as hir}; use rustc_middle::mir::BinOp; use rustc_middle::ty::layout::{FnAbiOf, HasTyCtxt, HasTypingEnv, LayoutOf}; +use rustc_middle::ty::offload_meta::OffloadMetadata; use rustc_middle::ty::{self, GenericArgsRef, Instance, SimdAlign, Ty, TyCtxt, TypingEnv}; use rustc_middle::{bug, span_bug}; use rustc_span::{Span, Symbol, sym}; @@ -1260,7 +1261,6 @@ fn codegen_offload<'ll, 'tcx>( } }; - // TODO(Sa4dUs): Will need typetrees let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE); let Some(kernel) = cx.get_function(&target_symbol) else { bug!("could not find target function") @@ -1272,26 +1272,26 @@ fn codegen_offload<'ll, 'tcx>( let sig = tcx.fn_sig(fn_target.def_id()).skip_binder().skip_binder(); let inputs = sig.inputs(); + let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::>(); + // TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling( cx, tcx, kernel, offload_entry_ty, - inputs.to_vec(), + metadata, &target_symbol, ); - let kernels = &[kernel]; - let llfn = bx.llfn(); - // TODO(Sa4dUs): this is a patch for delaying lifetime's issue fix + // TODO(Sa4dUs): this is just to a void lifetime's issues let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) }; crate::builder::gpu_offload::gen_call_handling( cx, bb, - kernels, + kernel, &[memtransfer_type], &[region_id], llfn, diff --git a/compiler/rustc_middle/src/ty/mod.rs b/compiler/rustc_middle/src/ty/mod.rs index d253deb2fe8fd..8ce6a9a6de3d5 100644 --- a/compiler/rustc_middle/src/ty/mod.rs +++ b/compiler/rustc_middle/src/ty/mod.rs @@ -130,6 +130,7 @@ pub mod fast_reject; pub mod inhabitedness; pub mod layout; pub mod normalize_erasing_regions; +pub mod offload_meta; pub mod pattern; pub mod print; pub mod relate; diff --git a/compiler/rustc_middle/src/ty/offload_meta.rs b/compiler/rustc_middle/src/ty/offload_meta.rs new file mode 100644 index 0000000000000..e7159888a643d --- /dev/null +++ b/compiler/rustc_middle/src/ty/offload_meta.rs @@ -0,0 +1,70 @@ +use crate::ty::{self, PseudoCanonicalInput, Ty, TyCtxt, TypingEnv}; + +// TODO(Sa4dUs): it doesn't feel correct for me to place this on `rustc_ast::expand`, will look for a proper location +pub struct OffloadMetadata { + pub payload_size: u64, + pub mode: TransferKind, +} + +pub enum TransferKind { + FromGpu = 1, + ToGpu = 2, + Both = 3, +} + +impl OffloadMetadata { + pub fn new(payload_size: u64, mode: TransferKind) -> Self { + OffloadMetadata { payload_size, mode } + } + + pub fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self { + OffloadMetadata { payload_size: get_payload_size(tcx, ty), mode: TransferKind::Both } + } +} + +// TODO(Sa4dUs): WIP, rn we just have a naive logic for references +fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 { + match ty.kind() { + /* + rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Char => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Str => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(), + */ + ty::Ref(_, inner, _) => get_payload_size(tcx, *inner), + /* + rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Never => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(), + rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(), + */ + _ => tcx + .layout_of(PseudoCanonicalInput { + typing_env: TypingEnv::fully_monomorphized(), + value: ty, + }) + .unwrap() + .size + .bytes(), + } +} From 0915824adeeffaa806f05fdeb74cf8c578a974ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 27 Oct 2025 12:30:19 +0100 Subject: [PATCH 3/7] Set maptypes using offload metadata --- .../src/builder/gpu_offload.rs | 21 +++---- compiler/rustc_middle/src/ty/offload_meta.rs | 56 ++++++++++++++++++- .../gpu_offload/offload_intrinsic.rs | 2 +- 3 files changed, 63 insertions(+), 16 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index b5a15673d1833..69518358b5b63 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -270,19 +270,17 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( let types = cx.func_params_types(cx.get_type_of_global(kernel)); // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or // reference) types. - let num_ptr_types = types - .iter() - .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer)) - .count(); - - let ptr_sizes = types + let ptr_meta = types .iter() .zip(metadata) .filter_map(|(&x, meta)| match cx.type_kind(x) { - rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta.payload_size), + rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta), _ => None, }) - .collect::>(); + .collect::>(); + + let ptr_sizes = ptr_meta.iter().map(|m| m.payload_size).collect::>(); + let ptr_transfer = ptr_meta.iter().map(|m| m.mode as u64 | 0x20).collect::>(); // We do not know their size anymore at this level, so hardcode a placeholder. // A follow-up pr will track these from the frontend, where we still have Rust types. @@ -294,11 +292,8 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten // will be 2. For now, everything is 3, until we have our frontend set up. // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later). - let memtransfer_types = add_priv_unnamed_arr( - &cx, - &format!(".offload_maptypes.{symbol}"), - &vec![1 + 2 + 32; num_ptr_types], - ); + let memtransfer_types = + add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer); // Next: For each function, generate these three entries. A weak constant, // the llvm.rodata entry name, and the llvm_offload_entries value diff --git a/compiler/rustc_middle/src/ty/offload_meta.rs b/compiler/rustc_middle/src/ty/offload_meta.rs index e7159888a643d..7c1b42b8cc08b 100644 --- a/compiler/rustc_middle/src/ty/offload_meta.rs +++ b/compiler/rustc_middle/src/ty/offload_meta.rs @@ -6,10 +6,13 @@ pub struct OffloadMetadata { pub mode: TransferKind, } +// TODO(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` flag only when needed +#[repr(u64)] +#[derive(Debug, Copy, Clone)] pub enum TransferKind { FromGpu = 1, ToGpu = 2, - Both = 3, + Both = 1 + 2, } impl OffloadMetadata { @@ -18,7 +21,10 @@ impl OffloadMetadata { } pub fn from_ty<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self { - OffloadMetadata { payload_size: get_payload_size(tcx, ty), mode: TransferKind::Both } + OffloadMetadata { + payload_size: get_payload_size(tcx, ty), + mode: TransferKind::from_ty(tcx, ty), + } } } @@ -68,3 +74,49 @@ fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 { .bytes(), } } + +impl TransferKind { + pub fn from_ty<'tcx>(_tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self { + // TODO(Sa4dUs): this logic is probs not fully correct, but it works for now + match ty.kind() { + rustc_type_ir::TyKind::Bool + | rustc_type_ir::TyKind::Char + | rustc_type_ir::TyKind::Int(_) + | rustc_type_ir::TyKind::Uint(_) + | rustc_type_ir::TyKind::Float(_) => TransferKind::ToGpu, + + rustc_type_ir::TyKind::Adt(_, _) + | rustc_type_ir::TyKind::Tuple(_) + | rustc_type_ir::TyKind::Array(_, _) => TransferKind::ToGpu, + + rustc_type_ir::TyKind::RawPtr(_, rustc_ast::Mutability::Not) + | rustc_type_ir::TyKind::Ref(_, _, rustc_ast::Mutability::Not) => TransferKind::ToGpu, + + rustc_type_ir::TyKind::RawPtr(_, rustc_ast::Mutability::Mut) + | rustc_type_ir::TyKind::Ref(_, _, rustc_ast::Mutability::Mut) => TransferKind::Both, + + rustc_type_ir::TyKind::Slice(_) + | rustc_type_ir::TyKind::Str + | rustc_type_ir::TyKind::Dynamic(_, _) => TransferKind::Both, + + rustc_type_ir::TyKind::FnDef(_, _) + | rustc_type_ir::TyKind::FnPtr(_, _) + | rustc_type_ir::TyKind::Closure(_, _) + | rustc_type_ir::TyKind::CoroutineClosure(_, _) + | rustc_type_ir::TyKind::Coroutine(_, _) + | rustc_type_ir::TyKind::CoroutineWitness(_, _) => TransferKind::ToGpu, + + rustc_type_ir::TyKind::Alias(_, _) + | rustc_type_ir::TyKind::Param(_) + | rustc_type_ir::TyKind::Bound(_, _) + | rustc_type_ir::TyKind::Placeholder(_) + | rustc_type_ir::TyKind::Infer(_) + | rustc_type_ir::TyKind::Error(_) => TransferKind::ToGpu, + + rustc_type_ir::TyKind::Never => TransferKind::ToGpu, + rustc_type_ir::TyKind::Foreign(_) => TransferKind::Both, + rustc_type_ir::TyKind::Pat(_, _) => TransferKind::Both, + rustc_type_ir::TyKind::UnsafeBinder(_) => TransferKind::Both, + } + } +} diff --git a/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs b/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs index 739186abc4f45..c3df15e3be6bd 100644 --- a/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs +++ b/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs @@ -1,4 +1,4 @@ -//@ compile-flags: -Zoffload=Enable -Zunstable-options -C opt-level=0 -Clto=fat +//@ compile-flags: -Zoffload=Enable -Zunstable-options -C opt-level=3 -Clto=fat //@ no-prefer-dynamic //@ needs-enzyme From 5e02ffe93c6e457030d6d65794e9927d1cf828ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Mon, 27 Oct 2025 20:32:46 +0100 Subject: [PATCH 4/7] Pass frontend info to `gen_call_handling` --- compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs | 3 ++- compiler/rustc_codegen_llvm/src/intrinsic.rs | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index 69518358b5b63..cbabc2c27106c 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -371,6 +371,7 @@ pub(crate) fn gen_call_handling<'ll>( memtransfer_types: &[&'ll llvm::Value], region_ids: &[&'ll llvm::Value], llfn: &'ll Value, + metadata: Vec, ) { let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } @@ -441,7 +442,7 @@ pub(crate) fn gen_call_handling<'ll>( // As mentioned above, we don't use Rust type information yet. So for now we will just // assume that we have 1024 bytes, 256 f32 values. // FIXME(offload): write an offload frontend and handle arbitrary types. - builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT); + builder.store(cx.get_const_i64(metadata[i].payload_size), gep3, Align::EIGHT); } // For now we have a very simplistic indexing scheme into our diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index e492e91682754..f913b6a697d36 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1295,6 +1295,7 @@ fn codegen_offload<'ll, 'tcx>( &[memtransfer_type], &[region_id], llfn, + metadata, ); } From c534f99251a25da50e6de35e487985c2e6aaf44f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Tue, 4 Nov 2025 18:58:24 +0100 Subject: [PATCH 5/7] Mark globals as used + some minor fixes --- .../src/builder/gpu_offload.rs | 62 +++++++++++++++---- compiler/rustc_codegen_llvm/src/intrinsic.rs | 14 ++--- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 1 + tests/codegen-llvm/gpu_offload/gpu_host.rs | 26 +++++--- .../gpu_offload/offload_intrinsic.rs | 37 ----------- 5 files changed, 72 insertions(+), 68 deletions(-) delete mode 100644 tests/codegen-llvm/gpu_offload/offload_intrinsic.rs diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index cbabc2c27106c..de7245bafec83 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -14,7 +14,7 @@ use crate::{LlvmCodegenBackend, SimpleCx, attributes}; pub(crate) fn handle_gpu_code<'ll>( _cgcx: &CodegenContext, - _cx: &'ll SimpleCx<'_>, + cx: &'ll SimpleCx<'_>, ) { /* // The offload memory transfer type for each kernel @@ -259,15 +259,14 @@ pub(crate) fn add_global<'ll>( // This function returns a memtransfer value which encodes how arguments to this kernel shall be // mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be // concatenated into the list of region_ids. -pub(crate) fn gen_define_handling<'ll, 'tcx>( +pub(crate) fn gen_define_handling<'ll>( cx: &SimpleCx<'ll>, - tcx: TyCtxt<'tcx>, - kernel: &'ll llvm::Value, + llfn: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, - metadata: Vec, + metadata: &Vec, symbol: &str, ) -> (&'ll llvm::Value, &'ll llvm::Value) { - let types = cx.func_params_types(cx.get_type_of_global(kernel)); + let types = cx.func_params_types(cx.get_type_of_global(llfn)); // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or // reference) types. let ptr_meta = types @@ -277,7 +276,7 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta), _ => None, }) - .collect::>(); + .collect::>(); let ptr_sizes = ptr_meta.iter().map(|m| m.payload_size).collect::>(); let ptr_transfer = ptr_meta.iter().map(|m| m.mode as u64 | 0x20).collect::>(); @@ -286,7 +285,7 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( // A follow-up pr will track these from the frontend, where we still have Rust types. // Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes. // I decided that 1024 bytes is a great placeholder value for now. - add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes); + let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes); // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2), // or both to and from the gpu (=3). Other values shouldn't affect us for now. // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten @@ -326,6 +325,8 @@ pub(crate) fn gen_define_handling<'ll, 'tcx>( llvm::set_alignment(llglobal, Align::EIGHT); let c_section_name = CString::new("llvm_offload_entries").unwrap(); llvm::set_section(llglobal, &c_section_name); + + add_to_llvm_used(cx, &[offload_sizes, memtransfer_types, region_id, llglobal]); (memtransfer_types, region_id) } @@ -367,11 +368,10 @@ fn declare_offload_fn<'ll>( pub(crate) fn gen_call_handling<'ll>( cx: &SimpleCx<'ll>, bb: &BasicBlock, - kernel: &'ll llvm::Value, memtransfer_types: &[&'ll llvm::Value], region_ids: &[&'ll llvm::Value], llfn: &'ll Value, - metadata: Vec, + metadata: &Vec, ) { let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } @@ -386,7 +386,7 @@ pub(crate) fn gen_call_handling<'ll>( let mut builder = SBuilder::build(cx, bb); - let types = cx.func_params_types(cx.get_type_of_global(kernel)); + let types = cx.func_params_types(cx.get_type_of_global(llfn)); let num_args = types.len() as u64; // Step 0) @@ -442,7 +442,7 @@ pub(crate) fn gen_call_handling<'ll>( // As mentioned above, we don't use Rust type information yet. So for now we will just // assume that we have 1024 bytes, 256 f32 values. // FIXME(offload): write an offload frontend and handle arbitrary types. - builder.store(cx.get_const_i64(metadata[i].payload_size), gep3, Align::EIGHT); + builder.store(cx.get_const_i64(metadata[i as usize].payload_size), gep3, Align::EIGHT); } // For now we have a very simplistic indexing scheme into our @@ -517,3 +517,41 @@ pub(crate) fn gen_call_handling<'ll>( drop(builder); } + +// TODO(Sa4dUs): check if there's a better way of doing this, also move to a proper location +fn add_to_llvm_used<'ll>(cx: &'ll SimpleCx<'_>, globals: &[&'ll Value]) { + let ptr_ty = cx.type_ptr(); + let arr_ty = cx.type_array(ptr_ty, globals.len() as u64); + let arr_val = cx.const_array(ptr_ty, globals); + + let name = CString::new("llvm.used").unwrap(); + + let used_global_opt = unsafe { llvm::LLVMGetNamedGlobal(cx.llmod, name.as_ptr()) }; + + if used_global_opt.is_none() { + let new_global = unsafe { llvm::LLVMAddGlobal(cx.llmod, arr_ty, name.as_ptr()) }; + unsafe { llvm::LLVMSetLinkage(new_global, llvm::Linkage::AppendingLinkage) }; + unsafe { + llvm::LLVMSetSection(new_global, CString::new("llvm.metadata").unwrap().as_ptr()) + }; + unsafe { llvm::LLVMSetInitializer(new_global, arr_val) }; + llvm::LLVMSetGlobalConstant(new_global, llvm::TRUE); + return; + } + + let used_global = used_global_opt.expect("expected @llvm.used"); + let mut combined: Vec<&'ll Value> = Vec::new(); + + if let Some(existing_init) = llvm::LLVMGetInitializer(used_global) { + let num_elems = unsafe { llvm::LLVMGetNumOperands(existing_init) }; + for i in 0..num_elems { + if let Some(elem) = unsafe { llvm::LLVMGetOperand(existing_init, i) } { + combined.push(elem); + } + } + } + + combined.extend_from_slice(globals); + let new_arr = cx.const_array(ptr_ty, &combined); + unsafe { llvm::LLVMSetInitializer(used_global, new_arr) }; +} diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index f913b6a697d36..687d7e1e473a3 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -1262,9 +1262,6 @@ fn codegen_offload<'ll, 'tcx>( }; let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE); - let Some(kernel) = cx.get_function(&target_symbol) else { - bug!("could not find target function") - }; let offload_entry_ty = TgtOffloadEntry::new_decl(&cx); @@ -1273,29 +1270,26 @@ fn codegen_offload<'ll, 'tcx>( let inputs = sig.inputs(); let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::>(); + let llfn = bx.llfn(); // TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling( cx, - tcx, - kernel, + llfn, offload_entry_ty, - metadata, + &metadata, &target_symbol, ); - let llfn = bx.llfn(); - // TODO(Sa4dUs): this is just to a void lifetime's issues let bb = unsafe { llvm::LLVMGetInsertBlock(bx.llbuilder) }; crate::builder::gpu_offload::gen_call_handling( cx, bb, - kernel, &[memtransfer_type], &[region_id], llfn, - metadata, + &metadata, ); } diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index 74d268ad5dd2e..33795e0d6674b 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -1164,6 +1164,7 @@ unsafe extern "C" { pub(crate) fn LLVMGetOperand(Val: &Value, Index: c_uint) -> Option<&Value>; pub(crate) fn LLVMGetNextInstruction(Val: &Value) -> Option<&Value>; pub(crate) fn LLVMInstructionEraseFromParent(Val: &Value); + pub(crate) fn LLVMGetNumOperands(Val: &Value) -> c_uint; // Operations on call sites pub(crate) fn LLVMSetInstructionCallConv(Instr: &Value, CC: c_uint); diff --git a/tests/codegen-llvm/gpu_offload/gpu_host.rs b/tests/codegen-llvm/gpu_offload/gpu_host.rs index fac4054d1b7ff..69eea6a6a8cea 100644 --- a/tests/codegen-llvm/gpu_offload/gpu_host.rs +++ b/tests/codegen-llvm/gpu_offload/gpu_host.rs @@ -11,12 +11,13 @@ // when inside of a function called main. This, too, is a temporary workaround for not having a // frontend. +#![feature(core_intrinsics)] #![no_main] #[unsafe(no_mangle)] fn main() { let mut x = [3.0; 256]; - kernel_1(&mut x); + kernel(&mut x); core::hint::black_box(&x); } @@ -25,13 +26,14 @@ fn main() { // CHECK: %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } // CHECK: %struct.__tgt_kernel_arguments = type { i32, i32, ptr, ptr, ptr, ptr, ptr, ptr, i64, i64, [3 x i32], [3 x i32], i32 } -// CHECK: @.offload_sizes.1 = private unnamed_addr constant [1 x i64] [i64 1024] -// CHECK: @.offload_maptypes.1 = private unnamed_addr constant [1 x i64] [i64 35] -// CHECK: @.kernel_1.region_id = weak unnamed_addr constant i8 0 -// CHECK: @.offloading.entry_name.1 = internal unnamed_addr constant [9 x i8] c"kernel_1\00", section ".llvm.rodata.offloading", align 1 -// CHECK: @.offloading.entry.kernel_1 = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @.kernel_1.region_id, ptr @.offloading.entry_name.1, i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8 -// CHECK: @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 -// CHECK: @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8 +// CHECK: @.offload_sizes._kernel = private unnamed_addr constant [1 x i64] [i64 1024] +// CHECK: @.offload_maptypes._kernel = private unnamed_addr constant [1 x i64] [i64 35] +// CHECK: @._kernel.region_id = weak unnamed_addr constant i8 0 +// CHECK: @.offloading.entry_name._kernel = internal unnamed_addr constant [8 x i8] c"_kernel\00", section ".llvm.rodata.offloading", align 1 +// CHECK: @.offloading.entry._kernel = weak constant %struct.__tgt_offload_entry { i64 0, i16 1, i16 1, i32 0, ptr @._kernel.region_id, ptr @.offloading.entry_name._kernel, i64 0, i64 0, ptr null }, section "llvm_offload_entries", align 8 + +// CHECK: @anon.{{.*}}.0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1 +// CHECK: @anon.{{.*}}.1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @anon.{{.*}}.0 }, align 8 // CHECK: Function Attrs: // CHECK-NEXT: define{{( dso_local)?}} void @main() @@ -99,7 +101,13 @@ fn main() { #[unsafe(no_mangle)] #[inline(never)] -pub fn kernel_1(x: &mut [f32; 256]) { +pub fn kernel(x: &mut [f32; 256]) { + core::intrinsics::offload(_kernel) +} + +#[unsafe(no_mangle)] +#[inline(never)] +pub fn _kernel(x: &mut [f32; 256]) { for i in 0..256 { x[i] = 21.0; } diff --git a/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs b/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs deleted file mode 100644 index c3df15e3be6bd..0000000000000 --- a/tests/codegen-llvm/gpu_offload/offload_intrinsic.rs +++ /dev/null @@ -1,37 +0,0 @@ -//@ compile-flags: -Zoffload=Enable -Zunstable-options -C opt-level=3 -Clto=fat -//@ no-prefer-dynamic -//@ needs-enzyme - -// This test is verifying that we generate __tgt_target_data_*_mapper before and after a call to the -// kernel_1. Better documentation to what each global or variable means is available in the gpu -// offlaod code, or the LLVM offload documentation. This code does not launch any GPU kernels yet, -// and will be rewritten once a proper offload frontend has landed. -// -// We currently only handle memory transfer for specific calls to functions named `kernel_{num}`, -// when inside of a function called main. This, too, is a temporary workaround for not having a -// frontend. - -// CHECK: ; -#![feature(core_intrinsics)] -#![no_main] - -#[unsafe(no_mangle)] -fn main() { - let mut x = [3.0; 256]; - kernel(&mut x); - core::hint::black_box(&x); -} - -#[unsafe(no_mangle)] -#[inline(never)] -pub fn kernel(x: &mut [f32; 256]) { - core::intrinsics::offload(_kernel) -} - -#[unsafe(no_mangle)] -#[inline(never)] -pub fn _kernel(x: &mut [f32; 256]) { - for i in 0..256 { - x[i] = 21.0; - } -} From 834347a3fd64097c349f5d9a9f00bdf64d1514fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Thu, 6 Nov 2025 12:34:55 +0100 Subject: [PATCH 6/7] Get types from fn_sig --- compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs | 5 ++--- compiler/rustc_codegen_llvm/src/intrinsic.rs | 7 +++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index de7245bafec83..fab867ba8b53b 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -261,12 +261,11 @@ pub(crate) fn add_global<'ll>( // concatenated into the list of region_ids. pub(crate) fn gen_define_handling<'ll>( cx: &SimpleCx<'ll>, - llfn: &'ll llvm::Value, offload_entry_ty: &'ll llvm::Type, metadata: &Vec, + types: &Vec<&Type>, symbol: &str, ) -> (&'ll llvm::Value, &'ll llvm::Value) { - let types = cx.func_params_types(cx.get_type_of_global(llfn)); // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or // reference) types. let ptr_meta = types @@ -371,6 +370,7 @@ pub(crate) fn gen_call_handling<'ll>( memtransfer_types: &[&'ll llvm::Value], region_ids: &[&'ll llvm::Value], llfn: &'ll Value, + types: &Vec<&Type>, metadata: &Vec, ) { let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); @@ -386,7 +386,6 @@ pub(crate) fn gen_call_handling<'ll>( let mut builder = SBuilder::build(cx, bb); - let types = cx.func_params_types(cx.get_type_of_global(llfn)); let num_args = types.len() as u64; // Step 0) diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index 687d7e1e473a3..ffd30aebf29d1 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -21,7 +21,7 @@ use rustc_symbol_mangling::{mangle_internal_symbol, symbol_name_for_instance_in_ use rustc_target::callconv::PassMode; use tracing::debug; -use crate::abi::FnAbiLlvmExt; +use crate::abi::{FnAbiLlvmExt, LlvmType}; use crate::builder::Builder; use crate::builder::autodiff::{adjust_activity_to_abi, generate_enzyme_call}; use crate::builder::gpu_offload::TgtOffloadEntry; @@ -1272,12 +1272,14 @@ fn codegen_offload<'ll, 'tcx>( let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::>(); let llfn = bx.llfn(); + let types = inputs.iter().map(|ty| cx.layout_of(*ty).llvm_type(cx)).collect::>(); + // TODO(Sa4dUs): separate globals from call-independent headers and use typetrees to reserve the correct amount of memory let (memtransfer_type, region_id) = crate::builder::gpu_offload::gen_define_handling( cx, - llfn, offload_entry_ty, &metadata, + &types, &target_symbol, ); @@ -1289,6 +1291,7 @@ fn codegen_offload<'ll, 'tcx>( &[memtransfer_type], &[region_id], llfn, + &types, &metadata, ); } From 97a8e963f8e98a5baf8ced86d55a183122310a2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marcelo=20Dom=C3=ADnguez?= Date: Thu, 6 Nov 2025 20:30:44 +0100 Subject: [PATCH 7/7] Don't depend on outer fn and some cleanup --- .../src/builder/gpu_offload.rs | 30 +++--- compiler/rustc_codegen_llvm/src/intrinsic.rs | 11 +-- .../rustc_hir_analysis/src/check/intrinsic.rs | 2 +- compiler/rustc_middle/src/ty/offload_meta.rs | 96 +++++++------------ library/core/src/intrinsics/mod.rs | 2 +- tests/codegen-llvm/gpu_offload/gpu_host.rs | 2 +- 6 files changed, 53 insertions(+), 90 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs index fab867ba8b53b..151be8fcb4757 100644 --- a/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs +++ b/compiler/rustc_codegen_llvm/src/builder/gpu_offload.rs @@ -262,23 +262,19 @@ pub(crate) fn add_global<'ll>( pub(crate) fn gen_define_handling<'ll>( cx: &SimpleCx<'ll>, offload_entry_ty: &'ll llvm::Type, - metadata: &Vec, - types: &Vec<&Type>, + metadata: &[OffloadMetadata], + types: &[&Type], symbol: &str, ) -> (&'ll llvm::Value, &'ll llvm::Value) { // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or // reference) types. - let ptr_meta = types - .iter() - .zip(metadata) - .filter_map(|(&x, meta)| match cx.type_kind(x) { - rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta), - _ => None, - }) - .collect::>(); - - let ptr_sizes = ptr_meta.iter().map(|m| m.payload_size).collect::>(); - let ptr_transfer = ptr_meta.iter().map(|m| m.mode as u64 | 0x20).collect::>(); + let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) { + rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta), + _ => None, + }); + + let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) = + ptr_meta.map(|m| (m.payload_size, m.mode as u64 | 0x20)).unzip(); // We do not know their size anymore at this level, so hardcode a placeholder. // A follow-up pr will track these from the frontend, where we still have Rust types. @@ -369,9 +365,9 @@ pub(crate) fn gen_call_handling<'ll>( bb: &BasicBlock, memtransfer_types: &[&'ll llvm::Value], region_ids: &[&'ll llvm::Value], - llfn: &'ll Value, - types: &Vec<&Type>, - metadata: &Vec, + args: &[&'ll Value], + types: &[&Type], + metadata: &[OffloadMetadata], ) { let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx); // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr } @@ -413,7 +409,7 @@ pub(crate) fn gen_call_handling<'ll>( let mut geps = vec![]; let i32_0 = cx.get_const_i32(0); for index in 0..num_args { - let v = unsafe { llvm::LLVMGetParam(llfn, index as u32) }; + let v = args[index as usize]; let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]); vals.push(v); geps.push(gep); diff --git a/compiler/rustc_codegen_llvm/src/intrinsic.rs b/compiler/rustc_codegen_llvm/src/intrinsic.rs index ffd30aebf29d1..12936f2182af4 100644 --- a/compiler/rustc_codegen_llvm/src/intrinsic.rs +++ b/compiler/rustc_codegen_llvm/src/intrinsic.rs @@ -198,7 +198,7 @@ impl<'ll, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'_, 'll, 'tcx> { return Ok(()); } sym::offload => { - codegen_offload(self, tcx, instance, args, result); + codegen_offload(self, tcx, instance, args); return Ok(()); } sym::is_val_statically_known => { @@ -1237,8 +1237,7 @@ fn codegen_offload<'ll, 'tcx>( bx: &mut Builder<'_, 'll, 'tcx>, tcx: TyCtxt<'tcx>, instance: ty::Instance<'tcx>, - _args: &[OperandRef<'tcx, &'ll Value>], - _result: PlaceRef<'tcx, &'ll Value>, + args: &[OperandRef<'tcx, &'ll Value>], ) { let cx = bx.cx; let fn_args = instance.args; @@ -1261,7 +1260,8 @@ fn codegen_offload<'ll, 'tcx>( } }; - let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target.clone(), LOCAL_CRATE); + let args = get_args_from_tuple(bx, args[1], fn_target); + let target_symbol = symbol_name_for_instance_in_crate(tcx, fn_target, LOCAL_CRATE); let offload_entry_ty = TgtOffloadEntry::new_decl(&cx); @@ -1270,7 +1270,6 @@ fn codegen_offload<'ll, 'tcx>( let inputs = sig.inputs(); let metadata = inputs.iter().map(|ty| OffloadMetadata::from_ty(tcx, *ty)).collect::>(); - let llfn = bx.llfn(); let types = inputs.iter().map(|ty| cx.layout_of(*ty).llvm_type(cx)).collect::>(); @@ -1290,7 +1289,7 @@ fn codegen_offload<'ll, 'tcx>( bb, &[memtransfer_type], &[region_id], - llfn, + &args, &types, &metadata, ); diff --git a/compiler/rustc_hir_analysis/src/check/intrinsic.rs b/compiler/rustc_hir_analysis/src/check/intrinsic.rs index 737337b901f4f..38ab057b75876 100644 --- a/compiler/rustc_hir_analysis/src/check/intrinsic.rs +++ b/compiler/rustc_hir_analysis/src/check/intrinsic.rs @@ -311,7 +311,7 @@ pub(crate) fn check_intrinsic_type( let type_id = tcx.type_of(tcx.lang_items().type_id().unwrap()).instantiate_identity(); (0, 0, vec![type_id, type_id], tcx.types.bool) } - sym::offload => (2, 0, vec![param(0)], param(1)), + sym::offload => (3, 0, vec![param(0), param(1)], param(2)), sym::offset => (2, 0, vec![param(0), param(1)], param(0)), sym::arith_offset => ( 1, diff --git a/compiler/rustc_middle/src/ty/offload_meta.rs b/compiler/rustc_middle/src/ty/offload_meta.rs index 7c1b42b8cc08b..11a0ca2741bb4 100644 --- a/compiler/rustc_middle/src/ty/offload_meta.rs +++ b/compiler/rustc_middle/src/ty/offload_meta.rs @@ -31,39 +31,7 @@ impl OffloadMetadata { // TODO(Sa4dUs): WIP, rn we just have a naive logic for references fn get_payload_size<'tcx>(tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> u64 { match ty.kind() { - /* - rustc_middle::infer::canonical::ir::TyKind::Bool => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Char => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Int(int_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Uint(uint_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Float(float_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Adt(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Foreign(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Str => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Array(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Pat(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Slice(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::RawPtr(_, mutability) => todo!(), - */ - ty::Ref(_, inner, _) => get_payload_size(tcx, *inner), - /* - rustc_middle::infer::canonical::ir::TyKind::FnDef(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::FnPtr(binder, fn_header) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::UnsafeBinder(unsafe_binder_inner) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Dynamic(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Closure(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::CoroutineClosure(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Coroutine(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::CoroutineWitness(_, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Never => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Tuple(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Alias(alias_ty_kind, alias_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Param(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Bound(bound_var_index_kind, _) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Placeholder(_) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Infer(infer_ty) => todo!(), - rustc_middle::infer::canonical::ir::TyKind::Error(_) => todo!(), - */ + ty::RawPtr(inner, _) | ty::Ref(_, inner, _) => get_payload_size(tcx, *inner), _ => tcx .layout_of(PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), @@ -79,44 +47,44 @@ impl TransferKind { pub fn from_ty<'tcx>(_tcx: TyCtxt<'tcx>, ty: Ty<'tcx>) -> Self { // TODO(Sa4dUs): this logic is probs not fully correct, but it works for now match ty.kind() { - rustc_type_ir::TyKind::Bool - | rustc_type_ir::TyKind::Char - | rustc_type_ir::TyKind::Int(_) - | rustc_type_ir::TyKind::Uint(_) - | rustc_type_ir::TyKind::Float(_) => TransferKind::ToGpu, + ty::Bool + | ty::Char + | ty::Int(_) + | ty::Uint(_) + | ty::Float(_) => TransferKind::ToGpu, - rustc_type_ir::TyKind::Adt(_, _) - | rustc_type_ir::TyKind::Tuple(_) - | rustc_type_ir::TyKind::Array(_, _) => TransferKind::ToGpu, + ty::Adt(_, _) + | ty::Tuple(_) + | ty::Array(_, _) => TransferKind::ToGpu, - rustc_type_ir::TyKind::RawPtr(_, rustc_ast::Mutability::Not) - | rustc_type_ir::TyKind::Ref(_, _, rustc_ast::Mutability::Not) => TransferKind::ToGpu, + ty::RawPtr(_, rustc_ast::Mutability::Not) + | ty::Ref(_, _, rustc_ast::Mutability::Not) => TransferKind::ToGpu, - rustc_type_ir::TyKind::RawPtr(_, rustc_ast::Mutability::Mut) - | rustc_type_ir::TyKind::Ref(_, _, rustc_ast::Mutability::Mut) => TransferKind::Both, + ty::RawPtr(_, rustc_ast::Mutability::Mut) + | ty::Ref(_, _, rustc_ast::Mutability::Mut) => TransferKind::Both, - rustc_type_ir::TyKind::Slice(_) - | rustc_type_ir::TyKind::Str - | rustc_type_ir::TyKind::Dynamic(_, _) => TransferKind::Both, + ty::Slice(_) + | ty::Str + | ty::Dynamic(_, _) => TransferKind::Both, - rustc_type_ir::TyKind::FnDef(_, _) - | rustc_type_ir::TyKind::FnPtr(_, _) - | rustc_type_ir::TyKind::Closure(_, _) - | rustc_type_ir::TyKind::CoroutineClosure(_, _) - | rustc_type_ir::TyKind::Coroutine(_, _) - | rustc_type_ir::TyKind::CoroutineWitness(_, _) => TransferKind::ToGpu, + ty::FnDef(_, _) + | ty::FnPtr(_, _) + | ty::Closure(_, _) + | ty::CoroutineClosure(_, _) + | ty::Coroutine(_, _) + | ty::CoroutineWitness(_, _) => TransferKind::ToGpu, - rustc_type_ir::TyKind::Alias(_, _) - | rustc_type_ir::TyKind::Param(_) - | rustc_type_ir::TyKind::Bound(_, _) - | rustc_type_ir::TyKind::Placeholder(_) - | rustc_type_ir::TyKind::Infer(_) - | rustc_type_ir::TyKind::Error(_) => TransferKind::ToGpu, + ty::Alias(_, _) + | ty::Param(_) + | ty::Bound(_, _) + | ty::Placeholder(_) + | ty::Infer(_) + | ty::Error(_) => TransferKind::ToGpu, - rustc_type_ir::TyKind::Never => TransferKind::ToGpu, - rustc_type_ir::TyKind::Foreign(_) => TransferKind::Both, - rustc_type_ir::TyKind::Pat(_, _) => TransferKind::Both, - rustc_type_ir::TyKind::UnsafeBinder(_) => TransferKind::Both, + ty::Never => TransferKind::ToGpu, + ty::Foreign(_) => TransferKind::Both, + ty::Pat(_, _) => TransferKind::Both, + ty::UnsafeBinder(_) => TransferKind::Both, } } } diff --git a/library/core/src/intrinsics/mod.rs b/library/core/src/intrinsics/mod.rs index dafc88e66ed2c..8c8a9f51e0208 100644 --- a/library/core/src/intrinsics/mod.rs +++ b/library/core/src/intrinsics/mod.rs @@ -3278,7 +3278,7 @@ pub const fn autodiff(f: F, df: G, args: T) -> #[rustc_nounwind] #[rustc_intrinsic] -pub const fn offload(f: F) -> R; +pub const fn offload(f: F, args: T) -> R; /// Inform Miri that a given pointer definitely has a certain alignment. #[cfg(miri)] diff --git a/tests/codegen-llvm/gpu_offload/gpu_host.rs b/tests/codegen-llvm/gpu_offload/gpu_host.rs index 69eea6a6a8cea..8a469f42906bf 100644 --- a/tests/codegen-llvm/gpu_offload/gpu_host.rs +++ b/tests/codegen-llvm/gpu_offload/gpu_host.rs @@ -102,7 +102,7 @@ fn main() { #[unsafe(no_mangle)] #[inline(never)] pub fn kernel(x: &mut [f32; 256]) { - core::intrinsics::offload(_kernel) + core::intrinsics::offload(_kernel, (x,)) } #[unsafe(no_mangle)]