diff --git a/grpc/src/context.rs b/grpc/src/context.rs new file mode 100644 index 000000000..528973df5 --- /dev/null +++ b/grpc/src/context.rs @@ -0,0 +1,125 @@ +mod extensions; +mod task_local_context; + +pub use extensions::{FutureExt, StreamExt}; +pub use task_local_context::current; + +use std::any::{Any, TypeId}; +use std::collections::HashMap; +use std::sync::Arc; +use std::time::Instant; + +/// A task-local context for propagating metadata, deadlines, and other request-scoped values. +pub trait Context: Send + Sync + 'static { + /// Get the deadline for the current context. + fn deadline(&self) -> Option; + + /// Create a new context with the given deadline. + fn with_deadline(&self, deadline: Instant) -> Arc; + + /// Get a value from the context extensions. + fn get(&self, type_id: TypeId) -> Option<&(dyn Any + Send + Sync)>; + + /// Create a new context with the given value. + fn with_value(&self, type_id: TypeId, value: Arc) -> Arc; +} + +#[derive(Clone, Default)] +struct ContextInner { + deadline: Option, + extensions: HashMap>, +} + +#[derive(Clone, Default)] +pub(crate) struct ContextImpl { + inner: Arc, +} + +impl Context for ContextImpl { + fn deadline(&self) -> Option { + self.inner.deadline + } + + fn with_deadline(&self, deadline: Instant) -> Arc { + let mut inner = (*self.inner).clone(); + inner.deadline = Some(deadline); + Arc::new(Self { + inner: Arc::new(inner), + }) + } + + fn get(&self, type_id: TypeId) -> Option<&(dyn Any + Send + Sync)> { + self.inner.extensions.get(&type_id).map(|v| &**v as _) + } + + fn with_value(&self, type_id: TypeId, value: Arc) -> Arc { + let mut inner = (*self.inner).clone(); + inner.extensions.insert(type_id, value); + Arc::new(Self { + inner: Arc::new(inner), + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::time::Duration; + + #[test] + fn default_context_has_no_deadline_or_extensions() { + let context = ContextImpl::default(); + assert!(context.deadline().is_none()); + assert!(context.get(TypeId::of::()).is_none()); + } + + #[test] + fn with_deadline_sets_deadline_and_preserves_original() { + let context = ContextImpl::default(); + let deadline = Instant::now() + Duration::from_secs(5); + let context_with_deadline = context.with_deadline(deadline); + + assert_eq!(context_with_deadline.deadline(), Some(deadline)); + // Original context should remain unchanged + assert!(context.deadline().is_none()); + } + + #[test] + fn with_value_stores_extension_and_preserves_original() { + let context = ContextImpl::default(); + + #[derive(Debug, PartialEq)] + struct MyValue(i32); + + let context_with_value = context.with_value(TypeId::of::(), Arc::new(MyValue(42))); + + let value = context_with_value + .get(TypeId::of::()) + .and_then(|v| v.downcast_ref::()); + assert_eq!(value, Some(&MyValue(42))); + + // Original context should not have the value + assert!(context.get(TypeId::of::()).is_none()); + } + + #[test] + fn with_value_overwrites_existing_extension_and_preserves_previous() { + let context = ContextImpl::default(); + + #[derive(Debug, PartialEq)] + struct MyValue(i32); + + let ctx1 = context.with_value(TypeId::of::(), Arc::new(MyValue(10))); + let ctx2 = ctx1.with_value(TypeId::of::(), Arc::new(MyValue(20))); + + let val1 = ctx1 + .get(TypeId::of::()) + .and_then(|v| v.downcast_ref::()); + let val2 = ctx2 + .get(TypeId::of::()) + .and_then(|v| v.downcast_ref::()); + + assert_eq!(val1, Some(&MyValue(10))); + assert_eq!(val2, Some(&MyValue(20))); + } +} diff --git a/grpc/src/context/extensions.rs b/grpc/src/context/extensions.rs new file mode 100644 index 000000000..a76239e60 --- /dev/null +++ b/grpc/src/context/extensions.rs @@ -0,0 +1,122 @@ +//! Extension traits for `Future` and `Stream` to provide context propagation. +//! +//! This module provides the [`FutureExt`] and [`StreamExt`] traits, which allow +//! attaching a [`Context`] to a [`Future`] or [`Stream`]. This ensures that the +//! context is set as the current task-local context whenever the future or stream +//! is polled. + +use std::future::Future; +use std::sync::Arc; +use tokio_stream::Stream; + +use super::task_local_context; +use super::Context; + +/// Extension trait for `Future` to provide context propagation. +/// +/// This trait allows attaching a [`Context`] to a [`Future`], ensuring that the context +/// is set as the current task-local context whenever the future is polled. +/// +/// # Examples +/// +/// ```rust +/// # use std::sync::Arc; +/// # use grpc::context::{Context, FutureExt}; +/// # async fn example() { +/// let context = grpc::context::current(); +/// let future = async { +/// // Context is available here +/// assert!(grpc::context::current().deadline().is_none()); +/// }; +/// +/// future.with_context(context).await; +/// # } +/// ``` +pub trait FutureExt: Future { + /// Attach a context to this future. + /// + /// The context will be set as the current task-local context whenever the future is polled. + fn with_context(self, context: Arc) -> impl Future + where + Self: Sized, + { + task_local_context::ContextScope::new(self, context) + } +} + +impl FutureExt for F {} + +/// Extension trait for `Stream` to provide context propagation. +/// +/// This trait allows attaching a [`Context`] to a [`Stream`], ensuring that the context +/// is set as the current task-local context whenever the stream is polled. +/// +/// # Examples +/// +/// ```rust +/// # use std::sync::Arc; +/// # use grpc::context::{Context, StreamExt}; +/// # use tokio_stream::StreamExt as _; +/// # async fn example() { +/// let context = grpc::context::current(); +/// let stream = tokio_stream::iter(vec![1, 2, 3]); +/// +/// let mut scoped_stream = stream.with_context(context); +/// +/// while let Some(item) = scoped_stream.next().await { +/// // Context is available here +/// assert!(grpc::context::current().deadline().is_none()); +/// } +/// # } +/// ``` +pub trait StreamExt: Stream { + /// Attach a context to this stream. + /// + /// The context will be set as the current task-local context whenever the stream is polled. + fn with_context(self, context: Arc) -> impl Stream + where + Self: Sized, + { + task_local_context::ContextScope::new(self, context) + } +} + +impl StreamExt for S {} + +#[cfg(test)] +mod tests { + use super::super::ContextImpl; + use super::*; + use tokio_stream::StreamExt as _; + + #[tokio::test] + async fn test_future_ext_attaches_context_correctly() { + let ctx = ContextImpl::default(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + let ctx = ctx.with_deadline(deadline); + + let future = async { + let current_ctx = super::task_local_context::current(); + assert_eq!(current_ctx.deadline(), Some(deadline)); + }; + + future.with_context(ctx).await; + } + + #[tokio::test] + async fn test_stream_ext_attaches_context_correctly() { + let ctx = ContextImpl::default(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + let ctx = ctx.with_deadline(deadline); + + let stream = async_stream::stream! { + let current_ctx = super::task_local_context::current(); + assert_eq!(current_ctx.deadline(), Some(deadline)); + yield 1; + }; + + let scoped_stream = stream.with_context(ctx); + tokio::pin!(scoped_stream); + scoped_stream.next().await; + } +} diff --git a/grpc/src/context/task_local_context.rs b/grpc/src/context/task_local_context.rs new file mode 100644 index 000000000..d30750035 --- /dev/null +++ b/grpc/src/context/task_local_context.rs @@ -0,0 +1,256 @@ +//! Task local context management. +//! +//! # Implementation Details +//! +//! This module implements a task-local context storage mechanism that is runtime agnostic. +//! It works by using a `std::thread_local` to store the context and swapping it in and out +//! of scope when the future is polled. This allows the context to be available to any +//! code running within the scope of the future, even if it is deeply nested. +//! +//! The implementation is very similar to `tokio::task_local` in terms of performance and +//! mechanics, but it does not depend on the Tokio runtime. +//! +//! # Performance +//! +//! It is important to note that this is **not** a zero-cost abstraction. Every time the +//! future is polled (i.e., every suspend/resume point), a cheap `Arc` clone is performed +//! to ensure the context is correctly set and restored. This overhead is generally minimal +//! but should be considered in performance-critical paths. + +use super::Context; +use super::ContextImpl; +use pin_project_lite::pin_project; +use std::cell::RefCell; +use std::future::Future; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context as TaskContext, Poll}; +use tokio_stream::Stream; + +thread_local! { + static CURRENT: RefCell>> = const { RefCell::new(None) }; +} + +/// Get the current context. +/// +/// This function returns the context associated with the current task. +/// If no context is set, it returns a default context. +/// +/// # Examples +/// +/// ```rust +/// use std::sync::Arc; +/// use grpc::context::{self, Context, FutureExt}; +/// +/// #[tokio::main] +/// async fn main() { +/// // By default, an empty context is returned +/// let ctx = context::current(); +/// assert!(ctx.deadline().is_none()); +/// +/// // You can set the context for a future +/// let deadline = std::time::Instant::now() + std::time::Duration::from_secs(5); +/// let ctx = ctx.with_deadline(deadline); +/// +/// let future = async { +/// let current_ctx = context::current(); +/// assert_eq!(current_ctx.deadline(), Some(deadline)); +/// }; +/// +/// future.with_context(ctx).await; +/// } +/// ``` +pub fn current() -> Arc { + CURRENT.with(|ctx| { + ctx.borrow() + .as_ref() + .map(|c| c.clone()) + .unwrap_or_else(|| Arc::new(ContextImpl::default())) + }) +} + +pin_project! { + pub struct ContextScope { + #[pin] + inner: T, + context: Arc, + } +} + +impl ContextScope { + pub fn new(inner: T, context: Arc) -> Self { + Self { inner, context } + } +} + +struct ContextGuard { + previous: Option>, +} + +impl ContextGuard { + fn new(context: Arc) -> Self { + let previous = CURRENT.with(|ctx| ctx.borrow_mut().replace(context)); + Self { previous } + } +} + +impl Drop for ContextGuard { + fn drop(&mut self) { + CURRENT.with(|ctx| { + if let Some(prev) = self.previous.take() { + ctx.borrow_mut().replace(prev); + } else { + ctx.borrow_mut().take(); + } + }); + } +} + +impl Future for ContextScope { + type Output = F::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll { + let this = self.project(); + let _guard = ContextGuard::new(this.context.clone()); + this.inner.poll(cx) + } +} + +impl Stream for ContextScope { + type Item = S::Item; + + fn poll_next(self: Pin<&mut Self>, cx: &mut TaskContext<'_>) -> Poll> { + let this = self.project(); + let _guard = ContextGuard::new(this.context.clone()); + this.inner.poll_next(cx) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use tokio_stream::StreamExt; + + #[test] + fn test_no_context_set_current_returns_default() { + let ctx = current(); + assert!(ctx.deadline().is_none()); + } + + #[tokio::test] + async fn test_future_wrapped_in_context_scope_sees_context() { + let ctx = ContextImpl::default(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + let ctx = ctx.with_deadline(deadline); + + let future = async { + let current_ctx = current(); + assert_eq!(current_ctx.deadline(), Some(deadline)); + }; + + ContextScope::new(future, ctx).await; + + // After scope, context should be reset (or default) + assert!(current().deadline().is_none()); + } + + #[tokio::test] + async fn test_stream_wrapped_in_context_scope_sees_context() { + let ctx = ContextImpl::default(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + let ctx = ctx.with_deadline(deadline); + + let stream = async_stream::stream! { + let current_ctx = current(); + assert_eq!(current_ctx.deadline(), Some(deadline)); + yield 1; + }; + + let scoped_stream = ContextScope::new(stream, ctx); + tokio::pin!(scoped_stream); + scoped_stream.next().await; + + assert!(current().deadline().is_none()); + } + + #[tokio::test] + async fn test_nested_context_scopes_restore_previous_context() { + let ctx1 = ContextImpl::default(); + let deadline1 = std::time::Instant::now() + std::time::Duration::from_secs(10); + let ctx1 = ctx1.with_deadline(deadline1); + + let ctx2 = ContextImpl::default(); + let deadline2 = std::time::Instant::now() + std::time::Duration::from_secs(20); + let ctx2 = ctx2.with_deadline(deadline2); + + let future = async move { + assert_eq!(current().deadline(), Some(deadline1)); + + let inner_future = async { + assert_eq!(current().deadline(), Some(deadline2)); + }; + + ContextScope::new(inner_future, ctx2).await; + + assert_eq!(current().deadline(), Some(deadline1)); + }; + + ContextScope::new(future, ctx1).await; + assert!(current().deadline().is_none()); + } + + #[tokio::test] + async fn test_spawned_task_with_context_scope_sees_context() { + let ctx = ContextImpl::default(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + let ctx = ctx.with_deadline(deadline); + + let future = async move { + // This code runs in a spawned task + let current_ctx = current(); + assert_eq!(current_ctx.deadline(), Some(deadline)); + }; + + // Spawn a new task, but wrap the future with context + let handle = tokio::spawn(ContextScope::new(future, ctx)); + handle.await.unwrap(); + } + + #[tokio::test] + async fn test_spawned_task_without_context_scope_does_not_inherit_context() { + let ctx = ContextImpl::default(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + let ctx = ctx.with_deadline(deadline); + + // Set the context for the current task + let future = async { + // Spawn a new task WITHOUT wrapping it in ContextScope + let handle = tokio::spawn(async { + let current_ctx = current(); + // Should NOT have the deadline + assert!(current_ctx.deadline().is_none()); + }); + handle.await.unwrap(); + }; + + ContextScope::new(future, ctx).await; + } + + #[tokio::test] + async fn test_context_propagates_to_nested_futures() { + let ctx = ContextImpl::default(); + let deadline = std::time::Instant::now() + std::time::Duration::from_secs(10); + let ctx = ctx.with_deadline(deadline); + + let inner_future = async { + let current_ctx = current(); + assert_eq!(current_ctx.deadline(), Some(deadline)); + }; + + let outer_future = async { + inner_future.await; + }; + + ContextScope::new(outer_future, ctx).await; + } +} diff --git a/grpc/src/lib.rs b/grpc/src/lib.rs index 0ee9831f8..895e66bfd 100644 --- a/grpc/src/lib.rs +++ b/grpc/src/lib.rs @@ -42,6 +42,7 @@ pub mod service; pub(crate) mod attributes; pub(crate) mod byte_str; pub(crate) mod codec; +pub mod context; #[cfg(test)] pub(crate) mod echo_pb { include!(concat!(