Skip to content

Commit 9406a99

Browse files
authored
chore: refactor tracingmiddelware (#3520)
# What does this PR do? Just moving TracingMiddleware to a new file ## Test Plan CI
1 parent 2be869b commit 9406a99

File tree

2 files changed

+75
-68
lines changed

2 files changed

+75
-68
lines changed

llama_stack/core/server/server.py

Lines changed: 3 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import httpx
2626
import rich.pretty
2727
import yaml
28-
from aiohttp import hdrs
2928
from fastapi import Body, FastAPI, HTTPException, Request, Response
3029
from fastapi import Path as FastapiPath
3130
from fastapi.exceptions import RequestValidationError
@@ -45,17 +44,13 @@
4544
process_cors_config,
4645
)
4746
from llama_stack.core.distribution import builtin_automatically_routed_apis
48-
from llama_stack.core.external import ExternalApiSpec, load_external_apis
47+
from llama_stack.core.external import load_external_apis
4948
from llama_stack.core.request_headers import (
5049
PROVIDER_DATA_VAR,
5150
request_provider_data_context,
5251
user_from_scope,
5352
)
54-
from llama_stack.core.server.routes import (
55-
find_matching_route,
56-
get_all_api_routes,
57-
initialize_route_impls,
58-
)
53+
from llama_stack.core.server.routes import get_all_api_routes
5954
from llama_stack.core.stack import (
6055
Stack,
6156
cast_image_name_to_string,
@@ -73,13 +68,12 @@
7368
)
7469
from llama_stack.providers.utils.telemetry.tracing import (
7570
CURRENT_TRACE_CONTEXT,
76-
end_trace,
7771
setup_logger,
78-
start_trace,
7972
)
8073

8174
from .auth import AuthenticationMiddleware
8275
from .quota import QuotaMiddleware
76+
from .tracing import TracingMiddleware
8377

8478
REPO_ROOT = Path(__file__).parent.parent.parent.parent
8579

@@ -299,65 +293,6 @@ async def route_handler(request: Request, **kwargs):
299293
return route_handler
300294

301295

302-
class TracingMiddleware:
303-
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
304-
self.app = app
305-
self.impls = impls
306-
self.external_apis = external_apis
307-
# FastAPI built-in paths that should bypass custom routing
308-
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
309-
310-
async def __call__(self, scope, receive, send):
311-
if scope.get("type") == "lifespan":
312-
return await self.app(scope, receive, send)
313-
314-
path = scope.get("path", "")
315-
316-
# Check if the path is a FastAPI built-in path
317-
if path.startswith(self.fastapi_paths):
318-
# Pass through to FastAPI's built-in handlers
319-
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
320-
return await self.app(scope, receive, send)
321-
322-
if not hasattr(self, "route_impls"):
323-
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
324-
325-
try:
326-
_, _, route_path, webmethod = find_matching_route(
327-
scope.get("method", hdrs.METH_GET), path, self.route_impls
328-
)
329-
except ValueError:
330-
# If no matching endpoint is found, pass through to FastAPI
331-
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
332-
return await self.app(scope, receive, send)
333-
334-
trace_attributes = {"__location__": "server", "raw_path": path}
335-
336-
# Extract W3C trace context headers and store as trace attributes
337-
headers = dict(scope.get("headers", []))
338-
traceparent = headers.get(b"traceparent", b"").decode()
339-
if traceparent:
340-
trace_attributes["traceparent"] = traceparent
341-
tracestate = headers.get(b"tracestate", b"").decode()
342-
if tracestate:
343-
trace_attributes["tracestate"] = tracestate
344-
345-
trace_path = webmethod.descriptive_name or route_path
346-
trace_context = await start_trace(trace_path, trace_attributes)
347-
348-
async def send_with_trace_id(message):
349-
if message["type"] == "http.response.start":
350-
headers = message.get("headers", [])
351-
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
352-
message["headers"] = headers
353-
await send(message)
354-
355-
try:
356-
return await self.app(scope, receive, send_with_trace_id)
357-
finally:
358-
await end_trace()
359-
360-
361296
class ClientVersionMiddleware:
362297
def __init__(self, app):
363298
self.app = app

llama_stack/core/server/tracing.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
from aiohttp import hdrs
7+
8+
from llama_stack.core.external import ExternalApiSpec
9+
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
10+
from llama_stack.log import get_logger
11+
from llama_stack.providers.utils.telemetry.tracing import end_trace, start_trace
12+
13+
logger = get_logger(name=__name__, category="core::server")
14+
15+
16+
class TracingMiddleware:
17+
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
18+
self.app = app
19+
self.impls = impls
20+
self.external_apis = external_apis
21+
# FastAPI built-in paths that should bypass custom routing
22+
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
23+
24+
async def __call__(self, scope, receive, send):
25+
if scope.get("type") == "lifespan":
26+
return await self.app(scope, receive, send)
27+
28+
path = scope.get("path", "")
29+
30+
# Check if the path is a FastAPI built-in path
31+
if path.startswith(self.fastapi_paths):
32+
# Pass through to FastAPI's built-in handlers
33+
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
34+
return await self.app(scope, receive, send)
35+
36+
if not hasattr(self, "route_impls"):
37+
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
38+
39+
try:
40+
_, _, route_path, webmethod = find_matching_route(
41+
scope.get("method", hdrs.METH_GET), path, self.route_impls
42+
)
43+
except ValueError:
44+
# If no matching endpoint is found, pass through to FastAPI
45+
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
46+
return await self.app(scope, receive, send)
47+
48+
trace_attributes = {"__location__": "server", "raw_path": path}
49+
50+
# Extract W3C trace context headers and store as trace attributes
51+
headers = dict(scope.get("headers", []))
52+
traceparent = headers.get(b"traceparent", b"").decode()
53+
if traceparent:
54+
trace_attributes["traceparent"] = traceparent
55+
tracestate = headers.get(b"tracestate", b"").decode()
56+
if tracestate:
57+
trace_attributes["tracestate"] = tracestate
58+
59+
trace_path = webmethod.descriptive_name or route_path
60+
trace_context = await start_trace(trace_path, trace_attributes)
61+
62+
async def send_with_trace_id(message):
63+
if message["type"] == "http.response.start":
64+
headers = message.get("headers", [])
65+
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
66+
message["headers"] = headers
67+
await send(message)
68+
69+
try:
70+
return await self.app(scope, receive, send_with_trace_id)
71+
finally:
72+
await end_trace()

0 commit comments

Comments
 (0)