From 0ab78aa27073df4e54e749697fb60dd023bc3fcf Mon Sep 17 00:00:00 2001 From: Sonny Scroggin Date: Tue, 11 Nov 2025 23:10:04 -0600 Subject: [PATCH 1/4] Async NIFs --- rustler/Cargo.toml | 3 + rustler/src/lib.rs | 3 + rustler/src/tokio/mod.rs | 3 + rustler/src/tokio/runtime.rs | 218 ++++++++++++++++++ rustler/src/types/local_pid.rs | 52 +++++ rustler/src/types/mod.rs | 2 +- rustler_codegen/src/nif.rs | 168 +++++++++++++- rustler_tests/config/config.exs | 7 + rustler_tests/lib/rustler_test.ex | 5 + rustler_tests/native/rustler_test/Cargo.toml | 5 +- rustler_tests/native/rustler_test/src/lib.rs | 15 +- .../native/rustler_test/src/test_async.rs | 40 ++++ rustler_tests/test/async_test.exs | 70 ++++++ 13 files changed, 583 insertions(+), 8 deletions(-) create mode 100644 rustler/src/tokio/mod.rs create mode 100644 rustler/src/tokio/runtime.rs create mode 100644 rustler_tests/config/config.exs create mode 100644 rustler_tests/native/rustler_test/src/test_async.rs create mode 100644 rustler_tests/test/async_test.exs diff --git a/rustler/Cargo.toml b/rustler/Cargo.toml index 99d6eddb..0380c78e 100644 --- a/rustler/Cargo.toml +++ b/rustler/Cargo.toml @@ -19,12 +19,15 @@ nif_version_2_15 = ["nif_version_2_14"] nif_version_2_16 = ["nif_version_2_15"] nif_version_2_17 = ["nif_version_2_16"] serde = ["dep:serde"] +tokio_rt = ["dep:tokio"] [dependencies] inventory = "0.3" rustler_codegen = { path = "../rustler_codegen", version = "0.37.1"} num-bigint = { version = "0.4", optional = true } serde = { version = "1", optional = true } +tokio = { version = "1", optional = true, features = ["rt", "rt-multi-thread", "sync"] } +once_cell = "1" [target.'cfg(not(windows))'.dependencies] libloading = "0.9" diff --git a/rustler/src/lib.rs b/rustler/src/lib.rs index b046f151..4d0fdf2c 100644 --- a/rustler/src/lib.rs +++ b/rustler/src/lib.rs @@ -83,4 +83,7 @@ pub mod serde; #[cfg(feature = "serde")] pub use crate::serde::SerdeTerm; +#[cfg(feature = "tokio_rt")] +pub mod tokio; + pub mod sys; diff --git a/rustler/src/tokio/mod.rs b/rustler/src/tokio/mod.rs new file mode 100644 index 00000000..3e28c799 --- /dev/null +++ b/rustler/src/tokio/mod.rs @@ -0,0 +1,3 @@ +mod runtime; + +pub use runtime::{configure, configure_runtime, runtime_handle, ConfigError, RuntimeConfig}; diff --git a/rustler/src/tokio/runtime.rs b/rustler/src/tokio/runtime.rs new file mode 100644 index 00000000..d9857006 --- /dev/null +++ b/rustler/src/tokio/runtime.rs @@ -0,0 +1,218 @@ +use crate::{Decoder, NifResult, Term}; +use once_cell::sync::OnceCell; +use std::sync::Arc; +use tokio::runtime::Runtime; + +/// Global tokio runtime for async NIFs. +/// +/// This runtime can be configured via `configure_runtime()` in your NIF's `load` callback, +/// or will be lazily initialized with default settings on first use. +static TOKIO_RUNTIME: OnceCell> = OnceCell::new(); + +/// Error type for runtime configuration failures. +#[derive(Debug)] +pub enum ConfigError { + /// The runtime has already been initialized (either by configuration or first use). + AlreadyInitialized, + /// Failed to build the Tokio runtime. + BuildFailed(std::io::Error), +} + +impl std::fmt::Display for ConfigError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + ConfigError::AlreadyInitialized => { + write!(f, "Tokio runtime already initialized") + } + ConfigError::BuildFailed(e) => { + write!(f, "Failed to build Tokio runtime: {}", e) + } + } + } +} + +impl std::error::Error for ConfigError {} + +/// Configuration options for the Tokio runtime. +/// +/// These can be passed from Elixir via the `load_data` option: +/// +/// ```elixir +/// use Rustler, +/// otp_app: :my_app, +/// crate: :my_nif, +/// load_data: [ +/// worker_threads: 4, +/// thread_name: "my-runtime" +/// ] +/// ``` +#[derive(Debug, Clone)] +pub struct RuntimeConfig { + /// Number of worker threads for the runtime. + /// If not specified, uses Tokio's default (number of CPU cores). + pub worker_threads: Option, + + /// Thread name prefix for worker threads. + /// If not specified, uses "rustler-tokio". + pub thread_name: Option, + + /// Stack size for worker threads in bytes. + /// If not specified, uses Tokio's default. + pub thread_stack_size: Option, +} + +impl Default for RuntimeConfig { + fn default() -> Self { + RuntimeConfig { + worker_threads: None, + thread_name: Some("rustler-tokio".to_string()), + thread_stack_size: None, + } + } +} + +impl<'a> Decoder<'a> for RuntimeConfig { + fn decode(term: Term<'a>) -> NifResult { + use crate::types::map::MapIterator; + use crate::Error; + + let mut config = RuntimeConfig::default(); + + // Try to decode as a map/keyword list + let map_iter = MapIterator::new(term).ok_or(Error::BadArg)?; + + for (key, value) in map_iter { + let key_str: String = key.decode()?; + + match key_str.as_str() { + "worker_threads" => { + config.worker_threads = Some(value.decode()?); + } + "thread_name" => { + config.thread_name = Some(value.decode()?); + } + "thread_stack_size" => { + config.thread_stack_size = Some(value.decode()?); + } + _ => { + // Ignore unknown options for forward compatibility + } + } + } + + Ok(config) + } +} + +/// Configure the global Tokio runtime from Elixir load_data. +/// +/// This is the recommended way to configure the runtime, allowing Elixir application +/// developers to tune the runtime without recompiling the NIF. +/// +/// # Example +/// +/// ```ignore +/// use rustler::{Env, Term}; +/// +/// fn load(_env: Env, load_info: Term) -> bool { +/// // Try to decode runtime config from load_info +/// if let Ok(config) = load_info.decode::() { +/// rustler::tokio::configure(config) +/// .expect("Failed to configure Tokio runtime"); +/// } +/// true +/// } +/// ``` +/// +/// In your Elixir config: +/// +/// ```elixir +/// # config/config.exs +/// config :my_app, MyNif, +/// load_data: [ +/// worker_threads: 4, +/// thread_name: "my-runtime" +/// ] +/// ``` +pub fn configure(config: RuntimeConfig) -> Result<(), ConfigError> { + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder.enable_all(); + + // Apply configuration + if let Some(threads) = config.worker_threads { + builder.worker_threads(threads); + } + + if let Some(name) = config.thread_name { + builder.thread_name(name); + } + + if let Some(stack_size) = config.thread_stack_size { + builder.thread_stack_size(stack_size); + } + + let runtime = builder.build().map_err(ConfigError::BuildFailed)?; + + TOKIO_RUNTIME + .set(Arc::new(runtime)) + .map_err(|_| ConfigError::AlreadyInitialized) +} + +/// Configure the global Tokio runtime programmatically. +/// +/// This provides direct access to the Tokio Builder API for advanced use cases. +/// For most applications, prefer `configure_runtime_from_term` which allows +/// configuration from Elixir. +/// +/// # Example +/// +/// ```ignore +/// use rustler::{Env, Term}; +/// +/// fn load(_env: Env, _: Term) -> bool { +/// rustler::tokio::configure_runtime(|builder| { +/// builder +/// .worker_threads(4) +/// .thread_name("myapp-tokio") +/// .thread_stack_size(3 * 1024 * 1024); +/// }).expect("Failed to configure Tokio runtime"); +/// +/// true +/// } +/// ``` +pub fn configure_runtime(config_fn: F) -> Result<(), ConfigError> +where + F: FnOnce(&mut tokio::runtime::Builder), +{ + let mut builder = tokio::runtime::Builder::new_multi_thread(); + builder.enable_all(); + + // Allow user to customize + config_fn(&mut builder); + + let runtime = builder.build().map_err(ConfigError::BuildFailed)?; + + TOKIO_RUNTIME + .set(Arc::new(runtime)) + .map_err(|_| ConfigError::AlreadyInitialized) +} + +/// Get a handle to the global tokio runtime, or the current runtime if already inside one. +pub fn runtime_handle() -> tokio::runtime::Handle { + // Try to get the current runtime handle first (if already in a tokio context) + tokio::runtime::Handle::try_current().unwrap_or_else(|_| { + // Get or initialize with default configuration + TOKIO_RUNTIME + .get_or_init(|| { + Arc::new( + tokio::runtime::Builder::new_multi_thread() + .enable_all() + .thread_name("rustler-tokio") + .build() + .expect("Failed to create default tokio runtime for async NIFs"), + ) + }) + .handle() + .clone() + }) +} diff --git a/rustler/src/types/local_pid.rs b/rustler/src/types/local_pid.rs index 7336dedd..166dc3b1 100644 --- a/rustler/src/types/local_pid.rs +++ b/rustler/src/types/local_pid.rs @@ -63,6 +63,58 @@ impl Ord for LocalPid { } } +/// A wrapper for `LocalPid` that represents the calling process in async NIFs. +/// +/// When used as the first parameter of an async NIF, `CallerPid` is automatically +/// populated with the calling process's PID, and is not decoded from the arguments. +/// This allows async NIFs to send intermediate messages back to the caller. +/// +/// # Example +/// +/// ```ignore +/// #[rustler::nif] +/// async fn with_progress(caller: CallerPid, work: Vec) -> i64 { +/// // Send progress updates +/// let mut env = OwnedEnv::new(); +/// env.send(caller.as_pid(), |e| "started".encode(e)); +/// +/// let result = do_work(work).await; +/// +/// // Final result sent automatically +/// result +/// } +/// ``` +#[derive(Copy, Clone)] +pub struct CallerPid(LocalPid); + +impl CallerPid { + /// Create a new CallerPid from a LocalPid. + /// + /// This is only used internally by the NIF macro. + #[doc(hidden)] + pub fn new(pid: LocalPid) -> Self { + CallerPid(pid) + } + + /// Get the underlying LocalPid. + pub fn as_pid(&self) -> &LocalPid { + &self.0 + } + + /// Check whether the calling process is alive. + pub fn is_alive(self, env: Env) -> bool { + self.0.is_alive(env) + } +} + +impl std::ops::Deref for CallerPid { + type Target = LocalPid; + + fn deref(&self) -> &Self::Target { + &self.0 + } +} + impl Env<'_> { /// Return the calling process's pid. /// diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index c7b72005..c1b7d2f3 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -28,7 +28,7 @@ pub mod tuple; #[doc(hidden)] pub mod local_pid; -pub use self::local_pid::LocalPid; +pub use self::local_pid::{CallerPid, LocalPid}; #[doc(hidden)] pub mod reference; diff --git a/rustler_codegen/src/nif.rs b/rustler_codegen/src/nif.rs index 46674423..c8909034 100644 --- a/rustler_codegen/src/nif.rs +++ b/rustler_codegen/src/nif.rs @@ -39,12 +39,11 @@ pub fn transcoder_decorator(nif_attributes: NifAttributes, fun: syn::ItemFn) -> let sig = &fun.sig; let name = &sig.ident; let inputs = &sig.inputs; + let is_async = sig.asyncness.is_some(); let flags = schedule_flag(nif_attributes.schedule); let function = fun.to_owned().into_token_stream(); let arity = arity(inputs.clone()); - let decoded_terms = extract_inputs(inputs.clone()); - let argument_names = create_function_params(inputs.clone()); let erl_func_name = nif_attributes .custom_name .map_or_else(|| name.to_string(), |n| n.value().to_string()); @@ -53,6 +52,24 @@ pub fn transcoder_decorator(nif_attributes: NifAttributes, fun: syn::ItemFn) -> panic!("Only non-Control ASCII strings are supported as function names"); } + if is_async { + generate_async_nif(erl_func_name, name, flags, arity, function, inputs.clone()) + } else { + generate_sync_nif(erl_func_name, name, flags, arity, function, inputs.clone()) + } +} + +fn generate_sync_nif( + erl_func_name: String, + name: &syn::Ident, + flags: TokenStream, + arity: u32, + function: TokenStream, + inputs: Punctuated, +) -> TokenStream { + let decoded_terms = extract_inputs(inputs.clone()); + let argument_names = create_function_params(inputs); + quote! { rustler::codegen_runtime::inventory::submit!( rustler::Nif { @@ -98,6 +115,79 @@ pub fn transcoder_decorator(nif_attributes: NifAttributes, fun: syn::ItemFn) -> } } +fn generate_async_nif( + erl_func_name: String, + name: &syn::Ident, + flags: TokenStream, + arity: u32, + function: TokenStream, + inputs: Punctuated, +) -> TokenStream { + let decoded_terms_async = extract_inputs_for_async(inputs.clone()); + let argument_names = create_function_params(inputs); + + quote! { + // Define the original async function at module level + #function + + // Submit the NIF wrapper to inventory + rustler::codegen_runtime::inventory::submit!( + rustler::Nif { + name: concat!(#erl_func_name, "\0").as_ptr() + as *const rustler::codegen_runtime::c_char, + arity: #arity, + flags: #flags as rustler::codegen_runtime::c_uint, + raw_func: { + unsafe extern "C" fn nif_func( + nif_env: rustler::codegen_runtime::NIF_ENV, + argc: rustler::codegen_runtime::c_int, + argv: *const rustler::codegen_runtime::NIF_TERM + ) -> rustler::codegen_runtime::NIF_TERM { + let lifetime = (); + let env = rustler::Env::new(&lifetime, nif_env); + + let terms = std::slice::from_raw_parts(argv, argc as usize) + .iter() + .map(|term| rustler::Term::new(env, *term)) + .collect::>(); + + fn wrapper<'a>( + env: rustler::Env<'a>, + args: &[rustler::Term<'a>] + ) -> rustler::codegen_runtime::NifReturned { + // Get the calling process PID + let pid = env.pid(); + + // Decode all arguments before spawning async task + #decoded_terms_async + + // Spawn async task on tokio runtime + let handle = rustler::tokio::runtime_handle(); + handle.spawn(async move { + // Execute the async function and get the result + let value = #name(#argument_names).await; + + // Send result back to calling process + let mut msg_env = rustler::OwnedEnv::new(); + let _ = msg_env.send_and_clear(&pid, |env| { + rustler::Encoder::encode(&value, env) + }); + }); + + // Return :ok immediately + rustler::codegen_runtime::NifReturned::Term( + rustler::types::atom::ok().to_term(env).as_c_arg() + ) + } + wrapper(env, &terms).apply(env) + } + nif_func + } + } + ); + } +} + fn schedule_flag(schedule: Option) -> TokenStream { let mut tokens = TokenStream::new(); @@ -201,7 +291,8 @@ fn arity(inputs: Punctuated) -> u32 { if let syn::Type::Path(syn::TypePath { path, .. }) = &*typed.ty { let ident = path.segments.last().unwrap().ident.to_string(); - if i == 0 && ident == "Env" { + // Skip Env and CallerPid when they're the first parameter + if i == 0 && (ident == "Env" || ident == "CallerPid") { continue; } @@ -217,3 +308,74 @@ fn arity(inputs: Punctuated) -> u32 { arity } + +fn extract_inputs_for_async(inputs: Punctuated) -> TokenStream { + let mut tokens = TokenStream::new(); + let mut args_offset = 0; + + for (param_idx, item) in inputs.iter().enumerate() { + if let syn::FnArg::Typed(ref typed) = item { + let name = &typed.pat; + let typ = &typed.ty; + + match &**typ { + syn::Type::Path(syn::TypePath { path, .. }) => { + let ident = path.segments.last().unwrap().ident.to_string(); + + // Special case: CallerPid as first parameter + if param_idx == 0 && ident == "CallerPid" { + let caller_setup = quote! { + let #name: #typ = rustler::types::CallerPid::new(pid); + }; + tokens.extend(caller_setup); + args_offset = 1; // Don't consume an arg slot + continue; + } + + // Async functions cannot take Env or Term parameters + if ident == "Env" || ident == "Term" { + panic!( + "Async NIFs cannot accept '{}' parameters. \ + All arguments must be decodable types that can be moved into the async task.", + ident + ); + } + + let args_idx = param_idx - args_offset; + let decoder = quote! { + let #name: #typ = match args[#args_idx].decode() { + Ok(value) => value, + Err(_) => return rustler::codegen_runtime::NifReturned::BadArg + }; + }; + + tokens.extend(decoder); + } + syn::Type::Reference(_) => { + panic!( + "Async NIFs cannot accept reference parameters. \ + All arguments must be owned types that can be moved into the async task." + ); + } + syn::Type::Tuple(typ) => { + let args_idx = param_idx - args_offset; + let decoder = quote! { + let #name: #typ = match args[#args_idx].decode() { + Ok(value) => value, + Err(_) => return rustler::codegen_runtime::NifReturned::BadArg + }; + }; + + tokens.extend(decoder); + } + other => { + panic!("unsupported async input type: {other:?}"); + } + } + } else { + panic!("unsupported input given: {:?}", stringify!(&item)); + }; + } + + tokens +} diff --git a/rustler_tests/config/config.exs b/rustler_tests/config/config.exs new file mode 100644 index 00000000..5e63e4d3 --- /dev/null +++ b/rustler_tests/config/config.exs @@ -0,0 +1,7 @@ +import Config + +config :rustler_test, RustlerTest, + load_data: [ + worker_threads: 4, + thread_name: "rustler-test" + ] diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 161536a5..fef4eb94 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -40,6 +40,11 @@ defmodule RustlerTest do def compare_local_pids(_, _), do: err() def are_equal_local_pids(_, _), do: err() + def async_add(_, _), do: err() + def async_sleep_and_return(_, _), do: err() + def async_tuple_multiply(_), do: err() + def async_with_progress(_), do: err() + def term_debug(_), do: err() def term_debug_and_reparse(term) do diff --git a/rustler_tests/native/rustler_test/Cargo.toml b/rustler_tests/native/rustler_test/Cargo.toml index acc27453..ea3962bf 100644 --- a/rustler_tests/native/rustler_test/Cargo.toml +++ b/rustler_tests/native/rustler_test/Cargo.toml @@ -14,10 +14,13 @@ name = "hello_rust" path = "src/main.rs" [features] +default = ["rustler/tokio_rt"] +tokio_rt = ["rustler/tokio_rt"] nif_version_2_14 = ["rustler/nif_version_2_14"] nif_version_2_15 = ["nif_version_2_14", "rustler/nif_version_2_15"] nif_version_2_16 = ["nif_version_2_15", "rustler/nif_version_2_16"] nif_version_2_17 = ["nif_version_2_16", "rustler/nif_version_2_17"] [dependencies] -rustler = { path = "../../../rustler", features = ["allocator"] } +rustler = { path = "../../../rustler", features = ["allocator", "tokio_rt"] } +tokio = { version = "1", features = ["time"] } diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index e37e35be..82202a50 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -1,3 +1,4 @@ +mod test_async; mod test_atom; mod test_binary; mod test_codegen; @@ -19,9 +20,17 @@ mod test_term; mod test_thread; mod test_tuple; -// Intentional usage of the explicit form (in an "invalid" way, listing a wrong set of functions) to ensure that the warning stays alive -rustler::init!("Elixir.RustlerTest", [deprecated, usage], load = load); +// Temporarily add async_add explicitly to debug +rustler::init!("Elixir.RustlerTest", load = load); + +fn load(env: rustler::Env, load_info: rustler::Term) -> bool { + // Configure Tokio runtime from Elixir load_data + #[cfg(feature = "tokio_rt")] + { + if let Ok(config) = load_info.decode::() { + rustler::tokio::configure(config).ok(); + } + } -fn load(env: rustler::Env, _: rustler::Term) -> bool { test_resource::on_load(env) } diff --git a/rustler_tests/native/rustler_test/src/test_async.rs b/rustler_tests/native/rustler_test/src/test_async.rs new file mode 100644 index 00000000..16197ffd --- /dev/null +++ b/rustler_tests/native/rustler_test/src/test_async.rs @@ -0,0 +1,40 @@ +use rustler::types::CallerPid; +use rustler::OwnedEnv; +use std::time::Duration; + +#[rustler::nif] +async fn async_add(a: i64, b: i64) -> i64 { + tokio::time::sleep(Duration::from_millis(10)).await; + a + b +} + +#[rustler::nif] +async fn async_sleep_and_return(ms: u64, value: String) -> String { + tokio::time::sleep(Duration::from_millis(ms)).await; + value +} + +#[rustler::nif] +async fn async_tuple_multiply(input: (i64, i64)) -> i64 { + tokio::time::sleep(Duration::from_millis(5)).await; + input.0 * input.1 +} + +#[rustler::nif] +async fn async_with_progress(caller: CallerPid, work_items: i64) -> i64 { + let mut total = 0; + + for i in 0..work_items { + tokio::time::sleep(Duration::from_millis(10)).await; + total += i; + + // Send progress update + let mut env = OwnedEnv::new(); + let _ = env.send_and_clear(caller.as_pid(), |e| { + use rustler::Encoder; + ("progress", i).encode(e) + }); + } + + total +} diff --git a/rustler_tests/test/async_test.exs b/rustler_tests/test/async_test.exs new file mode 100644 index 00000000..e5e499ac --- /dev/null +++ b/rustler_tests/test/async_test.exs @@ -0,0 +1,70 @@ +defmodule RustlerTest.AsyncTest do + use ExUnit.Case, async: false + + test "async_add returns :ok and result comes via message" do + assert :ok == RustlerTest.async_add(10, 20) + + assert_receive result, 1000 + assert result == 30 + end + + test "async_sleep_and_return" do + assert :ok == RustlerTest.async_sleep_and_return(50, "hello world") + + assert_receive result, 1000 + assert result == "hello world" + end + + test "async_tuple_multiply" do + assert :ok == RustlerTest.async_tuple_multiply({6, 7}) + + assert_receive result, 1000 + assert result == 42 + end + + test "multiple async calls can run concurrently" do + # Start 3 async operations + assert :ok == RustlerTest.async_sleep_and_return(100, "first") + assert :ok == RustlerTest.async_sleep_and_return(100, "second") + assert :ok == RustlerTest.async_sleep_and_return(100, "third") + + # Collect all results + results = + for _ <- 1..3 do + receive do + msg -> msg + after + 1000 -> :timeout + end + end + + # All should have completed + assert "first" in results + assert "second" in results + assert "third" in results + end + + test "async_with_progress sends intermediate messages using CallerPid" do + assert :ok == RustlerTest.async_with_progress(3) + + # Should receive progress messages: {:progress, 0}, {:progress, 1}, {:progress, 2} + # Then final result: 3 (which is 0 + 1 + 2) + + messages = + for _ <- 1..4 do + receive do + msg -> msg + after + 500 -> :timeout + end + end + + # Check we got progress updates + assert {"progress", 0} in messages + assert {"progress", 1} in messages + assert {"progress", 2} in messages + + # Final result should be sum: 0 + 1 + 2 = 3 + assert 3 in messages + end +end From 89952e50299f61fa3726a4825486672da623ac48 Mon Sep 17 00:00:00 2001 From: Sonny Scroggin Date: Thu, 13 Nov 2025 08:12:47 -0600 Subject: [PATCH 2/4] rustler::task --- rustler/src/lib.rs | 11 +- rustler/src/task_ref.rs | 37 ++++++ rustler/src/types/local_pid.rs | 107 ++++++++++++------ rustler/src/types/mod.rs | 5 +- rustler_codegen/src/lib.rs | 60 ++++++++++ rustler_codegen/src/nif.rs | 98 +++++++++++++--- .../native/rustler_test/src/test_async.rs | 23 ++-- rustler_tests/test/async_test.exs | 51 +++++---- 8 files changed, 305 insertions(+), 87 deletions(-) create mode 100644 rustler/src/task_ref.rs diff --git a/rustler/src/lib.rs b/rustler/src/lib.rs index 4d0fdf2c..83d391d1 100644 --- a/rustler/src/lib.rs +++ b/rustler/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(warnings)] #![allow(non_camel_case_types)] #![allow(clippy::missing_safety_doc)] @@ -73,8 +74,8 @@ pub use nif::Nif; pub type NifResult = Result; pub use rustler_codegen::{ - init, nif, resource_impl, NifException, NifMap, NifRecord, NifStruct, NifTaggedEnum, NifTuple, - NifUnitEnum, NifUntaggedEnum, + init, nif, resource_impl, task, NifException, NifMap, NifRecord, NifStruct, NifTaggedEnum, + NifTuple, NifUnitEnum, NifUntaggedEnum, }; #[cfg(feature = "serde")] @@ -86,4 +87,10 @@ pub use crate::serde::SerdeTerm; #[cfg(feature = "tokio_rt")] pub mod tokio; +#[cfg(feature = "tokio_rt")] +mod task_ref; + +#[cfg(feature = "tokio_rt")] +pub use task_ref::TaskRef; + pub mod sys; diff --git a/rustler/src/task_ref.rs b/rustler/src/task_ref.rs new file mode 100644 index 00000000..1b456ac0 --- /dev/null +++ b/rustler/src/task_ref.rs @@ -0,0 +1,37 @@ +use std::sync::atomic::{AtomicU64, Ordering}; + +static TASK_COUNTER: AtomicU64 = AtomicU64::new(0); + +/// Task reference resource for async tasks. +/// +/// This is automatically created by `#[rustler::task]` and returned to the caller. +/// All messages sent by the task (both intermediate and final) are tagged with this reference. +#[cfg(feature = "tokio_rt")] +#[derive(Debug, Clone)] +pub struct TaskRef { + #[allow(dead_code)] + id: u64, +} + +#[cfg(feature = "tokio_rt")] +impl TaskRef { + /// Create a new TaskRef with a unique ID. + /// + /// This is used internally by the `#[rustler::task]` macro. + #[doc(hidden)] + pub fn new() -> Self { + Self { + id: TASK_COUNTER.fetch_add(1, Ordering::Relaxed), + } + } +} + +// Implement Resource trait +#[cfg(feature = "tokio_rt")] +impl crate::Resource for TaskRef {} + +// Auto-register TaskRef resource via inventory +#[cfg(feature = "tokio_rt")] +crate::codegen_runtime::inventory::submit! { + crate::resource::Registration::new::() +} diff --git a/rustler/src/types/local_pid.rs b/rustler/src/types/local_pid.rs index 166dc3b1..b2c57e4f 100644 --- a/rustler/src/types/local_pid.rs +++ b/rustler/src/types/local_pid.rs @@ -63,55 +63,98 @@ impl Ord for LocalPid { } } -/// A wrapper for `LocalPid` that represents the calling process in async NIFs. +/// Caller information for async tasks with type-safe message sending. /// -/// When used as the first parameter of an async NIF, `CallerPid` is automatically -/// populated with the calling process's PID, and is not decoded from the arguments. -/// This allows async NIFs to send intermediate messages back to the caller. +/// Contains the calling process's PID and the task reference. When used as the first +/// parameter of a `#[rustler::task]`, it is automatically populated and provides +/// convenient methods for sending messages tagged with the task reference. +/// +/// The generic type `T` is automatically inferred from the task's return type, +/// ensuring that intermediate messages sent via `send()` are the same type as +/// the final result. /// /// # Example /// /// ```ignore -/// #[rustler::nif] -/// async fn with_progress(caller: CallerPid, work: Vec) -> i64 { -/// // Send progress updates -/// let mut env = OwnedEnv::new(); -/// env.send(caller.as_pid(), |e| "started".encode(e)); -/// -/// let result = do_work(work).await; -/// -/// // Final result sent automatically -/// result +/// #[rustler::task] +/// async fn with_progress(caller: Caller, work: Vec) -> Result { +/// for (i, item) in work.iter().enumerate() { +/// // Type-checked: must send Result +/// caller.send(Ok(i as i64)); +/// process(item).await?; +/// } +/// Ok(work.len() as i64) /// } /// ``` -#[derive(Copy, Clone)] -pub struct CallerPid(LocalPid); +#[cfg(feature = "tokio_rt")] +#[derive(Clone)] +pub struct Caller { + pid: LocalPid, + task_ref: crate::ResourceArc, + _phantom: std::marker::PhantomData, +} -impl CallerPid { - /// Create a new CallerPid from a LocalPid. +#[cfg(feature = "tokio_rt")] +impl Caller { + /// Create a new Caller. /// - /// This is only used internally by the NIF macro. + /// This is only used internally by the task macro. #[doc(hidden)] - pub fn new(pid: LocalPid) -> Self { - CallerPid(pid) + pub fn new(pid: LocalPid, task_ref: crate::ResourceArc) -> Self { + Self { + pid, + task_ref, + _phantom: std::marker::PhantomData, + } } - /// Get the underlying LocalPid. - pub fn as_pid(&self) -> &LocalPid { - &self.0 + /// Get the calling process's PID. + pub fn pid(&self) -> &LocalPid { + &self.pid } - /// Check whether the calling process is alive. - pub fn is_alive(self, env: Env) -> bool { - self.0.is_alive(env) + /// Get the task reference. + pub fn task_ref(&self) -> &crate::ResourceArc { + &self.task_ref } -} -impl std::ops::Deref for CallerPid { - type Target = LocalPid; + /// Send an intermediate message to the caller, automatically tagged with the task reference. + /// + /// The message will be sent as `{task_ref, message}`. + /// + /// The message type `T` must match the task's return type, ensuring type safety + /// for all messages sent during task execution. + /// + /// # Example + /// + /// ```ignore + /// #[rustler::task] + /// async fn process(caller: Caller, count: i64) -> String { + /// for i in 0..count { + /// caller.send(format!("Progress: {}", i)); // ✅ Type-safe + /// // caller.send(i); // ❌ Compile error: expected String, got i64 + /// } + /// "Done".to_string() + /// } + /// ``` + pub fn send(&self, message: T) { + let mut env = crate::OwnedEnv::new(); + let task_ref = self.task_ref.clone(); + let _ = env.send_and_clear(&self.pid, move |env| (task_ref, message).encode(env)); + } - fn deref(&self) -> &Self::Target { - &self.0 + /// Send the final message and complete the task. + /// + /// This is used internally by the `#[rustler::task]` macro to send the + /// task's return value. User code should just return the value normally. + #[doc(hidden)] + pub fn finish(self, message: T) { + self.send(message); + } + + /// Check whether the calling process is alive. + pub fn is_alive(&self, env: Env) -> bool { + self.pid.is_alive(env) } } diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index c1b7d2f3..27512d38 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -28,7 +28,10 @@ pub mod tuple; #[doc(hidden)] pub mod local_pid; -pub use self::local_pid::{CallerPid, LocalPid}; +pub use self::local_pid::LocalPid; + +#[cfg(feature = "tokio_rt")] +pub use self::local_pid::Caller; #[doc(hidden)] pub mod reference; diff --git a/rustler_codegen/src/lib.rs b/rustler_codegen/src/lib.rs index 9792a3f6..7f6deb13 100644 --- a/rustler_codegen/src/lib.rs +++ b/rustler_codegen/src/lib.rs @@ -1,3 +1,4 @@ +#![deny(warnings)] #![recursion_limit = "128"] use proc_macro::TokenStream; @@ -102,6 +103,65 @@ pub fn nif(args: TokenStream, input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as syn::ItemFn); + // Reject async functions in #[rustler::nif] + if input.sig.asyncness.is_some() { + return syn::Error::new_spanned( + input.sig.asyncness, + "async functions are not supported with #[rustler::nif]. Use #[rustler::task] instead.", + ) + .to_compile_error() + .into(); + } + + nif::transcoder_decorator(nif_attributes, input).into() +} + +/// Wrap an async function as a spawned task that returns a reference. +/// +/// The task is spawned onto the configured async runtime and returns a unique +/// reference immediately. When the task completes, a message `{ref, result}` is +/// sent to the calling process. +/// +/// ```ignore +/// #[rustler::task] +/// async fn fetch_data(url: String) -> Result { +/// // Long-running async operation +/// tokio::time::sleep(Duration::from_secs(1)).await; +/// Ok("data".to_string()) +/// } +/// ``` +/// +/// From Elixir: +/// ```elixir +/// ref = MyNIF.fetch_data("https://example.com") +/// receive do +/// {^ref, result} -> IO.puts("Got result: #{inspect(result)}") +/// after +/// 5000 -> :timeout +/// end +/// ``` +#[proc_macro_attribute] +pub fn task(args: TokenStream, input: TokenStream) -> TokenStream { + let mut nif_attributes = nif::NifAttributes::default(); + + if !args.is_empty() { + let nif_macro_parser = syn::meta::parser(|meta| nif_attributes.parse(meta)); + + syn::parse_macro_input!(args with nif_macro_parser); + } + + let input = syn::parse_macro_input!(input as syn::ItemFn); + + // Require async functions for #[rustler::task] + if input.sig.asyncness.is_none() { + return syn::Error::new_spanned( + input.sig.fn_token, + "#[rustler::task] requires an async function", + ) + .to_compile_error() + .into(); + } + nif::transcoder_decorator(nif_attributes, input).into() } diff --git a/rustler_codegen/src/nif.rs b/rustler_codegen/src/nif.rs index c8909034..4e308961 100644 --- a/rustler_codegen/src/nif.rs +++ b/rustler_codegen/src/nif.rs @@ -53,13 +53,21 @@ pub fn transcoder_decorator(nif_attributes: NifAttributes, fun: syn::ItemFn) -> } if is_async { - generate_async_nif(erl_func_name, name, flags, arity, function, inputs.clone()) + generate_task( + erl_func_name, + name, + flags, + arity, + function, + inputs.clone(), + &sig.output, + ) } else { - generate_sync_nif(erl_func_name, name, flags, arity, function, inputs.clone()) + generate_nif(erl_func_name, name, flags, arity, function, inputs.clone()) } } -fn generate_sync_nif( +fn generate_nif( erl_func_name: String, name: &syn::Ident, flags: TokenStream, @@ -115,17 +123,54 @@ fn generate_sync_nif( } } -fn generate_async_nif( +fn generate_task( erl_func_name: String, name: &syn::Ident, flags: TokenStream, arity: u32, function: TokenStream, inputs: Punctuated, + return_type: &syn::ReturnType, ) -> TokenStream { - let decoded_terms_async = extract_inputs_for_async(inputs.clone()); + // Check if first parameter is Caller + let uses_caller = inputs + .first() + .and_then(|arg| { + if let syn::FnArg::Typed(typed) = arg { + if let syn::Type::Path(syn::TypePath { path, .. }) = &*typed.ty { + let segment = path.segments.last()?; + return Some(segment.ident == "Caller"); + } + } + None + }) + .unwrap_or(false); + + let decoded_terms_async = extract_inputs_for_async(inputs.clone(), return_type); let argument_names = create_function_params(inputs); + // Generate code for sending the final result + let (send_result, caller_for_finish) = if uses_caller { + // When using Caller, clone it for the finish call + let caller_clone = quote! { + let caller_for_finish = caller.clone(); + }; + let finish = quote! { + caller_for_finish.finish(value); + }; + (finish, Some(caller_clone)) + } else { + // When not using Caller, send directly + let direct_send = quote! { + let mut msg_env = rustler::OwnedEnv::new(); + let _ = msg_env.send_and_clear(&pid, |env| { + use rustler::Encoder; + (task_ref_for_spawn, value).encode(env) + }); + }; + (direct_send, None) + }; + quote! { // Define the original async function at module level #function @@ -158,25 +203,30 @@ fn generate_async_nif( // Get the calling process PID let pid = env.pid(); + // Create a unique task reference resource + let task_ref = rustler::ResourceArc::new(rustler::TaskRef::new()); + let task_ref_for_spawn = task_ref.clone(); + // Decode all arguments before spawning async task #decoded_terms_async + // Clone caller if needed for finish() call + #caller_for_finish + // Spawn async task on tokio runtime let handle = rustler::tokio::runtime_handle(); handle.spawn(async move { // Execute the async function and get the result let value = #name(#argument_names).await; - // Send result back to calling process - let mut msg_env = rustler::OwnedEnv::new(); - let _ = msg_env.send_and_clear(&pid, |env| { - rustler::Encoder::encode(&value, env) - }); + // Send {ref, result} back to calling process + #send_result }); - // Return :ok immediately + // Return the task reference immediately + use rustler::Encoder; rustler::codegen_runtime::NifReturned::Term( - rustler::types::atom::ok().to_term(env).as_c_arg() + task_ref.encode(env).as_c_arg() ) } wrapper(env, &terms).apply(env) @@ -291,8 +341,8 @@ fn arity(inputs: Punctuated) -> u32 { if let syn::Type::Path(syn::TypePath { path, .. }) = &*typed.ty { let ident = path.segments.last().unwrap().ident.to_string(); - // Skip Env and CallerPid when they're the first parameter - if i == 0 && (ident == "Env" || ident == "CallerPid") { + // Skip Env and Caller when they're the first parameter + if i == 0 && (ident == "Env" || ident == "Caller") { continue; } @@ -309,10 +359,18 @@ fn arity(inputs: Punctuated) -> u32 { arity } -fn extract_inputs_for_async(inputs: Punctuated) -> TokenStream { +fn extract_inputs_for_async( + inputs: Punctuated, + return_type: &syn::ReturnType, +) -> TokenStream { let mut tokens = TokenStream::new(); let mut args_offset = 0; + // Validate that async tasks have an explicit return type + if matches!(return_type, syn::ReturnType::Default) { + panic!("Async tasks must have an explicit return type"); + } + for (param_idx, item) in inputs.iter().enumerate() { if let syn::FnArg::Typed(ref typed) = item { let name = &typed.pat; @@ -322,10 +380,14 @@ fn extract_inputs_for_async(inputs: Punctuated) -> TokenStrea syn::Type::Path(syn::TypePath { path, .. }) => { let ident = path.segments.last().unwrap().ident.to_string(); - // Special case: CallerPid as first parameter - if param_idx == 0 && ident == "CallerPid" { + // Special case: Caller as first parameter + if param_idx == 0 && ident == "Caller" { + // Validate that generic argument matches return type (optional check) + // The Rust compiler will catch mismatches anyway, but we could add + // a better error message here if needed + let caller_setup = quote! { - let #name: #typ = rustler::types::CallerPid::new(pid); + let #name: #typ = rustler::types::Caller::new(pid, task_ref.clone()); }; tokens.extend(caller_setup); args_offset = 1; // Don't consume an arg slot diff --git a/rustler_tests/native/rustler_test/src/test_async.rs b/rustler_tests/native/rustler_test/src/test_async.rs index 16197ffd..82d84310 100644 --- a/rustler_tests/native/rustler_test/src/test_async.rs +++ b/rustler_tests/native/rustler_test/src/test_async.rs @@ -1,39 +1,36 @@ -use rustler::types::CallerPid; -use rustler::OwnedEnv; +use rustler::types::Caller; use std::time::Duration; -#[rustler::nif] +#[rustler::task] async fn async_add(a: i64, b: i64) -> i64 { tokio::time::sleep(Duration::from_millis(10)).await; a + b } -#[rustler::nif] +#[rustler::task] async fn async_sleep_and_return(ms: u64, value: String) -> String { tokio::time::sleep(Duration::from_millis(ms)).await; value } -#[rustler::nif] +#[rustler::task] async fn async_tuple_multiply(input: (i64, i64)) -> i64 { tokio::time::sleep(Duration::from_millis(5)).await; input.0 * input.1 } -#[rustler::nif] -async fn async_with_progress(caller: CallerPid, work_items: i64) -> i64 { +#[rustler::task] +async fn async_with_progress(caller: Caller, work_items: i64) -> i64 { let mut total = 0; for i in 0..work_items { tokio::time::sleep(Duration::from_millis(10)).await; total += i; - // Send progress update - let mut env = OwnedEnv::new(); - let _ = env.send_and_clear(caller.as_pid(), |e| { - use rustler::Encoder; - ("progress", i).encode(e) - }); + // Send progress update - automatically tagged with task ref + // Note: This would be a compile error if we tried to send the tuple: + // caller.send(("progress", i)); // ❌ Type error: expected i64, got tuple + caller.send(i); // ✅ Type-safe: i64 matches return type } total diff --git a/rustler_tests/test/async_test.exs b/rustler_tests/test/async_test.exs index e5e499ac..f01e7448 100644 --- a/rustler_tests/test/async_test.exs +++ b/rustler_tests/test/async_test.exs @@ -1,38 +1,45 @@ defmodule RustlerTest.AsyncTest do use ExUnit.Case, async: false - test "async_add returns :ok and result comes via message" do - assert :ok == RustlerTest.async_add(10, 20) + test "async_add returns ref and result comes via message" do + ref = RustlerTest.async_add(10, 20) + assert is_reference(ref) - assert_receive result, 1000 + assert_receive {^ref, result}, 1000 assert result == 30 end test "async_sleep_and_return" do - assert :ok == RustlerTest.async_sleep_and_return(50, "hello world") + ref = RustlerTest.async_sleep_and_return(50, "hello world") + assert is_reference(ref) - assert_receive result, 1000 + assert_receive {^ref, result}, 1000 assert result == "hello world" end test "async_tuple_multiply" do - assert :ok == RustlerTest.async_tuple_multiply({6, 7}) + ref = RustlerTest.async_tuple_multiply({6, 7}) + assert is_reference(ref) - assert_receive result, 1000 + assert_receive {^ref, result}, 1000 assert result == 42 end test "multiple async calls can run concurrently" do # Start 3 async operations - assert :ok == RustlerTest.async_sleep_and_return(100, "first") - assert :ok == RustlerTest.async_sleep_and_return(100, "second") - assert :ok == RustlerTest.async_sleep_and_return(100, "third") + ref1 = RustlerTest.async_sleep_and_return(100, "first") + ref2 = RustlerTest.async_sleep_and_return(100, "second") + ref3 = RustlerTest.async_sleep_and_return(100, "third") + + assert is_reference(ref1) + assert is_reference(ref2) + assert is_reference(ref3) # Collect all results results = for _ <- 1..3 do receive do - msg -> msg + {_ref, msg} -> msg after 1000 -> :timeout end @@ -44,11 +51,13 @@ defmodule RustlerTest.AsyncTest do assert "third" in results end - test "async_with_progress sends intermediate messages using CallerPid" do - assert :ok == RustlerTest.async_with_progress(3) + test "async_with_progress sends intermediate messages using Caller" do + ref = RustlerTest.async_with_progress(3) + assert is_reference(ref) - # Should receive progress messages: {:progress, 0}, {:progress, 1}, {:progress, 2} - # Then final result: 3 (which is 0 + 1 + 2) + # All messages (intermediate and final) are tagged with the ref and have same type (i64) + # Should receive: {ref, 0}, {ref, 1}, {ref, 2}, {ref, 3} + # Final result: {ref, 3} (which is 0 + 1 + 2) messages = for _ <- 1..4 do @@ -59,12 +68,12 @@ defmodule RustlerTest.AsyncTest do end end - # Check we got progress updates - assert {"progress", 0} in messages - assert {"progress", 1} in messages - assert {"progress", 2} in messages + # Check we got progress updates (intermediate i64 values) tagged with ref + assert {ref, 0} in messages + assert {ref, 1} in messages + assert {ref, 2} in messages - # Final result should be sum: 0 + 1 + 2 = 3 - assert 3 in messages + # Final result should also be tagged with ref: {ref, 3} + assert {ref, 3} in messages end end From 5ab4ac1760c36ac7ffcdf535aa7cf2755493324f Mon Sep 17 00:00:00 2001 From: Sonny Scroggin Date: Thu, 13 Nov 2025 12:13:52 -0600 Subject: [PATCH 3/4] Channel and AsyncRuntime --- rustler/Cargo.toml | 6 +- rustler/build.rs | 3 + rustler/src/lib.rs | 40 +- rustler/src/runtime/async_runtime.rs | 30 ++ rustler/src/runtime/channel.rs | 360 ++++++++++++++++++ rustler/src/runtime/mod.rs | 91 +++++ .../{tokio/runtime.rs => runtime/tokio.rs} | 53 ++- rustler/src/task_ref.rs | 37 -- rustler/src/tokio/mod.rs | 3 - rustler/src/types/local_pid.rs | 100 +---- rustler/src/types/mod.rs | 3 - rustler_codegen/src/nif.rs | 126 +++--- rustler_tests/lib/rustler_test.ex | 5 + .../native/rustler_test/.cargo/config.toml | 2 + rustler_tests/native/rustler_test/Cargo.toml | 8 +- rustler_tests/native/rustler_test/src/lib.rs | 9 +- .../native/rustler_test/src/test_async.rs | 160 +++++++- rustler_tests/test/async_test.exs | 100 +++++ 18 files changed, 910 insertions(+), 226 deletions(-) create mode 100644 rustler/src/runtime/async_runtime.rs create mode 100644 rustler/src/runtime/channel.rs create mode 100644 rustler/src/runtime/mod.rs rename rustler/src/{tokio/runtime.rs => runtime/tokio.rs} (80%) delete mode 100644 rustler/src/task_ref.rs delete mode 100644 rustler/src/tokio/mod.rs create mode 100644 rustler_tests/native/rustler_test/.cargo/config.toml diff --git a/rustler/Cargo.toml b/rustler/Cargo.toml index 0380c78e..0e0015a2 100644 --- a/rustler/Cargo.toml +++ b/rustler/Cargo.toml @@ -11,7 +11,7 @@ rust-version = "1.91" [features] big_integer = ["dep:num-bigint"] -default = ["nif_version_2_15"] +default = ["nif_version_2_15", "async-rt", "tokio-rt"] derive = [] allocator = [] nif_version_2_14 = [] @@ -19,7 +19,8 @@ nif_version_2_15 = ["nif_version_2_14"] nif_version_2_16 = ["nif_version_2_15"] nif_version_2_17 = ["nif_version_2_16"] serde = ["dep:serde"] -tokio_rt = ["dep:tokio"] +async-rt = [] +tokio-rt = ["async-rt", "dep:tokio", "dep:futures-core"] [dependencies] inventory = "0.3" @@ -27,6 +28,7 @@ rustler_codegen = { path = "../rustler_codegen", version = "0.37.1"} num-bigint = { version = "0.4", optional = true } serde = { version = "1", optional = true } tokio = { version = "1", optional = true, features = ["rt", "rt-multi-thread", "sync"] } +futures-core = { version = "0.3", optional = true } once_cell = "1" [target.'cfg(not(windows))'.dependencies] diff --git a/rustler/build.rs b/rustler/build.rs index 96508f18..9d94593f 100644 --- a/rustler/build.rs +++ b/rustler/build.rs @@ -901,6 +901,9 @@ fn main() { let dest_path = Path::new(&out_dir).join(SNIPPET_NAME); fs::write(dest_path, api).unwrap(); + // Tell Cargo that rustler_unstable is a valid cfg + println!("cargo::rustc-check-cfg=cfg(rustler_unstable)"); + // The following lines are important to tell Cargo to recompile if something changes. println!("cargo:rerun-if-changed=build.rs"); } diff --git a/rustler/src/lib.rs b/rustler/src/lib.rs index 83d391d1..eae72912 100644 --- a/rustler/src/lib.rs +++ b/rustler/src/lib.rs @@ -59,7 +59,7 @@ pub use crate::schedule::SchedulerFlags; pub mod env; pub use crate::env::{Env, OwnedEnv}; pub mod thread; -pub use crate::thread::{spawn, JobSpawner, ThreadSpawner}; +pub use crate::thread::{JobSpawner, ThreadSpawner}; pub mod error; pub use crate::error::Error; @@ -84,13 +84,35 @@ pub mod serde; #[cfg(feature = "serde")] pub use crate::serde::SerdeTerm; -#[cfg(feature = "tokio_rt")] -pub mod tokio; - -#[cfg(feature = "tokio_rt")] -mod task_ref; - -#[cfg(feature = "tokio_rt")] -pub use task_ref::TaskRef; +#[cfg(feature = "async-rt")] +pub mod runtime; + +/// Spawn an async task on the global runtime. +/// +/// This provides a runtime-agnostic API similar to `tokio::spawn()`. +/// The future is spawned on the global runtime and executed to completion. +/// +/// Returns a join handle that can be used to await the result or cancel the task. +/// +/// # Example +/// +/// ```ignore +/// let handle = rustler::spawn(async { +/// // Your async code +/// process_data().await +/// }); +/// ``` +/// +/// # Panics +/// +/// Panics if the runtime fails to spawn the task. +#[cfg(feature = "tokio-rt")] +pub fn spawn(future: F) -> tokio::task::JoinHandle +where + F: std::future::Future + Send + 'static, + F::Output: Send + 'static, +{ + runtime::handle().spawn(future) +} pub mod sys; diff --git a/rustler/src/runtime/async_runtime.rs b/rustler/src/runtime/async_runtime.rs new file mode 100644 index 00000000..cad72d10 --- /dev/null +++ b/rustler/src/runtime/async_runtime.rs @@ -0,0 +1,30 @@ +use std::future::Future; +use std::pin::Pin; + +/// Trait for pluggable async runtimes. +/// +/// This allows users to provide their own async runtime implementation +/// instead of being locked into Tokio. +/// +/// # Example +/// +/// ```ignore +/// use rustler::runtime::AsyncRuntime; +/// use std::future::Future; +/// use std::pin::Pin; +/// +/// struct MyRuntime; +/// +/// impl AsyncRuntime for MyRuntime { +/// fn spawn(&self, future: Pin + Send + 'static>>) { +/// // Spawn on your custom runtime +/// } +/// } +/// ``` +pub trait AsyncRuntime: Send + Sync + 'static { + /// Spawn a future onto the runtime. + /// + /// The future should be executed to completion, and the runtime + /// is responsible for driving it. + fn spawn(&self, future: Pin + Send + 'static>>); +} diff --git a/rustler/src/runtime/channel.rs b/rustler/src/runtime/channel.rs new file mode 100644 index 00000000..c7d28ad3 --- /dev/null +++ b/rustler/src/runtime/channel.rs @@ -0,0 +1,360 @@ +use crate::types::LocalPid; +use crate::{Decoder, Encoder, Env, Error, NifResult, OwnedEnv, ResourceArc, Term}; +use futures_core::Stream; +use std::marker::PhantomData; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use tokio::sync::mpsc; + +// Type-erased sender function for channel messages. +type SendFn = Arc NifResult<()> + Send + Sync + std::panic::RefUnwindSafe>; + +/// Internal sender resource (type-erased for resource registration). +/// +/// This is the actual resource registered with BEAM. It holds a type-erased +/// sender function that decodes Terms and sends them to the typed channel. +pub struct ChannelSenderInner { + send_fn: SendFn, +} + +impl crate::Resource for ChannelSenderInner {} + +// Auto-register ChannelSenderInner resource +crate::codegen_runtime::inventory::submit! { + crate::resource::Registration::new::() +} + +/// Type-safe wrapper around the channel sender resource. +/// +/// This is returned to Elixir and can be used to send messages of type `Request` +/// to the running task. It also serves as the task reference for pattern matching +/// on response messages. +pub struct ChannelSender { + inner: ResourceArc, + _phantom: PhantomData, +} + +impl Clone for ChannelSender { + fn clone(&self) -> Self { + ChannelSender { + inner: self.inner.clone(), + _phantom: PhantomData, + } + } +} + +unsafe impl Send for ChannelSender {} +unsafe impl Sync for ChannelSender {} + +/// Cloneable sender for responses that can be passed to spawned tasks. +/// +/// This allows spawned subtasks to send their own responses back to Elixir, +/// all tagged with the same channel sender reference. +pub struct ResponseSender { + sender: ChannelSender, + pid: LocalPid, + _phantom: PhantomData, +} + +impl Clone for ResponseSender { + fn clone(&self) -> Self { + ResponseSender { + sender: self.sender.clone(), + pid: self.pid, + _phantom: PhantomData, + } + } +} + +unsafe impl Send for ResponseSender {} +unsafe impl Sync for ResponseSender {} + +impl ResponseSender +where + Response: Encoder + Send + 'static, +{ + /// Send a response message to the calling process. + /// + /// The message will be sent as `{channel_sender, response}`. + pub fn send(&self, response: Response) { + let mut env = OwnedEnv::new(); + let sender = self.sender.clone(); + let _ = env.send_and_clear(&self.pid, move |env| (sender, response).encode(env)); + } +} + +impl Encoder for ChannelSender { + fn encode<'a>(&self, env: Env<'a>) -> Term<'a> { + self.inner.encode(env) + } +} + +impl<'a, Request: 'a> Decoder<'a> for ChannelSender { + fn decode(term: Term<'a>) -> NifResult { + let inner: ResourceArc = term.decode()?; + Ok(ChannelSender { + inner, + _phantom: PhantomData, + }) + } +} + +/// Bidirectional channel for typed communication with async tasks. +/// +/// `Channel` provides both: +/// - **Receiving requests** from the calling Elixir process via `Stream` trait or `recv()` +/// - **Sending responses** back to the caller via `send()` +/// +/// The channel implements `Stream`, allowing idiomatic async iteration over incoming requests. +/// All response messages are automatically tagged with the channel sender reference. +/// +/// # Type Parameters +/// +/// - `Request`: Type of messages the task receives from Elixir (default: `()` for one-way tasks) +/// - `Response`: Type of messages the task sends back to Elixir +/// +/// # Examples +/// +/// ## One-way task (no incoming messages) +/// +/// ```ignore +/// #[rustler::task] +/// async fn compute(channel: Channel<(), i64>, n: i64) -> i64 { +/// channel.send(n / 2); // Send progress +/// tokio::time::sleep(Duration::from_millis(100)).await; +/// n * 2 // Final result +/// } +/// ``` +/// +/// ## Interactive task with Stream trait +/// +/// ```ignore +/// use futures::StreamExt; // for next() +/// +/// #[rustler::task] +/// async fn interactive(channel: Channel) -> String { +/// let mut count = 0; +/// +/// // Stream trait - idiomatic async iteration +/// while let Some(cmd) = channel.next().await { +/// match cmd { +/// Command::Stop => break, +/// Command::Process(x) => { +/// count += 1; +/// channel.send(format!("Processed: {}", x)); +/// } +/// } +/// } +/// +/// format!("Processed {} commands", count) +/// } +/// ``` +/// +/// ## Using recv() directly +/// +/// ```ignore +/// #[rustler::task] +/// async fn simple(channel: Channel) -> String { +/// if let Some(should_proceed) = channel.recv().await { +/// if should_proceed { +/// channel.send("Processing...".to_string()); +/// // do work +/// return "Done".to_string(); +/// } +/// } +/// "Cancelled".to_string() +/// } +/// ``` +pub struct Channel { + receiver: mpsc::UnboundedReceiver, + sender: ChannelSender, + pid: LocalPid, + _phantom_response: PhantomData, +} + +unsafe impl Send for Channel {} +unsafe impl Sync for Channel {} + +impl Channel +where + Request: for<'a> Decoder<'a> + Send + 'static, + Response: Encoder + Send + 'static, +{ + /// Create a new channel with a paired sender. + /// + /// Returns a tuple of (ChannelSender, Channel). The sender should be + /// returned to Elixir, and the channel is used by the async task. + /// + /// This is typically called automatically by the `#[rustler::task]` macro. + #[doc(hidden)] + pub fn new(pid: LocalPid) -> (ChannelSender, Self) { + let (tx, rx) = mpsc::unbounded_channel(); + + // Create type-erased sender function that decodes Terms to Request + let send_fn = Arc::new(move |_env: Env, term: Term| -> NifResult<()> { + let value: Request = term.decode()?; + tx.send(value) + .map_err(|_| Error::RaiseTerm(Box::new("channel_closed")))?; + Ok(()) + }); + + let inner = ChannelSenderInner { send_fn }; + let resource_arc = ResourceArc::new(inner); + + let sender: ChannelSender = ChannelSender { + inner: resource_arc, + _phantom: PhantomData, + }; + + let channel = Channel { + receiver: rx, + sender: sender.clone(), + pid, + _phantom_response: PhantomData, + }; + + (sender, channel) + } + + /// Receive the next request from the channel. + /// + /// Returns `Some(Request)` if a message was received, or `None` if the channel + /// has been closed (all senders dropped). + /// + /// This is an async function that will wait for a message to arrive. + /// + /// # Example + /// + /// ```ignore + /// if let Some(request) = channel.recv().await { + /// // Handle request + /// channel.send(process(request)); + /// } + /// ``` + pub async fn recv(&mut self) -> Option { + self.receiver.recv().await + } + + /// Try to receive a request without blocking. + /// + /// Returns `Some(Request)` if a message is available, or `None` if the channel + /// is empty or closed. + pub fn try_recv(&mut self) -> Option { + self.receiver.try_recv().ok() + } + + /// Send a response message to the calling process. + /// + /// The message will be sent as `{channel_sender, response}` where `channel_sender` + /// is the resource reference returned to Elixir. + /// + /// # Example + /// + /// ```ignore + /// channel.send("Progress: 50%".to_string()); + /// ``` + pub fn send(&self, response: Response) { + let mut env = OwnedEnv::new(); + let sender = self.sender.clone(); + let _ = env.send_and_clear(&self.pid, move |env| (sender, response).encode(env)); + } + + /// Send the final response and complete the task. + /// + /// This is used internally by the `#[rustler::task]` macro to send the + /// task's return value. User code should just return the value normally. + #[doc(hidden)] + pub fn finish(self, response: Response) { + self.send(response); + } + + /// Get a reference to the channel sender for this channel. + /// + /// This can be cloned and passed to other tasks or threads for sending + /// requests TO the channel. + pub fn sender(&self) -> &ChannelSender { + &self.sender + } + + /// Get a cloneable response sender that can send responses from spawned tasks. + /// + /// This is useful when you need to spawn subtasks that send their own + /// responses back to Elixir. + /// + /// # Example + /// + /// ```ignore + /// let responder = channel.responder(); + /// rustler::spawn(async move { + /// responder.send(42); + /// }); + /// ``` + pub fn responder(&self) -> ResponseSender { + ResponseSender { + sender: self.sender.clone(), + pid: self.pid, + _phantom: PhantomData, + } + } + + /// Check if the channel is closed (all senders dropped). + pub fn is_closed(&self) -> bool { + self.receiver.is_closed() + } + + /// Get the next request from the channel using the Stream trait. + /// + /// This is a convenience method that's equivalent to using the Stream trait + /// directly. Returns `Some(Request)` if a message was received, or `None` if + /// the channel has been closed. + /// + /// # Example + /// + /// ```ignore + /// while let Some(request) = channel.next().await { + /// // Handle request + /// channel.send(process(request)); + /// } + /// ``` + pub async fn next(&mut self) -> Option { + self.recv().await + } +} + +// Implement Stream trait for idiomatic async iteration +impl Stream for Channel +where + Request: for<'a> Decoder<'a> + Send + 'static, + Response: Encoder + Send + 'static, +{ + type Item = Request; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + // SAFETY: We never move the receiver + let this = unsafe { self.get_unchecked_mut() }; + this.receiver.poll_recv(cx) + } +} + +/// NIF function to send a message to a channel. +/// +/// This should be exported as a NIF in your module: +/// +/// ```ignore +/// #[rustler::nif] +/// fn channel_send_string( +/// env: Env, +/// sender: rustler::runtime::ChannelSender, +/// message: Term +/// ) -> NifResult { +/// rustler::runtime::channel::send(env, sender, message) +/// } +/// ``` +pub fn send(env: Env, sender: ChannelSender, message: Term) -> NifResult +where + T: for<'a> Decoder<'a> + Send + 'static, +{ + (sender.inner.send_fn)(env, message)?; + Ok(crate::types::atom::ok()) +} diff --git a/rustler/src/runtime/mod.rs b/rustler/src/runtime/mod.rs new file mode 100644 index 00000000..569c633a --- /dev/null +++ b/rustler/src/runtime/mod.rs @@ -0,0 +1,91 @@ +mod async_runtime; + +pub use async_runtime::AsyncRuntime; + +#[cfg(feature = "tokio-rt")] +pub mod tokio; + +#[cfg(feature = "tokio-rt")] +pub use tokio::TokioRuntime; + +#[cfg(feature = "tokio-rt")] +pub use tokio::{ConfigError, RuntimeConfig}; + +#[cfg(rustler_unstable)] +pub mod channel; + +#[cfg(rustler_unstable)] +pub use channel::{Channel, ChannelSender, ResponseSender}; + +/// Configure the global async runtime from Elixir configuration. +/// +/// This is the recommended way to configure the runtime, allowing Elixir application +/// developers to tune the runtime without recompiling the NIF. +/// +/// # Example +/// +/// ```ignore +/// use rustler::{Env, Term}; +/// +/// fn load(_env: Env, load_info: Term) -> bool { +/// if let Ok(config) = load_info.decode::() { +/// rustler::runtime::configure(config) +/// .expect("Failed to configure runtime"); +/// } +/// true +/// } +/// ``` +#[cfg(feature = "tokio-rt")] +pub fn configure(config: RuntimeConfig) -> Result<(), ConfigError> { + tokio::configure(config) +} + +/// Configure the global async runtime with a builder function. +/// +/// This provides a runtime-agnostic API. The builder type is determined +/// by the enabled runtime feature. +/// +/// # Example +/// +/// ```ignore +/// use rustler::{Env, Term}; +/// +/// fn load(_env: Env, _: Term) -> bool { +/// rustler::runtime::builder(|builder| { +/// builder +/// .worker_threads(4) +/// .thread_name("myapp-runtime") +/// .thread_stack_size(3 * 1024 * 1024); +/// }).expect("Failed to configure runtime"); +/// +/// true +/// } +/// ``` +#[cfg(feature = "tokio-rt")] +pub fn builder(config_fn: F) -> Result<(), ConfigError> +where + F: FnOnce(&mut ::tokio::runtime::Builder), +{ + self::tokio::configure_runtime(config_fn) +} + +/// Get a handle to the global async runtime. +/// +/// This provides a runtime-agnostic API. The handle type is determined +/// by the enabled runtime feature. +/// +/// Returns a handle to the current runtime if already inside one, otherwise +/// returns a handle to the global runtime (initializing it with defaults if needed). +/// +/// # Example +/// +/// ```ignore +/// let handle = rustler::runtime::handle(); +/// handle.spawn(async { +/// // Your async code +/// }); +/// ``` +#[cfg(feature = "tokio-rt")] +pub fn handle() -> ::tokio::runtime::Handle { + self::tokio::runtime_handle() +} diff --git a/rustler/src/tokio/runtime.rs b/rustler/src/runtime/tokio.rs similarity index 80% rename from rustler/src/tokio/runtime.rs rename to rustler/src/runtime/tokio.rs index d9857006..19953892 100644 --- a/rustler/src/tokio/runtime.rs +++ b/rustler/src/runtime/tokio.rs @@ -1,8 +1,34 @@ +use crate::runtime::AsyncRuntime; use crate::{Decoder, NifResult, Term}; use once_cell::sync::OnceCell; +use std::future::Future; +use std::pin::Pin; use std::sync::Arc; use tokio::runtime::Runtime; +/// Tokio runtime implementation of AsyncRuntime. +pub struct TokioRuntime { + handle: tokio::runtime::Handle, +} + +impl TokioRuntime { + pub fn new(runtime: Arc) -> Self { + Self { + handle: runtime.handle().clone(), + } + } + + pub fn from_handle(handle: tokio::runtime::Handle) -> Self { + Self { handle } + } +} + +impl AsyncRuntime for TokioRuntime { + fn spawn(&self, future: Pin + Send + 'static>>) { + self.handle.spawn(future); + } +} + /// Global tokio runtime for async NIFs. /// /// This runtime can be configured via `configure_runtime()` in your NIF's `load` callback, @@ -106,8 +132,8 @@ impl<'a> Decoder<'a> for RuntimeConfig { /// Configure the global Tokio runtime from Elixir load_data. /// -/// This is the recommended way to configure the runtime, allowing Elixir application -/// developers to tune the runtime without recompiling the NIF. +/// **Note:** Most users should use `rustler::runtime::configure()` instead for +/// a more runtime-agnostic API. /// /// # Example /// @@ -115,10 +141,10 @@ impl<'a> Decoder<'a> for RuntimeConfig { /// use rustler::{Env, Term}; /// /// fn load(_env: Env, load_info: Term) -> bool { -/// // Try to decode runtime config from load_info -/// if let Ok(config) = load_info.decode::() { -/// rustler::tokio::configure(config) -/// .expect("Failed to configure Tokio runtime"); +/// // Prefer: rustler::runtime::configure() +/// if let Ok(config) = load_info.decode::() { +/// rustler::runtime::configure(config) +/// .expect("Failed to configure runtime"); /// } /// true /// } @@ -160,9 +186,8 @@ pub fn configure(config: RuntimeConfig) -> Result<(), ConfigError> { /// Configure the global Tokio runtime programmatically. /// -/// This provides direct access to the Tokio Builder API for advanced use cases. -/// For most applications, prefer `configure_runtime_from_term` which allows -/// configuration from Elixir. +/// **Note:** Most users should use `rustler::runtime::builder()` instead for +/// a more runtime-agnostic API. /// /// # Example /// @@ -170,12 +195,13 @@ pub fn configure(config: RuntimeConfig) -> Result<(), ConfigError> { /// use rustler::{Env, Term}; /// /// fn load(_env: Env, _: Term) -> bool { -/// rustler::tokio::configure_runtime(|builder| { +/// // Prefer: rustler::runtime::builder() +/// rustler::runtime::builder(|builder| { /// builder /// .worker_threads(4) -/// .thread_name("myapp-tokio") +/// .thread_name("myapp-runtime") /// .thread_stack_size(3 * 1024 * 1024); -/// }).expect("Failed to configure Tokio runtime"); +/// }).expect("Failed to configure runtime"); /// /// true /// } @@ -198,6 +224,9 @@ where } /// Get a handle to the global tokio runtime, or the current runtime if already inside one. +/// +/// **Note:** Most users should use `rustler::runtime::handle()` instead for +/// a more runtime-agnostic API. pub fn runtime_handle() -> tokio::runtime::Handle { // Try to get the current runtime handle first (if already in a tokio context) tokio::runtime::Handle::try_current().unwrap_or_else(|_| { diff --git a/rustler/src/task_ref.rs b/rustler/src/task_ref.rs deleted file mode 100644 index 1b456ac0..00000000 --- a/rustler/src/task_ref.rs +++ /dev/null @@ -1,37 +0,0 @@ -use std::sync::atomic::{AtomicU64, Ordering}; - -static TASK_COUNTER: AtomicU64 = AtomicU64::new(0); - -/// Task reference resource for async tasks. -/// -/// This is automatically created by `#[rustler::task]` and returned to the caller. -/// All messages sent by the task (both intermediate and final) are tagged with this reference. -#[cfg(feature = "tokio_rt")] -#[derive(Debug, Clone)] -pub struct TaskRef { - #[allow(dead_code)] - id: u64, -} - -#[cfg(feature = "tokio_rt")] -impl TaskRef { - /// Create a new TaskRef with a unique ID. - /// - /// This is used internally by the `#[rustler::task]` macro. - #[doc(hidden)] - pub fn new() -> Self { - Self { - id: TASK_COUNTER.fetch_add(1, Ordering::Relaxed), - } - } -} - -// Implement Resource trait -#[cfg(feature = "tokio_rt")] -impl crate::Resource for TaskRef {} - -// Auto-register TaskRef resource via inventory -#[cfg(feature = "tokio_rt")] -crate::codegen_runtime::inventory::submit! { - crate::resource::Registration::new::() -} diff --git a/rustler/src/tokio/mod.rs b/rustler/src/tokio/mod.rs deleted file mode 100644 index 3e28c799..00000000 --- a/rustler/src/tokio/mod.rs +++ /dev/null @@ -1,3 +0,0 @@ -mod runtime; - -pub use runtime::{configure, configure_runtime, runtime_handle, ConfigError, RuntimeConfig}; diff --git a/rustler/src/types/local_pid.rs b/rustler/src/types/local_pid.rs index b2c57e4f..c0b9bd71 100644 --- a/rustler/src/types/local_pid.rs +++ b/rustler/src/types/local_pid.rs @@ -9,6 +9,11 @@ pub struct LocalPid { c: ErlNifPid, } +// Safe: LocalPid is just a process identifier that can be safely sent across threads. +// PIDs are used for message passing in BEAM, which is inherently thread-safe. +unsafe impl Send for LocalPid {} +unsafe impl Sync for LocalPid {} + impl LocalPid { #[inline] pub fn as_c_arg(&self) -> &ErlNifPid { @@ -63,101 +68,6 @@ impl Ord for LocalPid { } } -/// Caller information for async tasks with type-safe message sending. -/// -/// Contains the calling process's PID and the task reference. When used as the first -/// parameter of a `#[rustler::task]`, it is automatically populated and provides -/// convenient methods for sending messages tagged with the task reference. -/// -/// The generic type `T` is automatically inferred from the task's return type, -/// ensuring that intermediate messages sent via `send()` are the same type as -/// the final result. -/// -/// # Example -/// -/// ```ignore -/// #[rustler::task] -/// async fn with_progress(caller: Caller, work: Vec) -> Result { -/// for (i, item) in work.iter().enumerate() { -/// // Type-checked: must send Result -/// caller.send(Ok(i as i64)); -/// process(item).await?; -/// } -/// Ok(work.len() as i64) -/// } -/// ``` -#[cfg(feature = "tokio_rt")] -#[derive(Clone)] -pub struct Caller { - pid: LocalPid, - task_ref: crate::ResourceArc, - _phantom: std::marker::PhantomData, -} - -#[cfg(feature = "tokio_rt")] -impl Caller { - /// Create a new Caller. - /// - /// This is only used internally by the task macro. - #[doc(hidden)] - pub fn new(pid: LocalPid, task_ref: crate::ResourceArc) -> Self { - Self { - pid, - task_ref, - _phantom: std::marker::PhantomData, - } - } - - /// Get the calling process's PID. - pub fn pid(&self) -> &LocalPid { - &self.pid - } - - /// Get the task reference. - pub fn task_ref(&self) -> &crate::ResourceArc { - &self.task_ref - } - - /// Send an intermediate message to the caller, automatically tagged with the task reference. - /// - /// The message will be sent as `{task_ref, message}`. - /// - /// The message type `T` must match the task's return type, ensuring type safety - /// for all messages sent during task execution. - /// - /// # Example - /// - /// ```ignore - /// #[rustler::task] - /// async fn process(caller: Caller, count: i64) -> String { - /// for i in 0..count { - /// caller.send(format!("Progress: {}", i)); // ✅ Type-safe - /// // caller.send(i); // ❌ Compile error: expected String, got i64 - /// } - /// "Done".to_string() - /// } - /// ``` - pub fn send(&self, message: T) { - let mut env = crate::OwnedEnv::new(); - let task_ref = self.task_ref.clone(); - let _ = env.send_and_clear(&self.pid, move |env| (task_ref, message).encode(env)); - } - - /// Send the final message and complete the task. - /// - /// This is used internally by the `#[rustler::task]` macro to send the - /// task's return value. User code should just return the value normally. - #[doc(hidden)] - pub fn finish(self, message: T) { - self.send(message); - } - - /// Check whether the calling process is alive. - pub fn is_alive(&self, env: Env) -> bool { - self.pid.is_alive(env) - } -} - impl Env<'_> { /// Return the calling process's pid. /// diff --git a/rustler/src/types/mod.rs b/rustler/src/types/mod.rs index 27512d38..c7b72005 100644 --- a/rustler/src/types/mod.rs +++ b/rustler/src/types/mod.rs @@ -30,9 +30,6 @@ pub mod tuple; pub mod local_pid; pub use self::local_pid::LocalPid; -#[cfg(feature = "tokio_rt")] -pub use self::local_pid::Caller; - #[doc(hidden)] pub mod reference; pub use self::reference::Reference; diff --git a/rustler_codegen/src/nif.rs b/rustler_codegen/src/nif.rs index 4e308961..976816ef 100644 --- a/rustler_codegen/src/nif.rs +++ b/rustler_codegen/src/nif.rs @@ -132,43 +132,62 @@ fn generate_task( inputs: Punctuated, return_type: &syn::ReturnType, ) -> TokenStream { - // Check if first parameter is Caller - let uses_caller = inputs - .first() - .and_then(|arg| { - if let syn::FnArg::Typed(typed) = arg { - if let syn::Type::Path(syn::TypePath { path, .. }) = &*typed.ty { - let segment = path.segments.last()?; - return Some(segment.ident == "Caller"); + // Check if first parameter is Channel + // and extract the types if present + let channel_info = inputs.first().and_then(|arg| { + if let syn::FnArg::Typed(typed) = arg { + if let syn::Type::Path(syn::TypePath { path, .. }) = &*typed.ty { + let segment = path.segments.last()?; + if segment.ident == "Channel" { + // Return the full type for generating Channel::new + return Some(typed.ty.clone()); } } - None - }) - .unwrap_or(false); + } + None + }); + + let uses_channel = channel_info.is_some(); let decoded_terms_async = extract_inputs_for_async(inputs.clone(), return_type); let argument_names = create_function_params(inputs); - // Generate code for sending the final result - let (send_result, caller_for_finish) = if uses_caller { - // When using Caller, clone it for the finish call - let caller_clone = quote! { - let caller_for_finish = caller.clone(); + // Determine the Channel type to use + let channel_type = if let Some(ty) = channel_info { + // Use the type from the function signature + ty + } else { + // Default to Channel<(), Response> where Response is the return type + let response_type = match return_type { + syn::ReturnType::Type(_, ty) => ty.clone(), + syn::ReturnType::Default => { + panic!("Async tasks must have an explicit return type"); + } }; - let finish = quote! { - caller_for_finish.finish(value); + syn::parse_quote! { rustler::runtime::Channel<(), #response_type> } + }; + + // Generate code for sending the final result + let (clone_setup, send_result) = if uses_channel { + // When using Channel, the function is responsible for calling finish() + // The macro just executes the function and does nothing with the result + let send = quote! { + // Function is responsible for calling channel.finish() }; - (finish, Some(caller_clone)) + (quote! {}, send) } else { - // When not using Caller, send directly - let direct_send = quote! { + // When not using Channel, clone channel_sender before async block + let clone = quote! { + let channel_sender_for_send = channel_sender.clone(); + }; + let send = quote! { let mut msg_env = rustler::OwnedEnv::new(); let _ = msg_env.send_and_clear(&pid, |env| { use rustler::Encoder; - (task_ref_for_spawn, value).encode(env) + (channel_sender_for_send, value).encode(env) }); }; - (direct_send, None) + (clone, send) }; quote! { @@ -203,30 +222,30 @@ fn generate_task( // Get the calling process PID let pid = env.pid(); - // Create a unique task reference resource - let task_ref = rustler::ResourceArc::new(rustler::TaskRef::new()); - let task_ref_for_spawn = task_ref.clone(); + // Create channel - if task doesn't use Channel param, + // still create Channel<(), Response> for message tagging + let (channel_sender, channel): (_, #channel_type) = rustler::runtime::Channel::new(pid); + + // Clone channel_sender if needed (for tasks without Channel param) + #clone_setup // Decode all arguments before spawning async task #decoded_terms_async - // Clone caller if needed for finish() call - #caller_for_finish - - // Spawn async task on tokio runtime - let handle = rustler::tokio::runtime_handle(); - handle.spawn(async move { - // Execute the async function and get the result + // Spawn async task + rustler::spawn(async move { + // Execute the async function + #[allow(unused_variables)] let value = #name(#argument_names).await; - // Send {ref, result} back to calling process + // Send {channel_sender, result} back to calling process #send_result }); - // Return the task reference immediately + // Return the channel sender as task reference use rustler::Encoder; rustler::codegen_runtime::NifReturned::Term( - task_ref.encode(env).as_c_arg() + channel_sender.encode(env).as_c_arg() ) } wrapper(env, &terms).apply(env) @@ -341,8 +360,8 @@ fn arity(inputs: Punctuated) -> u32 { if let syn::Type::Path(syn::TypePath { path, .. }) = &*typed.ty { let ident = path.segments.last().unwrap().ident.to_string(); - // Skip Env and Caller when they're the first parameter - if i == 0 && (ident == "Env" || ident == "Caller") { + // Skip Env, Caller, and Channel when they're the first parameter + if i == 0 && (ident == "Env" || ident == "Caller" || ident == "Channel") { continue; } @@ -366,8 +385,21 @@ fn extract_inputs_for_async( let mut tokens = TokenStream::new(); let mut args_offset = 0; - // Validate that async tasks have an explicit return type - if matches!(return_type, syn::ReturnType::Default) { + // Check if first parameter is Channel (determines if explicit return type is required) + let has_channel = inputs + .first() + .and_then(|arg| { + if let syn::FnArg::Typed(typed) = arg { + if let syn::Type::Path(syn::TypePath { path, .. }) = &*typed.ty { + return path.segments.last().map(|s| s.ident == "Channel"); + } + } + None + }) + .unwrap_or(false); + + // Validate that async tasks have an explicit return type (unless they have a Channel parameter) + if !has_channel && matches!(return_type, syn::ReturnType::Default) { panic!("Async tasks must have an explicit return type"); } @@ -380,16 +412,10 @@ fn extract_inputs_for_async( syn::Type::Path(syn::TypePath { path, .. }) => { let ident = path.segments.last().unwrap().ident.to_string(); - // Special case: Caller as first parameter - if param_idx == 0 && ident == "Caller" { - // Validate that generic argument matches return type (optional check) - // The Rust compiler will catch mismatches anyway, but we could add - // a better error message here if needed - - let caller_setup = quote! { - let #name: #typ = rustler::types::Caller::new(pid, task_ref.clone()); - }; - tokens.extend(caller_setup); + // Special case: Channel as first parameter + if param_idx == 0 && ident == "Channel" { + // Channel is already created by wrapper, just pass it through + // No need to decode from args, and it doesn't consume an arg slot args_offset = 1; // Don't consume an arg slot continue; } diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index fef4eb94..39e6b2c7 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -44,6 +44,11 @@ defmodule RustlerTest do def async_sleep_and_return(_, _), do: err() def async_tuple_multiply(_), do: err() def async_with_progress(_), do: err() + def async_spawned_work(_), do: err() + def async_channel_echo(), do: err() + def channel_send_string(_, _), do: err() + def stateful_worker(), do: err() + def worker_send_command(_, _), do: err() def term_debug(_), do: err() diff --git a/rustler_tests/native/rustler_test/.cargo/config.toml b/rustler_tests/native/rustler_test/.cargo/config.toml new file mode 100644 index 00000000..7287c244 --- /dev/null +++ b/rustler_tests/native/rustler_test/.cargo/config.toml @@ -0,0 +1,2 @@ +[build] +rustflags = ["--cfg", "rustler_unstable"] diff --git a/rustler_tests/native/rustler_test/Cargo.toml b/rustler_tests/native/rustler_test/Cargo.toml index ea3962bf..9ebd0d87 100644 --- a/rustler_tests/native/rustler_test/Cargo.toml +++ b/rustler_tests/native/rustler_test/Cargo.toml @@ -14,13 +14,15 @@ name = "hello_rust" path = "src/main.rs" [features] -default = ["rustler/tokio_rt"] -tokio_rt = ["rustler/tokio_rt"] +default = ["rustler/async-rt", "rustler/tokio-rt"] +async-rt = ["rustler/async-rt"] +tokio-rt = ["async-rt", "rustler/tokio-rt"] nif_version_2_14 = ["rustler/nif_version_2_14"] nif_version_2_15 = ["nif_version_2_14", "rustler/nif_version_2_15"] nif_version_2_16 = ["nif_version_2_15", "rustler/nif_version_2_16"] nif_version_2_17 = ["nif_version_2_16", "rustler/nif_version_2_17"] [dependencies] -rustler = { path = "../../../rustler", features = ["allocator", "tokio_rt"] } +rustler = { path = "../../../rustler", features = ["allocator", "async-rt", "tokio-rt"] } tokio = { version = "1", features = ["time"] } +futures-core = "0.3" diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index 82202a50..0a5665ae 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -20,15 +20,14 @@ mod test_term; mod test_thread; mod test_tuple; -// Temporarily add async_add explicitly to debug rustler::init!("Elixir.RustlerTest", load = load); fn load(env: rustler::Env, load_info: rustler::Term) -> bool { - // Configure Tokio runtime from Elixir load_data - #[cfg(feature = "tokio_rt")] + // Configure runtime from Elixir load_data + #[cfg(feature = "tokio-rt")] { - if let Ok(config) = load_info.decode::() { - rustler::tokio::configure(config).ok(); + if let Ok(config) = load_info.decode::() { + rustler::runtime::configure(config).ok(); } } diff --git a/rustler_tests/native/rustler_test/src/test_async.rs b/rustler_tests/native/rustler_test/src/test_async.rs index 82d84310..e8d5cbb7 100644 --- a/rustler_tests/native/rustler_test/src/test_async.rs +++ b/rustler_tests/native/rustler_test/src/test_async.rs @@ -1,4 +1,4 @@ -use rustler::types::Caller; +use rustler::runtime::Channel; use std::time::Duration; #[rustler::task] @@ -20,18 +20,164 @@ async fn async_tuple_multiply(input: (i64, i64)) -> i64 { } #[rustler::task] -async fn async_with_progress(caller: Caller, work_items: i64) -> i64 { +async fn async_with_progress(channel: Channel<(), i64>, work_items: i64) { let mut total = 0; for i in 0..work_items { tokio::time::sleep(Duration::from_millis(10)).await; total += i; - // Send progress update - automatically tagged with task ref - // Note: This would be a compile error if we tried to send the tuple: - // caller.send(("progress", i)); // ❌ Type error: expected i64, got tuple - caller.send(i); // ✅ Type-safe: i64 matches return type + // Send progress update - automatically tagged with channel sender + // Note: This would be a compile error if we tried to send the wrong type: + // channel.send("progress"); // Type error: expected i64, got &str + channel.send(i); // Type-safe: i64 matches return type } - total + // Send final result and consume channel + channel.finish(total); +} + +#[rustler::task] +async fn async_spawned_work(channel: Channel<(), i64>, work_items: i64) { + let mut total = 0; + + // Demonstrate that ResponseSender can be cloned and sent across threads + for i in 0..work_items { + let responder = channel.responder(); // Clone responder for each spawned task + rustler::spawn(async move { + tokio::time::sleep(Duration::from_millis(5)).await; + responder.send(i); // Send from spawned task + }); + total += i; + } + + tokio::time::sleep(Duration::from_millis(50)).await; // Wait for spawned tasks + + // Send final result + channel.finish(total); +} + +// Bidirectional task using Channel with Stream trait +#[rustler::task] +async fn async_channel_echo(channel: Channel) { + let mut channel = channel; // Make it mutable in the function body + let mut count = 0; + + // Use Stream trait for idiomatic async iteration + while let Some(msg) = channel.next().await { + if msg == "stop" { + break; + } + count += 1; + // Echo each message back + channel.send(format!("echo: {}", msg)); + } + + // Send final result and consume channel + channel.finish(format!("Received {} messages", count)); +} + +// NIF to send to channel +#[rustler::nif] +fn channel_send_string( + env: rustler::Env, + sender: rustler::runtime::ChannelSender, + message: rustler::Term, +) -> rustler::NifResult { + rustler::runtime::channel::send(env, sender, message) +} + +// Example using enums for Request and Response types +#[derive(rustler::NifTaggedEnum, Clone, Debug)] +enum WorkerCommand { + Add { value: i64 }, + Subtract { value: i64 }, + Multiply { value: i64 }, + GetCurrent, + Reset, + Shutdown, +} + +#[derive(rustler::NifTaggedEnum, Clone, Debug)] +enum WorkerResponse { + Updated { old_value: i64, new_value: i64 }, + Current { value: i64 }, + Reset, + Error { reason: String }, + ShuttingDown { final_value: i64, operations: i64 }, +} + +#[rustler::task] +async fn stateful_worker(channel: Channel) { + let mut channel = channel; + let mut current_value = 0i64; + let mut operation_count = 0i64; + + while let Some(cmd) = channel.next().await { + tokio::time::sleep(Duration::from_millis(5)).await; + + let response = match cmd { + WorkerCommand::Add { value } => { + let old = current_value; + current_value += value; + operation_count += 1; + WorkerResponse::Updated { + old_value: old, + new_value: current_value, + } + } + WorkerCommand::Subtract { value } => { + let old = current_value; + current_value -= value; + operation_count += 1; + WorkerResponse::Updated { + old_value: old, + new_value: current_value, + } + } + WorkerCommand::Multiply { value } => { + let old = current_value; + current_value *= value; + operation_count += 1; + WorkerResponse::Updated { + old_value: old, + new_value: current_value, + } + } + WorkerCommand::GetCurrent => WorkerResponse::Current { + value: current_value, + }, + WorkerCommand::Reset => { + current_value = 0; + operation_count = 0; + WorkerResponse::Reset + } + WorkerCommand::Shutdown => { + // Send shutdown response and break + channel.send(WorkerResponse::ShuttingDown { + final_value: current_value, + operations: operation_count, + }); + break; + } + }; + + channel.send(response); + } + + // Final message when loop exits + channel.finish(WorkerResponse::ShuttingDown { + final_value: current_value, + operations: operation_count, + }); +} + +// NIF to send commands to the stateful worker +#[rustler::nif] +fn worker_send_command( + env: rustler::Env, + sender: rustler::runtime::ChannelSender, + command: rustler::Term, +) -> rustler::NifResult { + rustler::runtime::channel::send(env, sender, command) } diff --git a/rustler_tests/test/async_test.exs b/rustler_tests/test/async_test.exs index f01e7448..7fe98dd1 100644 --- a/rustler_tests/test/async_test.exs +++ b/rustler_tests/test/async_test.exs @@ -76,4 +76,104 @@ defmodule RustlerTest.AsyncTest do # Final result should also be tagged with ref: {ref, 3} assert {ref, 3} in messages end + + test "async_spawned_work demonstrates Caller can be cloned and sent across threads" do + ref = RustlerTest.async_spawned_work(3) + assert is_reference(ref) + + # Should receive messages from spawned tasks and final result + # Each spawned task sends its i value, plus final result + messages = + for _ <- 1..4 do + receive do + msg -> msg + after + 500 -> :timeout + end + end + + # Check we got messages from spawned tasks (sent across threads) + assert {ref, 0} in messages + assert {ref, 1} in messages + assert {ref, 2} in messages + + # Final result: 0 + 1 + 2 = 3 + assert {ref, 3} in messages + end + + test "async_channel_echo demonstrates bidirectional communication with Stream trait" do + # Start the task and get channel sender (which is also the task ref) + channel_sender = RustlerTest.async_channel_echo() + assert is_reference(channel_sender) + + # Send messages to the task via the channel + assert :ok == RustlerTest.channel_send_string(channel_sender, "hello") + assert :ok == RustlerTest.channel_send_string(channel_sender, "world") + assert :ok == RustlerTest.channel_send_string(channel_sender, "stop") + + # Collect echo messages + messages = + for _ <- 1..3 do + receive do + msg -> msg + after + 500 -> :timeout + end + end + + # Check we got echoes (tagged with channel_sender) + assert {channel_sender, "echo: hello"} in messages + assert {channel_sender, "echo: world"} in messages + + # Final message with count + assert {channel_sender, "Received 2 messages"} in messages + end + + test "stateful_worker demonstrates enum-based commands and responses" do + # Start the stateful worker + worker = RustlerTest.stateful_worker() + assert is_reference(worker) + + # Add 10 + assert :ok == RustlerTest.worker_send_command(worker, {:add, %{value: 10}}) + + assert_receive {^worker, {:updated, %{old_value: 0, new_value: 10}}}, 500 + + # Add 5 + assert :ok == RustlerTest.worker_send_command(worker, {:add, %{value: 5}}) + + assert_receive {^worker, {:updated, %{old_value: 10, new_value: 15}}}, 500 + + # Multiply by 2 + assert :ok == RustlerTest.worker_send_command(worker, {:multiply, %{value: 2}}) + + assert_receive {^worker, {:updated, %{old_value: 15, new_value: 30}}}, 500 + + # Subtract 5 + assert :ok == RustlerTest.worker_send_command(worker, {:subtract, %{value: 5}}) + + assert_receive {^worker, {:updated, %{old_value: 30, new_value: 25}}}, 500 + + # Get current value + assert :ok == RustlerTest.worker_send_command(worker, :get_current) + + assert_receive {^worker, {:current, %{value: 25}}}, 500 + + # Reset + assert :ok == RustlerTest.worker_send_command(worker, :reset) + + assert_receive {^worker, :reset}, 500 + + # Verify reset worked + assert :ok == RustlerTest.worker_send_command(worker, :get_current) + + assert_receive {^worker, {:current, %{value: 0}}}, 500 + + # Shutdown + assert :ok == RustlerTest.worker_send_command(worker, :shutdown) + + # Should receive two shutdown messages: one from the command, one as final + assert_receive {^worker, {:shutting_down, %{final_value: 0, operations: 0}}}, 500 + assert_receive {^worker, {:shutting_down, %{final_value: 0, operations: 0}}}, 500 + end end From bdcb343d958ca0673d9cdcf313b4249efc0ab215 Mon Sep 17 00:00:00 2001 From: Sonny Scroggin Date: Thu, 13 Nov 2025 17:07:44 -0600 Subject: [PATCH 4/4] Blocking async NIFs --- rustler/src/codegen_runtime.rs | 8 + rustler/src/runtime/mod.rs | 6 + rustler/src/runtime/yielding.rs | 311 ++++++++++++++ rustler_codegen/src/lib.rs | 14 +- rustler_codegen/src/nif.rs | 105 ++++- rustler_tests/lib/rustler_test.ex | 6 + rustler_tests/native/rustler_test/src/lib.rs | 3 +- .../src/{test_async.rs => test_tasks.rs} | 0 .../native/rustler_test/src/test_yielding.rs | 73 ++++ rustler_tests/test/yielding_test.exs | 379 ++++++++++++++++++ 10 files changed, 882 insertions(+), 23 deletions(-) create mode 100644 rustler/src/runtime/yielding.rs rename rustler_tests/native/rustler_test/src/{test_async.rs => test_tasks.rs} (100%) create mode 100644 rustler_tests/native/rustler_test/src/test_yielding.rs create mode 100644 rustler_tests/test/yielding_test.exs diff --git a/rustler/src/codegen_runtime.rs b/rustler/src/codegen_runtime.rs index 9ab1f898..b3f22ee7 100644 --- a/rustler/src/codegen_runtime.rs +++ b/rustler/src/codegen_runtime.rs @@ -58,6 +58,14 @@ unsafe impl NifReturnable for OwnedBinary { } } +// Allow returning NifReturned directly from NIFs +// This is useful for advanced use cases like yielding NIFs +unsafe impl NifReturnable for NifReturned { + unsafe fn into_returned(self, _env: Env) -> NifReturned { + self + } +} + pub enum NifReturned { Term(NIF_TERM), Raise(NIF_TERM), diff --git a/rustler/src/runtime/mod.rs b/rustler/src/runtime/mod.rs index 569c633a..ed22f361 100644 --- a/rustler/src/runtime/mod.rs +++ b/rustler/src/runtime/mod.rs @@ -11,6 +11,12 @@ pub use tokio::TokioRuntime; #[cfg(feature = "tokio-rt")] pub use tokio::{ConfigError, RuntimeConfig}; +#[cfg(feature = "async-rt")] +pub mod yielding; + +#[cfg(feature = "async-rt")] +pub use yielding::{yield_now, yielding_nif_run, YieldingTaskState}; + #[cfg(rustler_unstable)] pub mod channel; diff --git a/rustler/src/runtime/yielding.rs b/rustler/src/runtime/yielding.rs new file mode 100644 index 00000000..1e06e7df --- /dev/null +++ b/rustler/src/runtime/yielding.rs @@ -0,0 +1,311 @@ +/// True cooperative yielding NIFs using enif_schedule_nif +/// +/// This approach makes NIF calls appear synchronous to Elixir while yielding internally. +use crate::codegen_runtime::NifReturned; +use crate::schedule::SchedulerFlags; +use crate::wrapper::NIF_TERM; +use crate::{Encoder, Env, ResourceArc}; +use std::ffi::CString; +use std::future::Future; +use std::pin::Pin; +use std::sync::Mutex; +use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker}; + +// Type-erased poll function that takes Env and returns encoded result +type PollFn = dyn FnMut(&mut Context<'_>, Env) -> Poll + Send; + +/// Saved state for a yielding computation +pub struct YieldingTaskState { + /// Type-erased future polling function + poll_fn: Mutex>>, +} + +impl crate::Resource for YieldingTaskState {} + +// Auto-register the resource +crate::codegen_runtime::inventory::submit! { + crate::resource::Registration::new::() +} + +/// Run a future cooperatively, yielding to the BEAM scheduler as needed. +/// +/// This is the main entry point for yielding NIFs. Call this with your async code +/// and it will handle yielding automatically. +/// +/// # Example +/// +/// ```ignore +/// use rustler::codegen_runtime::NifReturned; +/// +/// #[rustler::nif] +/// fn my_yielding_nif(env: Env) -> NifReturned { +/// yielding_nif_run(env, async { +/// // Your async code here - yields automatically +/// let mut sum = 0; +/// for i in 0..1000 { +/// sum += i; +/// // Yield periodically to avoid blocking +/// yield_now().await; +/// } +/// sum +/// }) +/// } +/// ``` +/// +/// From Elixir, this appears as a normal blocking call: +/// ```elixir +/// result = MyNif.my_yielding_nif() # Blocks cooperatively until done +/// ``` +pub fn yielding_nif_run(env: Env, future: F) -> NifReturned +where + F: Future + Send + 'static, + T: Encoder + Send + 'static, +{ + start_yielding(env, future) +} + +/// Internal function for managing continuation state. +/// +/// This should not be called directly by users. +pub fn yielding_nif( + env: Env, + state: Option>, + future: F, +) -> NifReturned +where + F: Future + Send + 'static, + T: Encoder + Send + 'static, +{ + match state { + None => { + // Initial call - create state and start + start_yielding(env, future) + } + Some(state_resource) => { + // Continuation - resume from state + resume_yielding(env, state_resource) + } + } +} + +/// Start a new yielding computation +fn start_yielding(env: Env, future: F) -> NifReturned +where + F: Future + Send + 'static, + T: Encoder + Send + 'static, +{ + // Box and pin the future + let mut future = Box::pin(future); + + // Create type-erased poll function + let poll_fn: Pin> = + Box::pin( + move |ctx: &mut Context<'_>, env: Env| match future.as_mut().poll(ctx) { + Poll::Ready(result) => Poll::Ready(result.encode(env).as_c_arg()), + Poll::Pending => Poll::Pending, + }, + ); + + // Create task state resource + let task_state = YieldingTaskState { + poll_fn: Mutex::new(poll_fn), + }; + let resource = ResourceArc::new(task_state); + + // Poll immediately + poll_and_return(env, resource) +} + +/// Resume a yielding computation from saved state +fn resume_yielding(env: Env, state: ResourceArc) -> NifReturned { + poll_and_return(env, state) +} + +/// Poll the future and return appropriate NifReturned +fn poll_and_return(env: Env, state: ResourceArc) -> NifReturned { + // Create a simple waker that does nothing (we'll poll again on reschedule) + let waker = noop_waker(); + let mut context = Context::from_waker(&waker); + + // Poll the future first - don't check timeslice before giving it a chance to complete + let result = { + let mut poll_fn = state + .poll_fn + .lock() + .expect("YieldingTaskState mutex poisoned"); + + // SAFETY: We're not moving the function, just calling it + let f = unsafe { poll_fn.as_mut().get_unchecked_mut() }; + f(&mut context, env) + }; + + match result { + Poll::Ready(term) => { + // Future completed - return result + NifReturned::Term(term) + } + Poll::Pending => { + // Future still running - check if we should yield + // Consume a small amount of timeslice (10%) and check if we should continue + if crate::schedule::consume_timeslice(env, 10) { + // Still have timeslice - could poll again immediately + // But for now, let's reschedule to give other work a chance + reschedule_continuation(env, state) + } else { + // Timeslice exhausted - definitely reschedule + reschedule_continuation(env, state) + } + } + } +} + +/// Reschedule the continuation to run again +fn reschedule_continuation(env: Env, state: ResourceArc) -> NifReturned { + // Encode the state resource as an argument for the continuation + let state_term = state.encode(env).as_c_arg(); + + NifReturned::Reschedule { + fun_name: CString::new("__yielding_continuation").unwrap(), + flags: SchedulerFlags::Normal, + fun: yielding_continuation_raw, + args: vec![state_term], + } +} + +/// Raw C-ABI continuation function called by enif_schedule_nif +unsafe extern "C" fn yielding_continuation_raw( + env_ptr: *mut crate::sys::ErlNifEnv, + argc: i32, + argv: *const NIF_TERM, +) -> NIF_TERM { + // Create Env from the pointer + let env = Env::new_internal(&env_ptr, env_ptr, crate::env::EnvKind::Callback); + + // Decode the state resource from argv[0] + if argc != 1 { + return env.error_tuple("Expected 1 argument").as_c_arg(); + } + + let state_term = crate::Term::new(env, *argv); + + match state_term.decode::>() { + Ok(state) => { + // Resume the computation + match resume_yielding(env, state) { + NifReturned::Term(term) => term, + NifReturned::Reschedule { + fun_name, + flags, + fun, + args, + } => { + // Call enif_schedule_nif to reschedule again + unsafe { + crate::sys::enif_schedule_nif( + env_ptr, + fun_name.as_ptr(), + flags as i32, + fun, + args.len() as i32, + args.as_ptr(), + ) + } + } + NifReturned::BadArg => crate::types::atom::error().encode(env).as_c_arg(), + NifReturned::Raise(term) => term, + } + } + Err(_) => { + // Failed to decode state - return error + env.error_tuple("Invalid task state").as_c_arg() + } + } +} + +/// Create a no-op waker +/// +/// Since we're using cooperative yielding with enif_schedule_nif, we don't need +/// the waker to do anything. We'll poll again when we're rescheduled. +fn noop_waker() -> Waker { + fn noop_clone(_: *const ()) -> RawWaker { + noop_raw_waker() + } + fn noop(_: *const ()) {} + + fn noop_raw_waker() -> RawWaker { + RawWaker::new( + std::ptr::null(), + &RawWakerVTable::new(noop_clone, noop, noop, noop), + ) + } + + unsafe { Waker::from_raw(noop_raw_waker()) } +} + +/// A simple future that yields once before completing. +/// +/// This is useful for inserting yield points in your async code to check +/// the timeslice and give the scheduler a chance to run other work. +/// +/// # Example +/// +/// ```ignore +/// async fn process_large_file(path: String) -> Result> { +/// let mut buffer = Vec::new(); +/// let mut file = std::fs::File::open(path)?; +/// +/// loop { +/// let mut chunk = vec![0u8; 4096]; +/// match file.read(&mut chunk)? { +/// 0 => break, +/// n => { +/// buffer.extend_from_slice(&chunk[..n]); +/// // Yield to scheduler periodically +/// yield_now().await; +/// } +/// } +/// } +/// +/// Ok(buffer) +/// } +/// ``` +pub struct YieldNow { + yielded: bool, +} + +impl Future for YieldNow { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll { + if self.yielded { + Poll::Ready(()) + } else { + self.yielded = true; + Poll::Pending + } + } +} + +/// Yield control back to the BEAM scheduler. +/// +/// This returns a future that yields once before completing, allowing +/// the scheduler to run other work if needed. +pub fn yield_now() -> YieldNow { + YieldNow { yielded: false } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_yield_now_completes() { + // YieldNow should return Pending once, then Ready + let mut future = Box::pin(yield_now()); + let waker = noop_waker(); + let mut ctx = Context::from_waker(&waker); + + assert!(matches!(future.as_mut().poll(&mut ctx), Poll::Pending)); + assert!(matches!(future.as_mut().poll(&mut ctx), Poll::Ready(()))); + } +} diff --git a/rustler_codegen/src/lib.rs b/rustler_codegen/src/lib.rs index 7f6deb13..98132534 100644 --- a/rustler_codegen/src/lib.rs +++ b/rustler_codegen/src/lib.rs @@ -103,17 +103,7 @@ pub fn nif(args: TokenStream, input: TokenStream) -> TokenStream { let input = syn::parse_macro_input!(input as syn::ItemFn); - // Reject async functions in #[rustler::nif] - if input.sig.asyncness.is_some() { - return syn::Error::new_spanned( - input.sig.asyncness, - "async functions are not supported with #[rustler::nif]. Use #[rustler::task] instead.", - ) - .to_compile_error() - .into(); - } - - nif::transcoder_decorator(nif_attributes, input).into() + nif::transcoder_decorator(nif_attributes, input, false).into() } /// Wrap an async function as a spawned task that returns a reference. @@ -162,7 +152,7 @@ pub fn task(args: TokenStream, input: TokenStream) -> TokenStream { .into(); } - nif::transcoder_decorator(nif_attributes, input).into() + nif::transcoder_decorator(nif_attributes, input, true).into() } /// Derives implementations for the `Encoder` and `Decoder` traits diff --git a/rustler_codegen/src/nif.rs b/rustler_codegen/src/nif.rs index 976816ef..3288f245 100644 --- a/rustler_codegen/src/nif.rs +++ b/rustler_codegen/src/nif.rs @@ -35,7 +35,11 @@ impl NifAttributes { } } -pub fn transcoder_decorator(nif_attributes: NifAttributes, fun: syn::ItemFn) -> TokenStream { +pub fn transcoder_decorator( + nif_attributes: NifAttributes, + fun: syn::ItemFn, + is_task: bool, +) -> TokenStream { let sig = &fun.sig; let name = &sig.ident; let inputs = &sig.inputs; @@ -53,15 +57,29 @@ pub fn transcoder_decorator(nif_attributes: NifAttributes, fun: syn::ItemFn) -> } if is_async { - generate_task( - erl_func_name, - name, - flags, - arity, - function, - inputs.clone(), - &sig.output, - ) + if is_task { + // #[rustler::task] - message-based async NIF + generate_task( + erl_func_name, + name, + flags, + arity, + function, + inputs.clone(), + &sig.output, + ) + } else { + // #[rustler::nif] async - cooperative yielding NIF + generate_yielding_nif( + erl_func_name, + name, + flags, + arity, + function, + inputs.clone(), + &sig.output, + ) + } } else { generate_nif(erl_func_name, name, flags, arity, function, inputs.clone()) } @@ -257,6 +275,73 @@ fn generate_task( } } +fn generate_yielding_nif( + erl_func_name: String, + name: &syn::Ident, + flags: TokenStream, + arity: u32, + function: TokenStream, + inputs: Punctuated, + return_type: &syn::ReturnType, +) -> TokenStream { + // Extract inputs for async functions (similar to generate_task) + let decoded_terms = extract_inputs_for_async(inputs.clone(), return_type); + let argument_names = create_function_params(inputs); + + quote! { + // Define the original async function at module level + #function + + // Submit the NIF wrapper to inventory + rustler::codegen_runtime::inventory::submit!( + rustler::Nif { + name: concat!(#erl_func_name, "\0").as_ptr() + as *const rustler::codegen_runtime::c_char, + arity: #arity, + flags: #flags as rustler::codegen_runtime::c_uint, + raw_func: { + unsafe extern "C" fn nif_func( + nif_env: rustler::codegen_runtime::NIF_ENV, + argc: rustler::codegen_runtime::c_int, + argv: *const rustler::codegen_runtime::NIF_TERM + ) -> rustler::codegen_runtime::NIF_TERM { + let lifetime = (); + let env = rustler::Env::new(&lifetime, nif_env); + + let terms = std::slice::from_raw_parts(argv, argc as usize) + .iter() + .map(|term| rustler::Term::new(env, *term)) + .collect::>(); + + fn wrapper<'a>( + env: rustler::Env<'a>, + args: &[rustler::Term<'a>] + ) -> rustler::codegen_runtime::NifReturned { + let result: std::thread::Result<_> = + std::panic::catch_unwind(move || { + // Decode all arguments before creating the future + #decoded_terms + + // Call yielding_nif_run with the async function call + rustler::runtime::yielding_nif_run(env, async move { + #name(#argument_names).await + }) + }); + + match result { + Ok(nif_returned) => nif_returned, + Err(_) => rustler::codegen_runtime::NifReturned::BadArg, + } + } + wrapper(env, &terms).apply(env) + } + nif_func + } + } + ); + } +} + fn schedule_flag(schedule: Option) -> TokenStream { let mut tokens = TokenStream::new(); diff --git a/rustler_tests/lib/rustler_test.ex b/rustler_tests/lib/rustler_test.ex index 39e6b2c7..8a3ab896 100644 --- a/rustler_tests/lib/rustler_test.ex +++ b/rustler_tests/lib/rustler_test.ex @@ -50,6 +50,12 @@ defmodule RustlerTest do def stateful_worker(), do: err() def worker_send_command(_, _), do: err() + # Yielding runtime NIFs (true cooperative yielding) + def yielding_immediate(), do: err() + def yielding_sum(_), do: err() + def yielding_work_with_sleeps(), do: err() + def yielding_tuple_result(_, _), do: err() + def term_debug(_), do: err() def term_debug_and_reparse(term) do diff --git a/rustler_tests/native/rustler_test/src/lib.rs b/rustler_tests/native/rustler_test/src/lib.rs index 0a5665ae..1e84d9b3 100644 --- a/rustler_tests/native/rustler_test/src/lib.rs +++ b/rustler_tests/native/rustler_test/src/lib.rs @@ -1,4 +1,3 @@ -mod test_async; mod test_atom; mod test_binary; mod test_codegen; @@ -16,9 +15,11 @@ mod test_path; mod test_primitives; mod test_range; mod test_resource; +mod test_tasks; mod test_term; mod test_thread; mod test_tuple; +mod test_yielding; rustler::init!("Elixir.RustlerTest", load = load); diff --git a/rustler_tests/native/rustler_test/src/test_async.rs b/rustler_tests/native/rustler_test/src/test_tasks.rs similarity index 100% rename from rustler_tests/native/rustler_test/src/test_async.rs rename to rustler_tests/native/rustler_test/src/test_tasks.rs diff --git a/rustler_tests/native/rustler_test/src/test_yielding.rs b/rustler_tests/native/rustler_test/src/test_yielding.rs new file mode 100644 index 00000000..07e2e222 --- /dev/null +++ b/rustler_tests/native/rustler_test/src/test_yielding.rs @@ -0,0 +1,73 @@ +// ============================================================================ +// Cooperative Yielding NIFs (using enif_schedule_nif) +// ============================================================================ +// +// These NIFs use `#[rustler::nif] async fn` to implement true cooperative yielding. +// They appear synchronous to Elixir but yield internally to the BEAM scheduler. +// No enif_send, no messages - results are returned through normal NIF return values. + +use rustler::runtime::yield_now; + +/// Test immediate completion - no yields needed +#[rustler::nif] +async fn yielding_immediate() -> i64 { + 42 +} + +/// Test cooperative yielding with CPU-bound work +#[rustler::nif] +async fn yielding_sum(count: i64) -> i64 { + let mut sum = 0i64; + for i in 0..count { + sum += i; + // Yield every 100 iterations to avoid blocking the scheduler + if i % 100 == 0 { + yield_now().await; + } + } + sum +} + +/// Test yielding with blocking I/O (simulated with sleep) +#[rustler::nif] +async fn yielding_work_with_sleeps() -> String { + let mut result = String::from("Processing"); + + for i in 0..5 { + // Simulate some work + for _ in 0..1000 { + result.push('.'); + } + + // Yield to scheduler + yield_now().await; + + result.push_str(&format!(" step{}", i)); + } + + result +} + +/// Test returning complex types +#[rustler::nif] +async fn yielding_tuple_result(x: i64, y: i64) -> (i64, i64, &'static str) { + // Simulate some computation with yields + let mut sum = 0i64; + for i in 0..x { + sum += i; + if i % 10 == 0 { + yield_now().await; + } + } + + let mut product = 1i64; + for i in 1..=y { + product *= i; + if i % 10 == 0 { + yield_now().await; + } + } + + // Return tuple + (sum, product, "done") +} diff --git a/rustler_tests/test/yielding_test.exs b/rustler_tests/test/yielding_test.exs new file mode 100644 index 00000000..f9946732 --- /dev/null +++ b/rustler_tests/test/yielding_test.exs @@ -0,0 +1,379 @@ +defmodule RustlerTest.YieldingTest do + use ExUnit.Case, async: false + + # These tests verify TRUE cooperative yielding NIFs using enif_schedule_nif. + # Unlike async NIFs (which return references and send messages), + # yielding NIFs appear synchronous but yield internally to the BEAM scheduler. + + describe "yielding_immediate/0" do + test "returns result immediately without yielding" do + # This should complete on first poll without needing to reschedule + result = RustlerTest.yielding_immediate() + assert result == 42 + end + + test "returns i64 type" do + result = RustlerTest.yielding_immediate() + assert is_integer(result) + end + end + + describe "yielding_sum/1" do + test "computes sum of 0..n correctly" do + # sum(0..99) = 4950 + result = RustlerTest.yielding_sum(100) + assert result == 4950 + end + + test "yields during computation for large inputs" do + # This should trigger multiple yield points (every 100 iterations) + # sum(0..9999) = 49995000 + result = RustlerTest.yielding_sum(10_000) + assert result == 49_995_000 + end + + test "works with small inputs that don't need yielding" do + # sum(0..9) = 45 + result = RustlerTest.yielding_sum(10) + assert result == 45 + end + + test "handles zero input" do + result = RustlerTest.yielding_sum(0) + assert result == 0 + end + + test "is deterministic - same input gives same output" do + result1 = RustlerTest.yielding_sum(1000) + result2 = RustlerTest.yielding_sum(1000) + assert result1 == result2 + assert result1 == 499_500 + end + + test "blocks the calling process until complete" do + # This verifies the synchronous nature - we don't receive messages, + # the function call just blocks until the result is ready + start_time = System.monotonic_time(:millisecond) + result = RustlerTest.yielding_sum(10_000) + end_time = System.monotonic_time(:millisecond) + + assert result == 49_995_000 + # Should take some time due to computation and yielding + assert end_time >= start_time + end + end + + describe "yielding_work_with_sleeps/0" do + test "returns processed string with step markers" do + result = RustlerTest.yielding_work_with_sleeps() + + # Should contain "Processing" at the start + assert String.starts_with?(result, "Processing") + + # Should contain step markers + assert result =~ "step0" + assert result =~ "step1" + assert result =~ "step2" + assert result =~ "step3" + assert result =~ "step4" + end + + test "includes dots from simulated work" do + result = RustlerTest.yielding_work_with_sleeps() + + # Should have dots from the work simulation (1000 dots per step * 5 steps) + dot_count = result |> String.graphemes() |> Enum.count(&(&1 == ".")) + assert dot_count == 5000 + end + + test "processes all 5 steps in order" do + result = RustlerTest.yielding_work_with_sleeps() + + # Extract positions of step markers + step0_pos = :binary.match(result, "step0") |> elem(0) + step1_pos = :binary.match(result, "step1") |> elem(0) + step2_pos = :binary.match(result, "step2") |> elem(0) + step3_pos = :binary.match(result, "step3") |> elem(0) + step4_pos = :binary.match(result, "step4") |> elem(0) + + # Steps should appear in order + assert step0_pos < step1_pos + assert step1_pos < step2_pos + assert step2_pos < step3_pos + assert step3_pos < step4_pos + end + end + + describe "yielding_tuple_result/2" do + test "returns tuple with sum, product, and status" do + {sum, product, status} = RustlerTest.yielding_tuple_result(10, 5) + + # sum(0..9) = 45 + assert sum == 45 + # product(1..5) = 120 + assert product == 120 + assert status == "done" + end + + test "handles edge cases" do + # x=0 should give sum=0 + {sum, product, status} = RustlerTest.yielding_tuple_result(0, 1) + assert sum == 0 + assert product == 1 + assert status == "done" + end + + test "computes correct factorial in product" do + # 5! = 120 + {_sum, product, _status} = RustlerTest.yielding_tuple_result(1, 5) + assert product == 120 + + # 6! = 720 + {_sum, product, _status} = RustlerTest.yielding_tuple_result(1, 6) + assert product == 720 + end + + test "yields during computation for larger inputs" do + # This should trigger multiple yield points + {sum, product, status} = RustlerTest.yielding_tuple_result(100, 10) + + assert sum == 4950 + # 10! + assert product == 3_628_800 + assert status == "done" + end + + test "returns correct types" do + result = RustlerTest.yielding_tuple_result(5, 3) + + assert is_tuple(result) + assert tuple_size(result) == 3 + + {sum, product, status} = result + assert is_integer(sum) + assert is_integer(product) + assert is_binary(status) + end + end + + describe "cooperative yielding behavior" do + test "multiple yielding calls can be made sequentially" do + # These should all complete and return results directly + result1 = RustlerTest.yielding_sum(100) + result2 = RustlerTest.yielding_sum(200) + result3 = RustlerTest.yielding_immediate() + + assert result1 == 4950 + assert result2 == 19_900 + assert result3 == 42 + end + + test "yielding NIFs don't send messages" do + # Clear mailbox + flush_mailbox() + + # Call yielding NIF + result = RustlerTest.yielding_sum(1000) + assert result == 499_500 + + # Verify no messages were sent + refute_receive _, 100 + end + + test "yielding NIFs block the calling process" do + # Start a task that calls a yielding NIF + parent = self() + + task = + Task.async(fn -> + send(parent, :started) + result = RustlerTest.yielding_sum(10_000) + send(parent, {:completed, result}) + {:completed, result} + end) + + # Wait for task to start + assert_receive :started, 100 + + # The task should block until the NIF completes + # We should receive :completed, not timeout + assert_receive {:completed, 49_995_000}, 5_000 + + # Verify task completes successfully and returns the result + assert Task.await(task) == {:completed, 49_995_000} + end + + test "concurrent yielding calls from different processes" do + # Spawn multiple processes calling yielding NIFs concurrently + parent = self() + + for i <- 1..5 do + spawn(fn -> + result = RustlerTest.yielding_sum(1000) + send(parent, {:result, i, result}) + end) + end + + # Collect all results + results = + for _ <- 1..5 do + receive do + {:result, i, result} -> {i, result} + after + 5_000 -> :timeout + end + end + + # All should have computed the same result + assert length(results) == 5 + assert Enum.all?(results, fn {_i, result} -> result == 499_500 end) + end + end + + describe "performance characteristics" do + test "yielding NIFs don't block the scheduler excessively" do + # This test verifies that yielding NIFs cooperate with the scheduler + # by measuring if other work can interleave + + parent = self() + counter = :counters.new(1, [:atomics]) + + # Start a process that increments a counter in a tight loop + _counter_task = + spawn(fn -> + for _ <- 1..1000 do + :counters.add(counter, 1, 1) + Process.sleep(1) + end + + send(parent, :counter_done) + end) + + # Start the yielding NIF computation + _nif_task = + spawn(fn -> + result = RustlerTest.yielding_sum(100_000) + send(parent, {:nif_done, result}) + end) + + # Wait for both to complete + assert_receive :counter_done, 10_000 + assert_receive {:nif_done, 4_999_950_000}, 10_000 + + # The counter should have made good progress despite the NIF running + # This shows the NIF yielded and let other work run + count = :counters.get(counter, 1) + # Should have counted most of the iterations + assert count > 500 + end + end + + describe "reduction consumption" do + test "yielding NIFs consume reductions" do + # Get initial reduction count + {:reductions, reductions_before} = Process.info(self(), :reductions) + + # Call a yielding NIF that should consume reductions + result = RustlerTest.yielding_sum(10_000) + + # Get reduction count after NIF call + {:reductions, reductions_after} = Process.info(self(), :reductions) + + # Verify the result is correct + assert result == 49_995_000 + + # Verify reductions were consumed (should be significantly higher) + reductions_consumed = reductions_after - reductions_before + + # The NIF should consume reductions - we expect at least some consumption + # since we're yielding ~100 times (every 100 iterations for 10,000 iterations) + assert reductions_consumed > 0, + "Expected reductions to be consumed, but consumed #{reductions_consumed}" + end + + test "larger computations consume more reductions" do + # Small computation + {:reductions, before_small} = Process.info(self(), :reductions) + _result_small = RustlerTest.yielding_sum(1_000) + {:reductions, after_small} = Process.info(self(), :reductions) + small_consumed = after_small - before_small + + # Larger computation (10x more iterations, 10x more yields) + {:reductions, before_large} = Process.info(self(), :reductions) + _result_large = RustlerTest.yielding_sum(10_000) + {:reductions, after_large} = Process.info(self(), :reductions) + large_consumed = after_large - before_large + + # Larger computation should consume more reductions + assert large_consumed > small_consumed, + "Expected large computation (#{large_consumed}) to consume more than small (#{small_consumed})" + end + + test "immediate completion NIFs consume minimal reductions" do + # Get initial reduction count + {:reductions, reductions_before} = Process.info(self(), :reductions) + + # Call an immediate NIF (no yields) + result = RustlerTest.yielding_immediate() + + {:reductions, reductions_after} = Process.info(self(), :reductions) + + # Verify the result + assert result == 42 + + # This should consume very few reductions since it completes immediately + reductions_consumed = reductions_after - reductions_before + + # Should still consume some reductions (for the NIF call itself) + assert reductions_consumed >= 0 + end + + test "yielding NIFs consume reductions across multiple calls" do + # Make multiple NIF calls and track total reduction consumption + total_reductions = + Enum.reduce(1..5, 0, fn _, acc -> + {:reductions, before} = Process.info(self(), :reductions) + _result = RustlerTest.yielding_sum(1_000) + {:reductions, after_call} = Process.info(self(), :reductions) + + acc + (after_call - before) + end) + + # Total reductions should be significant + assert total_reductions > 0, + "Expected multiple NIF calls to consume reductions, consumed #{total_reductions}" + end + + test "reductions are consumed in spawned process" do + parent = self() + + # Spawn a process to run the yielding NIF + spawn(fn -> + {:reductions, before} = Process.info(self(), :reductions) + result = RustlerTest.yielding_sum(10_000) + {:reductions, after_call} = Process.info(self(), :reductions) + + consumed = after_call - before + send(parent, {:reductions_consumed, consumed, result}) + end) + + # Receive the result + assert_receive {:reductions_consumed, consumed, result}, 5_000 + + # Verify result and reductions + assert result == 49_995_000 + + assert consumed > 0, + "Expected spawned process to consume reductions, consumed #{consumed}" + end + end + + # Helper to flush mailbox + defp flush_mailbox do + receive do + _ -> flush_mailbox() + after + 0 -> :ok + end + end +end