@@ -7,6 +7,7 @@ use crate::kserve::{
77 kerve_server_metadata, kserve_health_live, kserve_health_ready, kserve_model_infer,
88 kserve_model_metadata, kserve_model_metadata_ready,
99} ;
10+ use crate :: logging:: trace_context_middleware;
1011use crate :: sagemaker:: {
1112 sagemaker_compatibility, SagemakerRequest , SagemakerResponse , SagemakerStreamResponse ,
1213 __path_sagemaker_compatibility,
@@ -63,6 +64,7 @@ use tokio::sync::oneshot;
6364use tokio:: time:: Instant ;
6465use tower_http:: cors:: { AllowOrigin , CorsLayer } ;
6566use tracing:: { info_span, instrument, Instrument } ;
67+ use tracing_opentelemetry:: OpenTelemetrySpanExt ;
6668use utoipa:: OpenApi ;
6769use utoipa_swagger_ui:: SwaggerUi ;
6870
@@ -125,6 +127,7 @@ pub(crate) async fn compat_generate(
125127 Extension ( default_return_full_text) : Extension < bool > ,
126128 infer : Extension < Infer > ,
127129 compute_type : Extension < ComputeType > ,
130+ context : Extension < Option < opentelemetry:: Context > > ,
128131 Json ( mut req) : Json < CompatGenerateRequest > ,
129132) -> Result < Response , ( StatusCode , Json < ErrorResponse > ) > {
130133 // default return_full_text given the pipeline_tag
@@ -134,11 +137,14 @@ pub(crate) async fn compat_generate(
134137
135138 // switch on stream
136139 if req. stream {
137- Ok ( generate_stream ( infer, compute_type, Json ( req. into ( ) ) )
138- . await
139- . into_response ( ) )
140+ Ok (
141+ generate_stream ( infer, compute_type, context, Json ( req. into ( ) ) )
142+ . await
143+ . into_response ( ) ,
144+ )
140145 } else {
141- let ( headers, Json ( generation) ) = generate ( infer, compute_type, Json ( req. into ( ) ) ) . await ?;
146+ let ( headers, Json ( generation) ) =
147+ generate ( infer, compute_type, context, Json ( req. into ( ) ) ) . await ?;
142148 // wrap generation inside a Vec to match api-inference
143149 Ok ( ( headers, Json ( vec ! [ generation] ) ) . into_response ( ) )
144150 }
@@ -267,9 +273,14 @@ seed,
267273async fn generate (
268274 infer : Extension < Infer > ,
269275 Extension ( ComputeType ( compute_type) ) : Extension < ComputeType > ,
276+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
270277 Json ( req) : Json < GenerateRequest > ,
271278) -> Result < ( HeaderMap , Json < GenerateResponse > ) , ( StatusCode , Json < ErrorResponse > ) > {
272279 let span = tracing:: Span :: current ( ) ;
280+ if let Some ( context) = context {
281+ span. set_parent ( context) ;
282+ }
283+
273284 let ( headers, _, response) =
274285 generate_internal ( infer, ComputeType ( compute_type) , Json ( req) , span) . await ?;
275286 Ok ( ( headers, response) )
@@ -465,12 +476,17 @@ seed,
465476async fn generate_stream (
466477 Extension ( infer) : Extension < Infer > ,
467478 Extension ( compute_type) : Extension < ComputeType > ,
479+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
468480 Json ( req) : Json < GenerateRequest > ,
469481) -> (
470482 HeaderMap ,
471483 Sse < impl Stream < Item = Result < Event , Infallible > > > ,
472484) {
473485 let span = tracing:: Span :: current ( ) ;
486+ if let Some ( context) = context {
487+ span. set_parent ( context) ;
488+ }
489+
474490 let ( headers, response_stream) =
475491 generate_stream_internal ( infer, compute_type, Json ( req) , span) . await ;
476492
@@ -700,9 +716,14 @@ pub(crate) async fn completions(
700716 Extension ( infer) : Extension < Infer > ,
701717 Extension ( compute_type) : Extension < ComputeType > ,
702718 Extension ( info) : Extension < Info > ,
719+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
703720 Json ( req) : Json < CompletionRequest > ,
704721) -> Result < Response , ( StatusCode , Json < ErrorResponse > ) > {
705722 let span = tracing:: Span :: current ( ) ;
723+ if let Some ( context) = context {
724+ span. set_parent ( context) ;
725+ }
726+
706727 metrics:: counter!( "tgi_request_count" ) . increment ( 1 ) ;
707728
708729 let CompletionRequest {
@@ -1148,9 +1169,14 @@ pub(crate) async fn chat_completions(
11481169 Extension ( infer) : Extension < Infer > ,
11491170 Extension ( compute_type) : Extension < ComputeType > ,
11501171 Extension ( info) : Extension < Info > ,
1172+ Extension ( context) : Extension < Option < opentelemetry:: Context > > ,
11511173 Json ( mut chat) : Json < ChatRequest > ,
11521174) -> Result < Response , ( StatusCode , Json < ErrorResponse > ) > {
11531175 let span = tracing:: Span :: current ( ) ;
1176+ if let Some ( context) = context {
1177+ span. set_parent ( context) ;
1178+ }
1179+
11541180 metrics:: counter!( "tgi_request_count" ) . increment ( 1 ) ;
11551181 let ChatRequest {
11561182 model,
@@ -2258,6 +2284,7 @@ async fn start(
22582284 . layer ( Extension ( prom_handle. clone ( ) ) )
22592285 . layer ( OtelAxumLayer :: default ( ) )
22602286 . layer ( DefaultBodyLimit :: max ( payload_limit) )
2287+ . layer ( axum:: middleware:: from_fn ( trace_context_middleware) )
22612288 . layer ( cors_layer) ;
22622289
22632290 tracing:: info!( "Connected" ) ;
0 commit comments