diff --git a/opentelemetry-instrumentation-tower/CHANGELOG.md b/opentelemetry-instrumentation-tower/CHANGELOG.md index c9eb6c6ea..f8b42ddeb 100644 --- a/opentelemetry-instrumentation-tower/CHANGELOG.md +++ b/opentelemetry-instrumentation-tower/CHANGELOG.md @@ -2,6 +2,49 @@ ## vNext +### Changed + +* **BREAKING**: Removed `with_meter()` method. The middleware now uses global meter and tracer providers by default via `opentelemetry::global::meter()` and `opentelemetry::global::tracer()`, with optional overrides via `with_tracer_provider()` and `with_meter_provider()` methods. +* **BREAKING**: Renamed types. Use the new names: + - `HTTPMetricsLayer` → `HTTPLayer` + - `HTTPMetricsService` → `HTTPService` + - `HTTPMetricsResponseFuture` → `HTTPResponseFuture` + - `HTTPMetricsLayerBuilder` → `HTTPLayerBuilder` +* Added OpenTelemetry trace support + +### Migration Guide + +#### API Changes + +Before: +```rust +use opentelemetry_instrumentation_tower::HTTPMetricsLayerBuilder; + +let layer = HTTPMetricsLayerBuilder::builder() + .with_meter(meter) + .build() + .unwrap(); +``` + +After: +```rust +use opentelemetry_instrumentation_tower::HTTPLayer; + +// Set global providers +global::set_meter_provider(meter_provider); +global::set_tracer_provider(tracer_provider); // for tracing support + +// Then create the layer - simple API using global providers +let layer = HTTPLayer::new(); +``` + +#### Type Name Changes + +- Replace `HTTPMetricsLayerBuilder` with `HTTPLayerBuilder` +- Replace `HTTPMetricsLayer` with `HTTPLayer` +- Replace `HTTPMetricsService` with `HTTPService` +- Replace `HTTPMetricsResponseFuture` with `HTTPResponseFuture` + ## v0.17.0 ### Changed diff --git a/opentelemetry-instrumentation-tower/Cargo.toml b/opentelemetry-instrumentation-tower/Cargo.toml index e1ae34241..8bded4822 100644 --- a/opentelemetry-instrumentation-tower/Cargo.toml +++ b/opentelemetry-instrumentation-tower/Cargo.toml @@ -5,7 +5,7 @@ rust-version = "1.75.0" version = "0.17.0" license = "Apache-2.0" -description = "OpenTelemetry Metrics Middleware for Tower-compatible Rust HTTP servers" +description = "OpenTelemetry Metrics and Tracing Middleware for Tower-compatible Rust HTTP servers" homepage = "https://github.com/open-telemetry/opentelemetry-rust-contrib" repository = "https://github.com/open-telemetry/opentelemetry-rust-contrib" documentation = "https://docs.rs/tower-otel-http-metrics" @@ -18,12 +18,11 @@ axum = ["dep:axum"] [dependencies] axum = { features = ["matched-path", "macros"], version = "0.8", default-features = false, optional = true } -futures-util = { version = "0.3", default-features = false } http = { version = "1", features = ["std"], default-features = false } http-body = { version = "1", default-features = false } -opentelemetry = { workspace = true, features = ["futures", "metrics"]} +opentelemetry = { workspace = true, features = ["futures", "metrics", "trace"] } +opentelemetry-http = "0.31" opentelemetry-semantic-conventions = { workspace = true, features = ["semconv_experimental"] } -pin-project-lite = { version = "0.2", default-features = false } tower-service = { version = "0.3", default-features = false } tower-layer = { version = "0.3", default-features = false } @@ -31,6 +30,7 @@ tower-layer = { version = "0.3", default-features = false } opentelemetry_sdk = { workspace = true, features = ["metrics", "testing"] } tokio = { version = "1.0", features = ["macros", "rt"] } tower = { version = "0.5", features = ["util"] } +tower-test = { version = "0.4" } [lints] workspace = true diff --git a/opentelemetry-instrumentation-tower/README.md b/opentelemetry-instrumentation-tower/README.md index 60f2ea709..ea04d9b99 100644 --- a/opentelemetry-instrumentation-tower/README.md +++ b/opentelemetry-instrumentation-tower/README.md @@ -1,6 +1,38 @@ -# Tower OTEL Metrics Middleware +# Tower OTEL HTTP Instrumentation Middleware -OpenTelemetry Metrics Middleware for Tower-compatible Rust HTTP servers. +OpenTelemetry HTTP Metrics and Tracing Middleware for Tower-compatible Rust HTTP servers. + +This middleware provides both metrics and distributed tracing for HTTP requests, following OpenTelemetry semantic conventions. + +## Features + +- **HTTP Metrics**: Request duration, active requests, request/response body sizes +- **Distributed Tracing**: HTTP spans with semantic attributes +- **Semantic Conventions**: Uses OpenTelemetry semantic conventions for consistent attribute naming +- **Flexible Configuration**: Support for custom attribute extractors and tracer configuration +- **Framework Support**: Works with any Tower-compatible HTTP framework (Axum, Hyper, Tonic etc.) + +## Usage + +## Metrics + +The middleware exports the following metrics: + +- `http.server.request.duration` - Duration of HTTP requests +- `http.server.active_requests` - Number of active HTTP requests +- `http.server.request.body.size` - Size of HTTP request bodies +- `http.server.response.body.size` - Size of HTTP response bodies + +## Tracing + +HTTP spans are created with the following attributes (following OpenTelemetry semantic conventions): + +- `http.request.method` - HTTP method +- `url.scheme` - URL scheme (http/https) +- `url.path` - Request path +- `url.full` - Full URL +- `user_agent.original` - User agent string +- `http.response.status_code` - HTTP response status code ## Examples diff --git a/opentelemetry-instrumentation-tower/examples/axum-http-service/Cargo.toml b/opentelemetry-instrumentation-tower/examples/axum-http-service/Cargo.toml index 986bb69d3..f62e6f8de 100644 --- a/opentelemetry-instrumentation-tower/examples/axum-http-service/Cargo.toml +++ b/opentelemetry-instrumentation-tower/examples/axum-http-service/Cargo.toml @@ -12,9 +12,9 @@ axum = { features = ["http1", "tokio"], version = "0.8", default-features = fals bytes = { version = "1", default-features = false } opentelemetry = { workspace = true} opentelemetry_sdk = { workspace = true, default-features = false } -opentelemetry-otlp = { version = "0.31.0", features = ["grpc-tonic", "metrics"], default-features = false } +opentelemetry-otlp = { version = "0.31.0", features = ["grpc-tonic", "metrics", "trace"], default-features = false } tokio = { version = "1", features = ["rt-multi-thread"], default-features = false } rand_09 = { package = "rand", version = "0.9" } [lints] -workspace = true \ No newline at end of file +workspace = true diff --git a/opentelemetry-instrumentation-tower/examples/axum-http-service/src/main.rs b/opentelemetry-instrumentation-tower/examples/axum-http-service/src/main.rs index 559c74c2e..1718387aa 100644 --- a/opentelemetry-instrumentation-tower/examples/axum-http-service/src/main.rs +++ b/opentelemetry-instrumentation-tower/examples/axum-http-service/src/main.rs @@ -1,7 +1,12 @@ use axum::routing::{get, post, put, Router}; use bytes::Bytes; use opentelemetry::global; -use opentelemetry_instrumentation_tower as otel_tower_metrics; +use opentelemetry_instrumentation_tower::HTTPLayer; +use opentelemetry_otlp::{MetricExporter, SpanExporter}; +use opentelemetry_sdk::{ + metrics::{PeriodicReader, SdkMeterProvider}, + trace::SdkTracerProvider, +}; use std::time::Duration; const SERVICE_NAME: &str = "example-axum-http-service"; @@ -40,34 +45,47 @@ async fn handle() -> Bytes { #[tokio::main] async fn main() { - let exporter = opentelemetry_otlp::MetricExporter::builder() - .with_tonic() - // .with_endpoint("http://localhost:4317") // default; leave out in favor of env var OTEL_EXPORTER_OTLP_ENDPOINT - .build() - .unwrap(); + { + let exporter = MetricExporter::builder() + .with_tonic() + // .with_endpoint("http://localhost:4317") // default; leave out in favor of env var OTEL_EXPORTER_OTLP_ENDPOINT + .build() + .unwrap(); - let reader = opentelemetry_sdk::metrics::PeriodicReader::builder(exporter) - .with_interval(_OTEL_METRIC_EXPORT_INTERVAL) - .build(); + let reader = PeriodicReader::builder(exporter) + .with_interval(_OTEL_METRIC_EXPORT_INTERVAL) + .build(); - let meter_provider = opentelemetry_sdk::metrics::SdkMeterProvider::builder() - .with_reader(reader) - .with_resource(init_otel_resource()) - .build(); + let provider = SdkMeterProvider::builder() + .with_reader(reader) + .with_resource(init_otel_resource()) + .build(); - global::set_meter_provider(meter_provider); - // init our otel metrics middleware - let global_meter = global::meter(SERVICE_NAME); - let otel_metrics_service_layer = otel_tower_metrics::HTTPMetricsLayerBuilder::builder() - .with_meter(global_meter) - .build() - .unwrap(); + global::set_meter_provider(provider); + } + + { + let exporter = SpanExporter::builder() + .with_tonic() + // .with_endpoint("http://localhost:4317") // default; leave out in favor of env var OTEL_EXPORTER_OTLP_ENDPOINT + .build() + .unwrap(); + + let provider = SdkTracerProvider::builder() + .with_batch_exporter(exporter) + .with_resource(init_otel_resource()) + .build(); + + global::set_tracer_provider(provider); + } + + let otel_service_layer = HTTPLayer::new(); let app = Router::new() .route("/", get(handle)) .route("/", post(handle)) .route("/", put(handle)) - .layer(otel_metrics_service_layer); + .layer(otel_service_layer); let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await.unwrap(); let server = axum::serve(listener, app); diff --git a/opentelemetry-instrumentation-tower/examples/hyper-http-service/Cargo.toml b/opentelemetry-instrumentation-tower/examples/hyper-http-service/Cargo.toml index b04da5798..943b4ff5a 100644 --- a/opentelemetry-instrumentation-tower/examples/hyper-http-service/Cargo.toml +++ b/opentelemetry-instrumentation-tower/examples/hyper-http-service/Cargo.toml @@ -13,10 +13,10 @@ http-body-util = { version = "0.1", default-features = false } hyper-util = { version = "0.1", features = ["http1", "service", "server", "tokio"], default-features = false } opentelemetry = { workspace = true} opentelemetry_sdk = { workspace = true, default-features = false } -opentelemetry-otlp = { version = "0.31.0", features = ["grpc-tonic", "metrics"], default-features = false } +opentelemetry-otlp = { version = "0.31.0", features = ["grpc-tonic", "metrics", "trace"], default-features = false } tokio = { version = "1", features = ["rt-multi-thread", "macros"], default-features = false } tower = { version = "0.5", default-features = false } rand_09 = { package = "rand", version = "0.9" } [lints] -workspace = true \ No newline at end of file +workspace = true diff --git a/opentelemetry-instrumentation-tower/examples/hyper-http-service/src/main.rs b/opentelemetry-instrumentation-tower/examples/hyper-http-service/src/main.rs index ccd31bd5b..8134f3fcb 100644 --- a/opentelemetry-instrumentation-tower/examples/hyper-http-service/src/main.rs +++ b/opentelemetry-instrumentation-tower/examples/hyper-http-service/src/main.rs @@ -2,7 +2,10 @@ use http_body_util::Full; use hyper::body::Bytes; use hyper::{Request, Response}; use opentelemetry::global; -use opentelemetry_instrumentation_tower as otel_tower_metrics; +use opentelemetry_instrumentation_tower::HTTPLayer; +use opentelemetry_otlp::{MetricExporter, SpanExporter}; +use opentelemetry_sdk::metrics::{PeriodicReader, SdkMeterProvider}; +use opentelemetry_sdk::trace::SdkTracerProvider; use std::convert::Infallible; use std::net::SocketAddr; use std::time::Duration; @@ -45,31 +48,44 @@ async fn handle(_req: Request) -> Result, pub server_active_requests: UpDownCounter, pub server_request_body_size: Histogram, @@ -131,23 +129,38 @@ struct HTTPMetricsLayerState { } #[derive(Clone)] -/// [`Service`] used by [`HTTPMetricsLayer`] -pub struct HTTPMetricsService { - pub(crate) state: Arc, +/// [`Service`] used by [`HTTPLayer`] +pub struct HTTPService { + pub(crate) state: Arc, request_extractor: ReqExt, response_extractor: ResExt, inner_service: S, + tracer: Arc, } #[derive(Clone)] -/// [`Layer`] which applies the OTEL HTTP server metrics middleware -pub struct HTTPMetricsLayer { - state: Arc, +/// [`Layer`] which applies the OTEL HTTP server metrics and tracing middleware +pub struct HTTPLayer { + state: Arc, request_extractor: ReqExt, response_extractor: ResExt, + tracer: Arc, } -pub struct HTTPMetricsLayerBuilder { +impl HTTPLayer { + /// Create a new HTTP layer with default configuration using global providers + pub fn new() -> Self { + HTTPLayerBuilder::builder().build().unwrap() + } +} + +impl Default for HTTPLayer { + fn default() -> Self { + Self::new() + } +} + +pub struct HTTPLayerBuilder { meter: Option, req_dur_bounds: Option>, request_extractor: ReqExt, @@ -190,9 +203,9 @@ impl fmt::Debug for Error { } } -impl HTTPMetricsLayerBuilder { +impl HTTPLayerBuilder { pub fn builder() -> Self { - HTTPMetricsLayerBuilder { + HTTPLayerBuilder { meter: None, req_dur_bounds: Some(LIBRARY_DEFAULT_HTTP_SERVER_DURATION_BOUNDARIES.to_vec()), request_extractor: NoOpExtractor, @@ -201,16 +214,16 @@ impl HTTPMetricsLayerBuilder { } } -impl HTTPMetricsLayerBuilder { +impl HTTPLayerBuilder { /// Set a request attribute extractor pub fn with_request_extractor( self, extractor: NewReqExt, - ) -> HTTPMetricsLayerBuilder + ) -> HTTPLayerBuilder where NewReqExt: RequestAttributeExtractor, { - HTTPMetricsLayerBuilder { + HTTPLayerBuilder { meter: self.meter, req_dur_bounds: self.req_dur_bounds, request_extractor: extractor, @@ -222,11 +235,11 @@ impl HTTPMetricsLayerBuilder { pub fn with_response_extractor( self, extractor: NewResExt, - ) -> HTTPMetricsLayerBuilder + ) -> HTTPLayerBuilder where NewResExt: ResponseAttributeExtractor, { - HTTPMetricsLayerBuilder { + HTTPLayerBuilder { meter: self.meter, req_dur_bounds: self.req_dur_bounds, request_extractor: self.request_extractor, @@ -238,7 +251,7 @@ impl HTTPMetricsLayerBuilder { pub fn with_request_extractor_fn( self, f: F, - ) -> HTTPMetricsLayerBuilder, ResExt> + ) -> HTTPLayerBuilder, ResExt> where F: Fn(&http::Request) -> Vec + Clone + Send + Sync + 'static, { @@ -249,27 +262,30 @@ impl HTTPMetricsLayerBuilder { pub fn with_response_extractor_fn( self, f: F, - ) -> HTTPMetricsLayerBuilder> + ) -> HTTPLayerBuilder> where F: Fn(&http::Response) -> Vec + Clone + Send + Sync + 'static, { self.with_response_extractor(FnResponseExtractor::new(f)) } - pub fn build(self) -> Result> { + pub fn build(self) -> Result> { let req_dur_bounds = self .req_dur_bounds .unwrap_or_else(|| LIBRARY_DEFAULT_HTTP_SERVER_DURATION_BOUNDARIES.to_vec()); - match self.meter { - Some(meter) => Ok(HTTPMetricsLayer { - state: Arc::from(Self::make_state(meter, req_dur_bounds)), - request_extractor: self.request_extractor, - response_extractor: self.response_extractor, - }), - None => Err(Error { - inner: ErrorKind::Config(String::from("no meter provided")), - }), - } + + let tracer = Arc::new(global::tracer("opentelemetry-instrumentation-tower")); + + let meter: Meter = self + .meter + .unwrap_or_else(|| global::meter("opentelemetry-instrumentation-tower")); + + Ok(HTTPLayer { + state: Arc::from(Self::make_state(meter, req_dur_bounds)), + request_extractor: self.request_extractor, + response_extractor: self.response_extractor, + tracer, + }) } pub fn with_meter(mut self, meter: Meter) -> Self { @@ -282,8 +298,8 @@ impl HTTPMetricsLayerBuilder { self } - fn make_state(meter: Meter, req_dur_bounds: Vec) -> HTTPMetricsLayerState { - HTTPMetricsLayerState { + fn make_state(meter: Meter, req_dur_bounds: Vec) -> HTTPLayerState { + HTTPLayerState { server_request_duration: meter .f64_histogram(Cow::from(HTTP_SERVER_DURATION_METRIC)) .with_description("Duration of HTTP server requests.") @@ -309,31 +325,27 @@ impl HTTPMetricsLayerBuilder { } } -impl Layer for HTTPMetricsLayer +impl Layer for HTTPLayer where ReqExt: Clone, ResExt: Clone, { - type Service = HTTPMetricsService; + type Service = HTTPService; fn layer(&self, service: S) -> Self::Service { - HTTPMetricsService { + HTTPService { state: self.state.clone(), request_extractor: self.request_extractor.clone(), response_extractor: self.response_extractor.clone(), inner_service: service, + tracer: self.tracer.clone(), } } } -/// ResponseFutureMetricsState holds request-scoped data for metrics and their attributes. -/// -/// ResponseFutureMetricsState lives inside the response future, as it needs to hold data -/// initialized or extracted from the request before it is forwarded to the inner Service. -/// The rest of the data (e.g. status code, error) can be extracted from the response -/// or calculated with respect to the data held here (e.g., duration = now - duration start). -#[derive(Clone)] -struct ResponseFutureMetricsState { +/// Request data extracted before the inner service call. +/// This data is needed for metrics and span finalization after the response is received. +struct RequestData { // fields for the metric values // https://opentelemetry.io/docs/specs/semconv/http/http-metrics/#metric-httpserverrequestduration duration_start: Instant, @@ -351,28 +363,19 @@ struct ResponseFutureMetricsState { custom_request_attributes: Vec, } -pin_project! { - /// Response [`Future`] for [`HTTPMetricsService`]. - pub struct HTTPMetricsResponseFuture { - #[pin] - inner_response_future: F, - layer_state: Arc, - metrics_state: ResponseFutureMetricsState, - response_extractor: ResExt, - } -} - impl Service> - for HTTPMetricsService + for HTTPService where S: Service, Response = http::Response>, + S::Future: Send + 'static, + S::Error: std::fmt::Debug, ResBody: http_body::Body, ReqExt: RequestAttributeExtractor, ResExt: ResponseAttributeExtractor, { type Response = S::Response; type Error = S::Error; - type Future = HTTPMetricsResponseFuture; + type Future = Pin> + Send>>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner_service.poll_ready(cx) @@ -394,105 +397,190 @@ where let url_scheme_kv = KeyValue::new(URL_SCHEME_LABEL, scheme); let method = req.method().as_str().to_owned(); - let method_kv = KeyValue::new(HTTP_REQUEST_METHOD_LABEL, method); + let method_kv = KeyValue::new(HTTP_REQUEST_METHOD_LABEL, method.clone()); - #[allow(unused_mut)] - let mut route_kv_opt = None; #[cfg(feature = "axum")] - if let Some(matched_path) = req.extensions().get::() { - route_kv_opt = Some(KeyValue::new( - HTTP_ROUTE_LABEL, - matched_path.as_str().to_owned(), - )); - }; + let route_kv_opt = req + .extensions() + .get::() + .map(|matched_path| KeyValue::new(HTTP_ROUTE_LABEL, matched_path.as_str().to_owned())); + + #[cfg(not(feature = "axum"))] + let route_kv_opt = None; // Extract custom request attributes let custom_request_attributes = self.request_extractor.extract_attributes(&req); + // Extract the context from the incoming request headers + let parent_cx = global::get_text_map_propagator(|propagator| { + propagator.extract(&HeaderExtractor(req.headers())) + }); + + let mut span_attributes = vec![ + KeyValue::new(semconv::trace::HTTP_REQUEST_METHOD, method.clone()), + url_scheme_kv.clone(), + KeyValue::new(semconv::attribute::URL_PATH, req.uri().path().to_string()), + KeyValue::new(semconv::trace::URL_FULL, req.uri().to_string()), + ]; + + if let Some(user_agent) = req + .headers() + .get("user-agent") + .and_then(|v| v.to_str().ok()) + { + span_attributes.push(KeyValue::new( + semconv::trace::USER_AGENT_ORIGINAL, + user_agent.to_string(), + )); + } + + span_attributes.extend(custom_request_attributes.clone()); + + let span_name = format!("{} {}", method, req.uri().path()); + + let span = self + .tracer + .span_builder(span_name) + .with_kind(SpanKind::Server) + .with_attributes(span_attributes) + .start_with_context(self.tracer.as_ref(), &parent_cx); + + let cx = parent_cx.with_span(span); + self.state .server_active_requests .add(1, &[url_scheme_kv.clone(), method_kv.clone()]); - HTTPMetricsResponseFuture { - inner_response_future: self.inner_service.call(req), - layer_state: self.state.clone(), - metrics_state: ResponseFutureMetricsState { - duration_start, - req_body_size: content_length, - - protocol_name_kv, - protocol_version_kv, - url_scheme_kv, - method_kv, - route_kv_opt, - custom_request_attributes, - }, - response_extractor: self.response_extractor.clone(), - } + let request_data = RequestData { + duration_start, + req_body_size: content_length, + protocol_name_kv, + protocol_version_kv, + url_scheme_kv, + method_kv, + route_kv_opt, + custom_request_attributes, + }; + + let layer_state = self.state.clone(); + let response_extractor = self.response_extractor.clone(); + + let inner_future = self.inner_service.call(req); + + Box::pin( + async move { + let result = inner_future.await; + finalize_request(&result, &request_data, &layer_state, &response_extractor); + result + } + .with_context(cx), + ) } } -impl Future for HTTPMetricsResponseFuture -where - F: Future, E>>, +/// Finalizes the request by updating the span and recording metrics after the response is received. +fn finalize_request( + result: &result::Result, E>, + request_data: &RequestData, + layer_state: &Arc, + response_extractor: &ResExt, +) where ResBody: http_body::Body, ResExt: ResponseAttributeExtractor, + E: std::fmt::Debug, { - type Output = F::Output; - - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - let this = self.project(); - let response = ready!(this.inner_response_future.poll(cx))?; - let status = response.status(); - - // Build base label set - let mut label_superset = vec![ - this.metrics_state.protocol_name_kv.clone(), - this.metrics_state.protocol_version_kv.clone(), - this.metrics_state.url_scheme_kv.clone(), - this.metrics_state.method_kv.clone(), - KeyValue::new(HTTP_RESPONSE_STATUS_CODE_LABEL, i64::from(status.as_u16())), - ]; + let cx = OtelContext::current(); + let span = cx.span(); + + match result { + Ok(response) => { + let status = response.status(); + + // Build base label set + let mut label_superset = vec![ + request_data.protocol_name_kv.clone(), + request_data.protocol_version_kv.clone(), + request_data.url_scheme_kv.clone(), + request_data.method_kv.clone(), + KeyValue::new(HTTP_RESPONSE_STATUS_CODE_LABEL, i64::from(status.as_u16())), + ]; + + if let Some(route_kv) = &request_data.route_kv_opt { + label_superset.push(route_kv.clone()); + } - if let Some(route_kv) = this.metrics_state.route_kv_opt.clone() { - label_superset.push(route_kv); - } + // Add custom request attributes + label_superset.extend(request_data.custom_request_attributes.clone()); - // Add custom request attributes - label_superset.extend(this.metrics_state.custom_request_attributes.clone()); + // Extract and add custom response attributes + let custom_response_attributes = response_extractor.extract_attributes(response); + label_superset.extend(custom_response_attributes.clone()); - // Extract and add custom response attributes - let custom_response_attributes = this.response_extractor.extract_attributes(&response); - label_superset.extend(custom_response_attributes); + // Update span + span.set_attribute(KeyValue::new( + semconv::trace::HTTP_RESPONSE_STATUS_CODE, + status.as_u16() as i64, + )); - this.layer_state.server_request_duration.record( - this.metrics_state.duration_start.elapsed().as_secs_f64(), - &label_superset, - ); + // Add custom response attributes to span + for attr in &custom_response_attributes { + span.set_attribute(attr.clone()); + } - if let Some(req_content_length) = this.metrics_state.req_body_size { - this.layer_state - .server_request_body_size - .record(req_content_length, &label_superset); - } + // Set span status based on HTTP status code + if status.is_server_error() { + span.set_status(Status::Error { + description: format!("HTTP {}", status.as_u16()).into(), + }); + } - // use same approach for `http.server.response.body.size` as hyper does to set content-length - if let Some(resp_content_length) = response.body().size_hint().exact() { - this.layer_state - .server_response_body_size - .record(resp_content_length, &label_superset); - } + // Record metrics + layer_state.server_request_duration.record( + request_data.duration_start.elapsed().as_secs_f64(), + &label_superset, + ); - this.layer_state.server_active_requests.add( - -1, - &[ - this.metrics_state.url_scheme_kv.clone(), - this.metrics_state.method_kv.clone(), - ], - ); + if let Some(req_content_length) = request_data.req_body_size { + layer_state + .server_request_body_size + .record(req_content_length, &label_superset); + } - Ready(Ok(response)) + if let Some(resp_content_length) = response.body().size_hint().exact() { + layer_state + .server_response_body_size + .record(resp_content_length, &label_superset); + } + } + Err(error) => { + // Mark span as error + span.set_status(Status::Error { + description: format!("{:?}", error).into(), + }); + + // Still record duration metric with error label + let label_superset = vec![ + request_data.protocol_name_kv.clone(), + request_data.protocol_version_kv.clone(), + request_data.url_scheme_kv.clone(), + request_data.method_kv.clone(), + ]; + + layer_state.server_request_duration.record( + request_data.duration_start.elapsed().as_secs_f64(), + &label_superset, + ); + } } + + // Always decrement active requests counter + layer_state.server_active_requests.add( + -1, + &[ + request_data.url_scheme_kv.clone(), + request_data.method_kv.clone(), + ], + ); } fn split_and_format_protocol_version(http_version: http::Version) -> (String, String) { @@ -509,17 +597,121 @@ fn split_and_format_protocol_version(http_version: http::Version) -> (String, St #[cfg(test)] mod tests { + // Tests use optional provider overrides instead of global providers to avoid interference. use super::*; + use http::{Request, Response, StatusCode}; use opentelemetry::metrics::MeterProvider; + use opentelemetry::trace::TracerProvider; + use opentelemetry::trace::{FutureExt, TraceContextExt, Tracer}; + use opentelemetry_sdk::metrics::SdkMeterProvider; use opentelemetry_sdk::metrics::{ data::{AggregatedMetrics, MetricData}, - InMemoryMetricExporter, PeriodicReader, SdkMeterProvider, + InMemoryMetricExporter, PeriodicReader, }; + use opentelemetry_sdk::trace::{InMemorySpanExporterBuilder, SdkTracerProvider}; + use std::result::Result; use std::time::Duration; - use tower::Service; + use tower::{Service, ServiceBuilder, ServiceExt}; + + #[tokio::test(flavor = "current_thread")] + async fn test_tracing_with_in_memory_tracer() { + let trace_exporter = InMemorySpanExporterBuilder::new().build(); + let tracer_provider = SdkTracerProvider::builder() + .with_simple_exporter(trace_exporter.clone()) + .build(); + + let tracer = Arc::new(BoxedTracer::new(Box::new( + tracer_provider.tracer("test_tracer"), + ))); + + let mut layer = HTTPLayerBuilder::builder().build().unwrap(); + layer.tracer = tracer.clone(); + + let mut service = ServiceBuilder::new() + .layer(layer) + .service(tower::service_fn(echo)); + + // Create a parent span and set it as the current context + let parent_span = tracer.start("parent_operation"); + let cx = OtelContext::current_with_span(parent_span); + + let request_body = "test".to_string(); + let request = http::Request::builder() + .uri("http://example.com/api/users/123") + .header("Content-Length", request_body.len().to_string()) + .header("User-Agent", "tower-test-client/1.0") + .body(request_body) + .unwrap(); + + // Execute the service call within the parent span context + let _response = async { service.ready().await.unwrap().call(request).await.unwrap() } + .with_context(cx) + .await; + + tracer_provider.force_flush().unwrap(); + + let spans = trace_exporter.get_finished_spans().unwrap(); + assert_eq!( + spans.len(), + 2, + "Expected exactly two spans to be recorded (parent + HTTP)" + ); + + // Find the HTTP span (should be the child) + let http_span = spans + .iter() + .find(|span| span.name == "GET /api/users/123") + .expect("Should find HTTP span"); + + // Find the parent span + let parent_span = spans + .iter() + .find(|span| span.name == "parent_operation") + .expect("Should find parent span"); + + // Verify the HTTP span has the correct parent + assert_eq!( + http_span.parent_span_id, + parent_span.span_context.span_id(), + "HTTP span should have parent span as parent" + ); + + // Verify they share the same trace ID + assert_eq!( + http_span.span_context.trace_id(), + parent_span.span_context.trace_id(), + "Parent and child spans should share the same trace ID" + ); + + assert_eq!( + http_span.name, "GET /api/users/123", + "Span name should match the request" + ); + // Build expected attributes + let expected_attributes = vec![ + KeyValue::new(semconv::trace::HTTP_REQUEST_METHOD, "GET".to_string()), + KeyValue::new(semconv::trace::URL_SCHEME, "http".to_string()), + KeyValue::new(semconv::trace::URL_PATH, "/api/users/123".to_string()), + KeyValue::new( + semconv::trace::URL_FULL, + "http://example.com/api/users/123".to_string(), + ), + KeyValue::new( + semconv::trace::USER_AGENT_ORIGINAL, + "tower-test-client/1.0".to_string(), + ), + KeyValue::new(semconv::trace::HTTP_RESPONSE_STATUS_CODE, 200), + ]; + + assert_eq!(http_span.attributes, expected_attributes); + } + + async fn echo(req: http::Request) -> Result, Error> { + Ok(http::Response::new(req.into_body())) + } - #[tokio::test] + #[tokio::test(flavor = "current_thread")] async fn test_metrics_labels() { let exporter = InMemoryMetricExporter::default(); let reader = PeriodicReader::builder(exporter.clone()) @@ -528,7 +720,7 @@ mod tests { let meter_provider = SdkMeterProvider::builder().with_reader(reader).build(); let meter = meter_provider.meter("test"); - let layer = HTTPMetricsLayerBuilder::builder() + let layer = HTTPLayerBuilder::builder() .with_meter(meter) .build() .unwrap(); @@ -642,7 +834,7 @@ mod tests { .iter() .find(|kv| kv.key.as_str() == NETWORK_PROTOCOL_NAME_LABEL) .expect("Protocol name should be present in request body size"); - assert_eq!(protocol_name.value.as_str(), "http"); + assert_eq!(protocol_name.value.as_str(), "https"); let protocol_version = attributes .iter() @@ -761,4 +953,52 @@ mod tests { } } } + + #[tokio::test(flavor = "current_thread")] + async fn test_context_available_in_handler() { + let trace_exporter = InMemorySpanExporterBuilder::new().build(); + let tracer_provider = SdkTracerProvider::builder() + .with_simple_exporter(trace_exporter.clone()) + .build(); + + let tracer = Arc::new(BoxedTracer::new(Box::new( + tracer_provider.tracer("test_tracer"), + ))); + + let mut layer = HTTPLayerBuilder::builder().build().unwrap(); + layer.tracer = tracer.clone(); + + let service = tower::service_fn(|_req: Request| async { + // Access the current context - this should have the HTTP span + let cx = OtelContext::current(); + let span = cx.span(); + + // Verify we can get span context (means context is attached) + let span_context = span.span_context(); + assert!(span_context.is_valid(), "Span context should be valid"); + + Ok::<_, std::convert::Infallible>( + Response::builder() + .status(StatusCode::OK) + .body(String::from("OK")) + .unwrap(), + ) + }); + + let mut service = layer.layer(service); + + let request = Request::builder() + .method("GET") + .uri("http://example.com/test") + .body("test".to_string()) + .unwrap(); + + let _response = service.call(request).await.unwrap(); + + tracer_provider.force_flush().unwrap(); + + let spans = trace_exporter.get_finished_spans().unwrap(); + assert_eq!(spans.len(), 1, "Expected one HTTP span"); + assert_eq!(spans[0].name, "GET /test"); + } }