From eb3cab1eece1cf72591bf1fa11a33d19e5d82eec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 19 Nov 2025 15:29:37 +0100 Subject: [PATCH 01/24] feat: Implement FastAPI router system MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit introduces a new FastAPI router-based system for defining API endpoints, enabling a migration path away from the legacy @webmethod decorator system. The implementation includes router infrastructure, migration of the Batches API as the first example, and updates to server, OpenAPI generation, and inspection systems to support both routing approaches. The router infrastructure consists of a router registry system that allows APIs to register FastAPI router factories, which are then automatically discovered and included in the server application. Standard error responses are centralized in router_utils to ensure consistent OpenAPI specification generation with proper $ref references to component responses. The Batches API has been migrated to demonstrate the new pattern. The protocol definition and models remain in llama_stack_api/batches, maintaining clear separation between API contracts and server implementation. The FastAPI router implementation lives in llama_stack/core/server/routers/batches, following the established pattern where API contracts are defined in llama_stack_api and server routing logic lives in llama_stack/core/server. The server now checks for registered routers before falling back to the legacy webmethod-based route discovery, ensuring backward compatibility during the migration period. The OpenAPI generator has been updated to handle both router-based and webmethod-based routes, correctly extracting metadata from FastAPI route decorators and Pydantic Field descriptions. The inspect endpoint now includes routes from both systems, with proper filtering for deprecated routes and API levels. Response descriptions are now explicitly defined in router decorators, ensuring the generated OpenAPI specification matches the previous format. Error responses use $ref references to component responses (BadRequest400, TooManyRequests429, etc.) as required by the specification. This is neat and will allow us to remove a lot of boiler plate code from our generator once the migration is done. This implementation provides a foundation for incrementally migrating other APIs to the router system while maintaining full backward compatibility with existing webmethod-based APIs. Closes: https://github.com/llamastack/llama-stack/issues/4188 Signed-off-by: Sébastien Han --- client-sdks/stainless/openapi.yml | 36 ++++-- docs/static/deprecated-llama-stack-spec.yaml | 6 + .../static/experimental-llama-stack-spec.yaml | 34 +++++ docs/static/llama-stack-spec.yaml | 36 ++++-- docs/static/stainless-llama-stack-spec.yaml | 36 ++++-- scripts/openapi_generator/app.py | 34 ++++- src/llama_stack/core/inspect.py | 87 +++++++++++-- .../core/server/router_registry.py | 64 +++++++++ src/llama_stack/core/server/router_utils.py | 14 ++ .../core/server/routers/__init__.py | 12 ++ .../core/server/routers/batches.py | 121 ++++++++++++++++++ src/llama_stack/core/server/routes.py | 12 ++ src/llama_stack/core/server/server.py | 88 +++++++++---- src/llama_stack/core/server/tracing.py | 59 +++++++++ .../{batches.py => batches/__init__.py} | 59 +++------ src/llama_stack_api/batches/models.py | 37 ++++++ 16 files changed, 608 insertions(+), 127 deletions(-) create mode 100644 src/llama_stack/core/server/router_registry.py create mode 100644 src/llama_stack/core/server/router_utils.py create mode 100644 src/llama_stack/core/server/routers/__init__.py create mode 100644 src/llama_stack/core/server/routers/batches.py rename src/llama_stack_api/{batches.py => batches/__init__.py} (52%) create mode 100644 src/llama_stack_api/batches/models.py diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index ff86e30e10..be941f6528 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -37,7 +37,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -76,9 +76,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -97,20 +99,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -119,7 +121,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/batches/{batch_id}/cancel: post: responses: @@ -130,20 +132,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -152,7 +154,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/chat/completions: get: responses: @@ -3950,29 +3952,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 3bc06d7d7e..94b1a69a77 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -793,29 +793,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 2b36ebf473..dfd3545447 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -688,6 +688,40 @@ components: - data title: ListBatchesResponse description: Response containing a list of batch objects. + CreateBatchRequest: + properties: + input_file_id: + type: string + title: Input File Id + description: The ID of an uploaded file containing requests for the batch. + endpoint: + type: string + title: Endpoint + description: The endpoint to be used for all requests in the batch. + completion_window: + type: string + const: 24h + title: Completion Window + description: The time window within which the batch should be processed. + metadata: + anyOf: + - additionalProperties: + type: string + type: object + - type: 'null' + description: Optional metadata for the batch. + idempotency_key: + anyOf: + - type: string + - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. + type: object + required: + - input_file_id + - endpoint + - completion_window + title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index a12ac342f9..a736fc8f98 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -35,7 +35,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -74,9 +74,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -95,20 +97,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -117,7 +119,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/batches/{batch_id}/cancel: post: responses: @@ -128,20 +130,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -150,7 +152,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/chat/completions: get: responses: @@ -2971,29 +2973,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index ff86e30e10..be941f6528 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -37,7 +37,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -76,9 +76,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -97,20 +99,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -119,7 +121,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/batches/{batch_id}/cancel: post: responses: @@ -130,20 +132,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -152,7 +154,7 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + title: Batch Id /v1/chat/completions: get: responses: @@ -3950,29 +3952,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: diff --git a/scripts/openapi_generator/app.py b/scripts/openapi_generator/app.py index d972889cdc..48afc157d3 100644 --- a/scripts/openapi_generator/app.py +++ b/scripts/openapi_generator/app.py @@ -64,7 +64,8 @@ def _get_protocol_method(api: Api, method_name: str) -> Any | None: def create_llama_stack_app() -> FastAPI: """ Create a FastAPI app that represents the Llama Stack API. - This uses the existing route discovery system to automatically find all routes. + This uses both router-based routes (for migrated APIs) and the existing + route discovery system for legacy webmethod-based routes. """ app = FastAPI( title="Llama Stack API", @@ -75,15 +76,42 @@ def create_llama_stack_app() -> FastAPI: ], ) - # Get all API routes + # Import batches router to trigger router registration + try: + from llama_stack.core.server.routers import batches # noqa: F401 + except ImportError: + pass + + # Include routers for APIs that have them registered + from llama_stack.core.server.router_registry import create_router, has_router + + def dummy_impl_getter(api: Api) -> Any: + """Dummy implementation getter for OpenAPI generation.""" + return None + + # Get all APIs that might have routers + from llama_stack.core.resolver import api_protocol_map + + protocols = api_protocol_map() + for api in protocols.keys(): + if has_router(api): + router = create_router(api, dummy_impl_getter) + if router: + app.include_router(router) + + # Get all API routes (for legacy webmethod-based routes) from llama_stack.core.server.routes import get_all_api_routes api_routes = get_all_api_routes() - # Create FastAPI routes from the discovered routes + # Create FastAPI routes from the discovered routes (skip APIs that have routers) from . import endpoints for api, routes in api_routes.items(): + # Skip APIs that have routers - they're already included above + if has_router(api): + continue + for route, webmethod in routes: # Convert the route to a FastAPI endpoint endpoints._create_fastapi_endpoint(app, route, webmethod, api) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 272c9d1bc2..be5a26c14a 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,8 +10,10 @@ from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis +from llama_stack.core.server.router_registry import create_router, has_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( + Api, HealthInfo, HealthStatus, Inspect, @@ -57,34 +59,91 @@ def should_include_route(webmethod) -> bool: ret = [] external_apis = load_external_apis(run_config) all_endpoints = get_all_api_routes(external_apis) + + # Helper function to get provider types for an API + def get_provider_types(api: Api) -> list[str]: + if api.value in ["providers", "inspect"]: + return [] # These APIs don't have "real" providers they're internal to the stack + providers = run_config.providers.get(api.value, []) + return [p.provider_type for p in providers] if providers else [] + + # Process webmethod-based routes (legacy) for api, endpoints in all_endpoints.items(): + # Skip APIs that have routers - they'll be processed separately + if has_router(api): + continue + + provider_types = get_provider_types(api) # Always include provider and inspect APIs, filter others based on run config - if api.value in ["providers", "inspect"]: + if api.value in ["providers", "inspect"] or provider_types: ret.extend( [ RouteInfo( route=e.path, method=next(iter([m for m in e.methods if m != "HEAD"])), - provider_types=[], # These APIs don't have "real" providers - they're internal to the stack + provider_types=provider_types, ) for e, webmethod in endpoints if e.methods is not None and should_include_route(webmethod) ] ) + + # Helper function to determine if a router route should be included based on api_filter + def should_include_router_route(route, router_prefix: str | None) -> bool: + """Check if a router-based route should be included based on api_filter.""" + # Check deprecated status + route_deprecated = getattr(route, "deprecated", False) or False + + if api_filter is None: + # Default: only non-deprecated routes + return not route_deprecated + elif api_filter == "deprecated": + # Special filter: show deprecated routes regardless of their actual level + return route_deprecated else: - providers = run_config.providers.get(api.value, []) - if providers: # Only process if there are providers for this API - ret.extend( - [ - RouteInfo( - route=e.path, - method=next(iter([m for m in e.methods if m != "HEAD"])), - provider_types=[p.provider_type for p in providers], + # Filter by API level (non-deprecated routes only) + # Extract level from router prefix (e.g., "/v1" -> "v1") + if router_prefix: + prefix_level = router_prefix.lstrip("/") + return not route_deprecated and prefix_level == api_filter + return not route_deprecated + + # Process router-based routes + def dummy_impl_getter(api: Api) -> None: + """Dummy implementation getter for route inspection.""" + return None + + from llama_stack.core.resolver import api_protocol_map + + protocols = api_protocol_map(external_apis) + for api in protocols.keys(): + if not has_router(api): + continue + + router = create_router(api, dummy_impl_getter) + if not router: + continue + + provider_types = get_provider_types(api) + # Only include if there are providers (or it's a special API) + if api.value in ["providers", "inspect"] or provider_types: + router_prefix = getattr(router, "prefix", None) + for route in router.routes: + # Extract HTTP methods from the route + # FastAPI routes have methods as a set + if hasattr(route, "methods") and route.methods: + methods = {m for m in route.methods if m != "HEAD"} + if methods and should_include_router_route(route, router_prefix): + # FastAPI already combines router prefix with route path + path = route.path + + ret.append( + RouteInfo( + route=path, + method=next(iter(methods)), + provider_types=provider_types, + ) ) - for e, webmethod in endpoints - if e.methods is not None and should_include_route(webmethod) - ] - ) return ListRoutesResponse(data=ret) diff --git a/src/llama_stack/core/server/router_registry.py b/src/llama_stack/core/server/router_registry.py new file mode 100644 index 0000000000..e149d13465 --- /dev/null +++ b/src/llama_stack/core/server/router_registry.py @@ -0,0 +1,64 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Router registry for FastAPI routers. + +This module provides a way to register FastAPI routers for APIs that have been +migrated to use explicit FastAPI routers instead of Protocol-based route discovery. +""" + +from collections.abc import Callable +from typing import TYPE_CHECKING, Any + +from fastapi import APIRouter + +if TYPE_CHECKING: + from llama_stack_api.datatypes import Api + +# Registry of router factory functions +# Each factory function takes a callable that returns the implementation for a given API +# and returns an APIRouter +# Use string keys to avoid circular imports +_router_factories: dict[str, Callable[[Callable[["Api"], Any]], APIRouter]] = {} + + +def register_router(api: "Api", router_factory: Callable[[Callable[["Api"], Any]], APIRouter]) -> None: + """Register a router factory for an API. + + Args: + api: The API enum value + router_factory: A function that takes an impl_getter function and returns an APIRouter + """ + _router_factories[api.value] = router_factory # type: ignore[attr-defined] + + +def has_router(api: "Api") -> bool: + """Check if an API has a registered router. + + Args: + api: The API enum value + + Returns: + True if a router factory is registered for this API + """ + return api.value in _router_factories # type: ignore[attr-defined] + + +def create_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | None: + """Create a router for an API if one is registered. + + Args: + api: The API enum value + impl_getter: Function that returns the implementation for a given API + + Returns: + APIRouter if registered, None otherwise + """ + api_value = api.value # type: ignore[attr-defined] + if api_value not in _router_factories: + return None + + return _router_factories[api_value](impl_getter) diff --git a/src/llama_stack/core/server/router_utils.py b/src/llama_stack/core/server/router_utils.py new file mode 100644 index 0000000000..1c508af76b --- /dev/null +++ b/src/llama_stack/core/server/router_utils.py @@ -0,0 +1,14 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Utilities for creating FastAPI routers with standard error responses.""" + +standard_responses = { + 400: {"$ref": "#/components/responses/BadRequest400"}, + 429: {"$ref": "#/components/responses/TooManyRequests429"}, + 500: {"$ref": "#/components/responses/InternalServerError500"}, + "default": {"$ref": "#/components/responses/DefaultError"}, +} diff --git a/src/llama_stack/core/server/routers/__init__.py b/src/llama_stack/core/server/routers/__init__.py new file mode 100644 index 0000000000..213cb75c8d --- /dev/null +++ b/src/llama_stack/core/server/routers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""FastAPI router implementations for server endpoints. + +This package contains FastAPI router implementations that define the HTTP +endpoints for each API. The API contracts (protocols and models) are defined +in llama_stack_api, while the server routing implementation lives here. +""" diff --git a/src/llama_stack/core/server/routers/batches.py b/src/llama_stack/core/server/routers/batches.py new file mode 100644 index 0000000000..fb7f8ebfa0 --- /dev/null +++ b/src/llama_stack/core/server/routers/batches.py @@ -0,0 +1,121 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""FastAPI router for the Batches API. + +This module defines the FastAPI router for the Batches API using standard +FastAPI route decorators instead of Protocol-based route discovery. +""" + +from collections.abc import Callable +from typing import Annotated + +from fastapi import APIRouter, Body, Depends + +from llama_stack.core.server.router_registry import register_router +from llama_stack.core.server.router_utils import standard_responses +from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse +from llama_stack_api.batches.models import CreateBatchRequest +from llama_stack_api.datatypes import Api +from llama_stack_api.version import LLAMA_STACK_API_V1 + + +def create_batches_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: + """Create a FastAPI router for the Batches API. + + Args: + impl_getter: Function that returns the Batches implementation for the batches API + + Returns: + APIRouter configured for the Batches API + """ + router = APIRouter( + prefix=f"/{LLAMA_STACK_API_V1}", + tags=["Batches"], + responses=standard_responses, + ) + + def get_batch_service() -> Batches: + """Dependency function to get the batch service implementation.""" + return impl_getter(Api.batches) + + @router.post( + "/batches", + response_model=BatchObject, + summary="Create a new batch for processing multiple API requests.", + description="Create a new batch for processing multiple API requests.", + responses={ + 200: {"description": "The created batch object."}, + 409: {"description": "Conflict: The idempotency key was previously used with different parameters."}, + }, + ) + async def create_batch( + request: Annotated[CreateBatchRequest, Body(...)], + svc: Annotated[Batches, Depends(get_batch_service)], + ) -> BatchObject: + """Create a new batch.""" + return await svc.create_batch( + input_file_id=request.input_file_id, + endpoint=request.endpoint, + completion_window=request.completion_window, + metadata=request.metadata, + idempotency_key=request.idempotency_key, + ) + + @router.get( + "/batches/{batch_id}", + response_model=BatchObject, + summary="Retrieve information about a specific batch.", + description="Retrieve information about a specific batch.", + responses={ + 200: {"description": "The batch object."}, + }, + ) + async def retrieve_batch( + batch_id: str, + svc: Annotated[Batches, Depends(get_batch_service)], + ) -> BatchObject: + """Retrieve information about a specific batch.""" + return await svc.retrieve_batch(batch_id) + + @router.post( + "/batches/{batch_id}/cancel", + response_model=BatchObject, + summary="Cancel a batch that is in progress.", + description="Cancel a batch that is in progress.", + responses={ + 200: {"description": "The updated batch object."}, + }, + ) + async def cancel_batch( + batch_id: str, + svc: Annotated[Batches, Depends(get_batch_service)], + ) -> BatchObject: + """Cancel a batch that is in progress.""" + return await svc.cancel_batch(batch_id) + + @router.get( + "/batches", + response_model=ListBatchesResponse, + summary="List all batches for the current user.", + description="List all batches for the current user.", + responses={ + 200: {"description": "A list of batch objects."}, + }, + ) + async def list_batches( + svc: Annotated[Batches, Depends(get_batch_service)], + after: str | None = None, + limit: int = 20, + ) -> ListBatchesResponse: + """List all batches for the current user.""" + return await svc.list_batches(after=after, limit=limit) + + return router + + +# Register the router factory +register_router(Api.batches, create_batches_router) diff --git a/src/llama_stack/core/server/routes.py b/src/llama_stack/core/server/routes.py index af50025654..25027267f8 100644 --- a/src/llama_stack/core/server/routes.py +++ b/src/llama_stack/core/server/routes.py @@ -26,6 +26,18 @@ def get_all_api_routes( external_apis: dict[Api, ExternalApiSpec] | None = None, ) -> dict[Api, list[tuple[Route, WebMethod]]]: + """Get all API routes from webmethod-based protocols. + + This function only returns routes from APIs that use the legacy @webmethod + decorator system. For APIs that have been migrated to FastAPI routers, + use the router registry (router_registry.has_router() and router_registry.create_router()). + + Args: + external_apis: Optional dictionary of external API specifications + + Returns: + Dictionary mapping API to list of (Route, WebMethod) tuples + """ apis = {} protocols = api_protocol_map(external_apis) diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 0d3513980a..76f283f3a7 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -44,6 +44,7 @@ request_provider_data_context, user_from_scope, ) +from llama_stack.core.server.router_registry import create_router, has_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, @@ -87,7 +88,7 @@ def create_sse_event(data: Any) -> str: async def global_exception_handler(request: Request, exc: Exception): - traceback.print_exception(exc) + traceback.print_exception(type(exc), exc, exc.__traceback__) http_exc = translate_exception(exc) return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}) @@ -448,6 +449,14 @@ def create_app() -> StackApp: external_apis = load_external_apis(config) all_routes = get_all_api_routes(external_apis) + # Import batches router to trigger router registration + # This ensures the router is registered before we try to use it + # We will make this code better once the migration is complete + try: + from llama_stack.core.server.routers import batches # noqa: F401 + except ImportError: + pass + if config.apis: apis_to_serve = set(config.apis) else: @@ -463,41 +472,68 @@ def create_app() -> StackApp: apis_to_serve.add("providers") apis_to_serve.add("prompts") apis_to_serve.add("conversations") - for api_str in apis_to_serve: - api = Api(api_str) - routes = all_routes[api] + def impl_getter(api: Api) -> Any: + """Get the implementation for a given API.""" try: - impl = impls[api] + return impls[api] except KeyError as e: raise ValueError(f"Could not find provider implementation for {api} API") from e - for route, _ in routes: - if not hasattr(impl, route.name): - # ideally this should be a typing violation already - raise ValueError(f"Could not find method {route.name} on {impl}!") - - impl_method = getattr(impl, route.name) - # Filter out HEAD method since it's automatically handled by FastAPI for GET routes - available_methods = [m for m in route.methods if m != "HEAD"] - if not available_methods: - raise ValueError(f"No methods found for {route.name} on {impl}") - method = available_methods[0] - logger.debug(f"{method} {route.path}") - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") - getattr(app, method.lower())(route.path, response_model=None)( - create_dynamic_typed_route( - impl_method, - method.lower(), - route.path, - ) + for api_str in apis_to_serve: + api = Api(api_str) + + if has_router(api): + router = create_router(api, impl_getter) + if router: + app.include_router(router) + logger.debug(f"Registered router for {api} API") + else: + logger.warning( + f"API '{api.value}' has a registered router factory but it returned None. Skipping this API." ) + else: + # Fall back to old webmethod-based route discovery until the migration is complete + routes = all_routes[api] + try: + impl = impls[api] + except KeyError as e: + raise ValueError(f"Could not find provider implementation for {api} API") from e + + for route, _ in routes: + if not hasattr(impl, route.name): + # ideally this should be a typing violation already + raise ValueError(f"Could not find method {route.name} on {impl}!") + + impl_method = getattr(impl, route.name) + # Filter out HEAD method since it's automatically handled by FastAPI for GET routes + available_methods = [m for m in route.methods if m != "HEAD"] + if not available_methods: + raise ValueError(f"No methods found for {route.name} on {impl}") + method = available_methods[0] + logger.debug(f"{method} {route.path}") + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") + getattr(app, method.lower())(route.path, response_model=None)( + create_dynamic_typed_route( + impl_method, + method.lower(), + route.path, + ) + ) logger.debug(f"serving APIs: {apis_to_serve}") + # Register specific exception handlers before the generic Exception handler + # This prevents the re-raising behavior that causes connection resets app.exception_handler(RequestValidationError)(global_exception_handler) + app.exception_handler(ConflictError)(global_exception_handler) + app.exception_handler(ResourceNotFoundError)(global_exception_handler) + app.exception_handler(AuthenticationRequiredError)(global_exception_handler) + app.exception_handler(AccessDeniedError)(global_exception_handler) + app.exception_handler(BadRequestError)(global_exception_handler) + # Generic Exception handler should be last app.exception_handler(Exception)(global_exception_handler) if config.telemetry.enabled: diff --git a/src/llama_stack/core/server/tracing.py b/src/llama_stack/core/server/tracing.py index c4901d9b12..352badcaae 100644 --- a/src/llama_stack/core/server/tracing.py +++ b/src/llama_stack/core/server/tracing.py @@ -6,9 +6,11 @@ from aiohttp import hdrs from llama_stack.core.external import ExternalApiSpec +from llama_stack.core.server.router_registry import has_router from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.telemetry.tracing import end_trace, start_trace from llama_stack.log import get_logger +from llama_stack_api.datatypes import Api logger = get_logger(name=__name__, category="core::server") @@ -21,6 +23,25 @@ def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): # FastAPI built-in paths that should bypass custom routing self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") + def _is_router_based_route(self, path: str) -> bool: + """Check if a path belongs to a router-based API. + + Router-based APIs use FastAPI routers instead of the old webmethod system. + We need to check if the path matches any router-based API prefix. + """ + # Extract API name from path (e.g., /v1/batches -> batches) + # Paths are typically /v1/{api_name} or /v1/{api_name}/... + parts = path.strip("/").split("/") + if len(parts) >= 2 and parts[0].startswith("v"): + api_name = parts[1] + try: + api = Api(api_name) + return has_router(api) + except (ValueError, KeyError): + # Not a known API or not router-based + return False + return False + async def __call__(self, scope, receive, send): if scope.get("type") == "lifespan": return await self.app(scope, receive, send) @@ -33,6 +54,44 @@ async def __call__(self, scope, receive, send): logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") return await self.app(scope, receive, send) + # Check if this is a router-based route - if so, pass through to FastAPI + # Router-based routes are handled by FastAPI directly, so we skip the old route lookup + # but still need to set up tracing + is_router_based = self._is_router_based_route(path) + if is_router_based: + logger.debug(f"Router-based route detected: {path}, setting up tracing") + # Set up tracing for router-based routes + trace_attributes = {"__location__": "server", "raw_path": path} + + # Extract W3C trace context headers and store as trace attributes + headers = dict(scope.get("headers", [])) + traceparent = headers.get(b"traceparent", b"").decode() + if traceparent: + trace_attributes["traceparent"] = traceparent + tracestate = headers.get(b"tracestate", b"").decode() + if tracestate: + trace_attributes["tracestate"] = tracestate + + trace_context = await start_trace(path, trace_attributes) + + async def send_with_trace_id(message): + if message["type"] == "http.response.start": + headers = message.get("headers", []) + headers.append([b"x-trace-id", str(trace_context.trace_id).encode()]) + message["headers"] = headers + await send(message) + + try: + return await self.app(scope, receive, send_with_trace_id) + finally: + # Always end trace, even if exception occurred + # FastAPI's exception handler will handle the exception and send the response + # The exception will continue to propagate for logging, which is normal + try: + await end_trace() + except Exception: + logger.exception("Error ending trace") + if not hasattr(self, "route_impls"): self.route_impls = initialize_route_impls(self.impls, self.external_apis) diff --git a/src/llama_stack_api/batches.py b/src/llama_stack_api/batches/__init__.py similarity index 52% rename from src/llama_stack_api/batches.py rename to src/llama_stack_api/batches/__init__.py index 00c47d39f7..636dd0c52a 100644 --- a/src/llama_stack_api/batches.py +++ b/src/llama_stack_api/batches/__init__.py @@ -4,12 +4,17 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +"""Batches API protocol and models. + +This module contains the Batches protocol definition and related models. +The router implementation is in llama_stack.core.server.routers.batches. +""" + from typing import Literal, Protocol, runtime_checkable from pydantic import BaseModel, Field -from llama_stack_api.schema_utils import json_schema_type, webmethod -from llama_stack_api.version import LLAMA_STACK_API_V1 +from llama_stack_api.schema_utils import json_schema_type try: from openai.types import Batch as BatchObject @@ -43,7 +48,6 @@ class Batches(Protocol): Note: This API is currently under active development and may undergo changes. """ - @webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1) async def create_batch( self, input_file_id: str, @@ -51,46 +55,17 @@ async def create_batch( completion_window: Literal["24h"], metadata: dict[str, str] | None = None, idempotency_key: str | None = None, - ) -> BatchObject: - """Create a new batch for processing multiple API requests. - - :param input_file_id: The ID of an uploaded file containing requests for the batch. - :param endpoint: The endpoint to be used for all requests in the batch. - :param completion_window: The time window within which the batch should be processed. - :param metadata: Optional metadata for the batch. - :param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior. - :returns: The created batch object. - """ - ... - - @webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1) - async def retrieve_batch(self, batch_id: str) -> BatchObject: - """Retrieve information about a specific batch. - - :param batch_id: The ID of the batch to retrieve. - :returns: The batch object. - """ - ... - - @webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1) - async def cancel_batch(self, batch_id: str) -> BatchObject: - """Cancel a batch that is in progress. - - :param batch_id: The ID of the batch to cancel. - :returns: The updated batch object. - """ - ... - - @webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1) + ) -> BatchObject: ... + + async def retrieve_batch(self, batch_id: str) -> BatchObject: ... + + async def cancel_batch(self, batch_id: str) -> BatchObject: ... + async def list_batches( self, after: str | None = None, limit: int = 20, - ) -> ListBatchesResponse: - """List all batches for the current user. - - :param after: A cursor for pagination; returns batches after this batch ID. - :param limit: Number of batches to return (default 20, max 100). - :returns: A list of batch objects. - """ - ... + ) -> ListBatchesResponse: ... + + +__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] diff --git a/src/llama_stack_api/batches/models.py b/src/llama_stack_api/batches/models.py new file mode 100644 index 0000000000..22e024be21 --- /dev/null +++ b/src/llama_stack_api/batches/models.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Pydantic models for Batches API requests and responses. + +This module defines the request and response models for the Batches API +using Pydantic with Field descriptions for OpenAPI schema generation. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + +from llama_stack_api.batches import BatchObject, ListBatchesResponse +from llama_stack_api.schema_utils import json_schema_type + + +@json_schema_type +class CreateBatchRequest(BaseModel): + """Request model for creating a batch.""" + + input_file_id: str = Field(..., description="The ID of an uploaded file containing requests for the batch.") + endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch.") + completion_window: Literal["24h"] = Field( + ..., description="The time window within which the batch should be processed." + ) + metadata: dict[str, str] | None = Field(default=None, description="Optional metadata for the batch.") + idempotency_key: str | None = Field( + default=None, description="Optional idempotency key. When provided, enables idempotent behavior." + ) + + +# Re-export response models for convenience +__all__ = ["CreateBatchRequest", "BatchObject", "ListBatchesResponse"] From 2fe24a6df8d6b650cdfc64843f3e6fd593cceba6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 20 Nov 2025 12:41:24 +0100 Subject: [PATCH 02/24] chore: move ListBatchesResponse to models.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- .../core/server/routers/batches.py | 4 ---- src/llama_stack_api/batches/__init__.py | 19 ++++-------------- src/llama_stack_api/batches/models.py | 20 ++++++++++++++++--- 3 files changed, 21 insertions(+), 22 deletions(-) diff --git a/src/llama_stack/core/server/routers/batches.py b/src/llama_stack/core/server/routers/batches.py index fb7f8ebfa0..de1605eafb 100644 --- a/src/llama_stack/core/server/routers/batches.py +++ b/src/llama_stack/core/server/routers/batches.py @@ -56,7 +56,6 @@ async def create_batch( request: Annotated[CreateBatchRequest, Body(...)], svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - """Create a new batch.""" return await svc.create_batch( input_file_id=request.input_file_id, endpoint=request.endpoint, @@ -78,7 +77,6 @@ async def retrieve_batch( batch_id: str, svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - """Retrieve information about a specific batch.""" return await svc.retrieve_batch(batch_id) @router.post( @@ -94,7 +92,6 @@ async def cancel_batch( batch_id: str, svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - """Cancel a batch that is in progress.""" return await svc.cancel_batch(batch_id) @router.get( @@ -111,7 +108,6 @@ async def list_batches( after: str | None = None, limit: int = 20, ) -> ListBatchesResponse: - """List all batches for the current user.""" return await svc.list_batches(after=after, limit=limit) return router diff --git a/src/llama_stack_api/batches/__init__.py b/src/llama_stack_api/batches/__init__.py index 636dd0c52a..2dff546b42 100644 --- a/src/llama_stack_api/batches/__init__.py +++ b/src/llama_stack_api/batches/__init__.py @@ -6,31 +6,20 @@ """Batches API protocol and models. -This module contains the Batches protocol definition and related models. +This module contains the Batches protocol definition. +Pydantic models are defined in llama_stack_api.batches.models. The router implementation is in llama_stack.core.server.routers.batches. """ from typing import Literal, Protocol, runtime_checkable -from pydantic import BaseModel, Field - -from llama_stack_api.schema_utils import json_schema_type - try: from openai.types import Batch as BatchObject except ImportError as e: raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e - -@json_schema_type -class ListBatchesResponse(BaseModel): - """Response containing a list of batch objects.""" - - object: Literal["list"] = "list" - data: list[BatchObject] = Field(..., description="List of batch objects") - first_id: str | None = Field(default=None, description="ID of the first batch in the list") - last_id: str | None = Field(default=None, description="ID of the last batch in the list") - has_more: bool = Field(default=False, description="Whether there are more batches available") +# Import models for re-export +from llama_stack_api.batches.models import ListBatchesResponse @runtime_checkable diff --git a/src/llama_stack_api/batches/models.py b/src/llama_stack_api/batches/models.py index 22e024be21..fe449280d7 100644 --- a/src/llama_stack_api/batches/models.py +++ b/src/llama_stack_api/batches/models.py @@ -14,9 +14,13 @@ from pydantic import BaseModel, Field -from llama_stack_api.batches import BatchObject, ListBatchesResponse from llama_stack_api.schema_utils import json_schema_type +try: + from openai.types import Batch as BatchObject +except ImportError as e: + raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e + @json_schema_type class CreateBatchRequest(BaseModel): @@ -33,5 +37,15 @@ class CreateBatchRequest(BaseModel): ) -# Re-export response models for convenience -__all__ = ["CreateBatchRequest", "BatchObject", "ListBatchesResponse"] +@json_schema_type +class ListBatchesResponse(BaseModel): + """Response containing a list of batch objects.""" + + object: Literal["list"] = "list" + data: list[BatchObject] = Field(..., description="List of batch objects") + first_id: str | None = Field(default=None, description="ID of the first batch in the list") + last_id: str | None = Field(default=None, description="ID of the last batch in the list") + has_more: bool = Field(default=False, description="Whether there are more batches available") + + +__all__ = ["CreateBatchRequest", "ListBatchesResponse", "BatchObject"] From 00e7ea6c3b2156f677f4999c53918f12fc945267 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 20 Nov 2025 15:00:11 +0100 Subject: [PATCH 03/24] fix: adopt FastAPI directly in llama-stack-api MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit migrates the Batches API to use FastAPI routers directly in the API package, removing the need for custom decorator systems and manual router registration. The API package now defines FastAPI routers using standard FastAPI route decorators, making it self-sufficient and eliminating dependencies on the server package. The router implementation has been moved from llama_stack/core/server/routers/batches.py to llama_stack_api/batches/routes.py, where it belongs alongside the protocol and models. Standard error responses (standard_responses) have been moved from the server package to llama_stack_api/router_utils.py, ensuring the API package can define complete routers without server dependencies. FastAPI has been added as an explicit dependency to the llama-stack-api package, making it an intentional dependency rather than an implicit one. Router discovery is now fully automatic. The server discovers routers by checking for routes modules in each API package and looking for a create_router function. This eliminates the need for manual registration and makes the system scalable - new APIs with router modules are automatically discovered and used. The router registry has been simplified to use automatic discovery instead of maintaining a manual registry. The build_router function (renamed from create_router to better reflect its purpose) discovers and combines router factories with implementations to create the final router instances. Exposing Routers from the API is nice for the Bring Your Own API use case too. Signed-off-by: Sébastien Han --- scripts/openapi_generator/app.py | 12 +--- src/llama_stack/core/inspect.py | 4 +- .../core/server/router_registry.py | 56 +++++++++---------- src/llama_stack/core/server/routes.py | 2 +- src/llama_stack/core/server/server.py | 24 ++------ src/llama_stack_api/batches/__init__.py | 2 +- .../batches/routes.py} | 12 ++-- src/llama_stack_api/pyproject.toml | 1 + .../router_utils.py | 7 ++- uv.lock | 6 +- 10 files changed, 55 insertions(+), 71 deletions(-) rename src/{llama_stack/core/server/routers/batches.py => llama_stack_api/batches/routes.py} (90%) rename src/{llama_stack/core/server => llama_stack_api}/router_utils.py (73%) diff --git a/scripts/openapi_generator/app.py b/scripts/openapi_generator/app.py index 48afc157d3..e98aafb353 100644 --- a/scripts/openapi_generator/app.py +++ b/scripts/openapi_generator/app.py @@ -76,14 +76,8 @@ def create_llama_stack_app() -> FastAPI: ], ) - # Import batches router to trigger router registration - try: - from llama_stack.core.server.routers import batches # noqa: F401 - except ImportError: - pass - - # Include routers for APIs that have them registered - from llama_stack.core.server.router_registry import create_router, has_router + # Include routers for APIs that have them (automatic discovery) + from llama_stack.core.server.router_registry import build_router, has_router def dummy_impl_getter(api: Api) -> Any: """Dummy implementation getter for OpenAPI generation.""" @@ -95,7 +89,7 @@ def dummy_impl_getter(api: Api) -> Any: protocols = api_protocol_map() for api in protocols.keys(): if has_router(api): - router = create_router(api, dummy_impl_getter) + router = build_router(api, dummy_impl_getter) if router: app.include_router(router) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index be5a26c14a..ac7958968b 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,7 +10,7 @@ from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis -from llama_stack.core.server.router_registry import create_router, has_router +from llama_stack.core.server.router_registry import build_router, has_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( Api, @@ -120,7 +120,7 @@ def dummy_impl_getter(api: Api) -> None: if not has_router(api): continue - router = create_router(api, dummy_impl_getter) + router = build_router(api, dummy_impl_getter) if not router: continue diff --git a/src/llama_stack/core/server/router_registry.py b/src/llama_stack/core/server/router_registry.py index e149d13465..4849241e24 100644 --- a/src/llama_stack/core/server/router_registry.py +++ b/src/llama_stack/core/server/router_registry.py @@ -4,12 +4,13 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -"""Router registry for FastAPI routers. +"""Router utilities for FastAPI routers. -This module provides a way to register FastAPI routers for APIs that have been -migrated to use explicit FastAPI routers instead of Protocol-based route discovery. +This module provides utilities to discover and create FastAPI routers from API packages. +Routers are automatically discovered by checking for routes modules in each API package. """ +import importlib from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -18,47 +19,42 @@ if TYPE_CHECKING: from llama_stack_api.datatypes import Api -# Registry of router factory functions -# Each factory function takes a callable that returns the implementation for a given API -# and returns an APIRouter -# Use string keys to avoid circular imports -_router_factories: dict[str, Callable[[Callable[["Api"], Any]], APIRouter]] = {} - - -def register_router(api: "Api", router_factory: Callable[[Callable[["Api"], Any]], APIRouter]) -> None: - """Register a router factory for an API. - - Args: - api: The API enum value - router_factory: A function that takes an impl_getter function and returns an APIRouter - """ - _router_factories[api.value] = router_factory # type: ignore[attr-defined] - def has_router(api: "Api") -> bool: - """Check if an API has a registered router. + """Check if an API has a router factory in its routes module. Args: api: The API enum value Returns: - True if a router factory is registered for this API + True if the API has a routes module with a create_router function """ - return api.value in _router_factories # type: ignore[attr-defined] + try: + routes_module = importlib.import_module(f"llama_stack_api.{api.value}.routes") + return hasattr(routes_module, "create_router") + except (ImportError, AttributeError): + return False -def create_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | None: - """Create a router for an API if one is registered. +def build_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | None: + """Build a router for an API by combining its router factory with the implementation. + + This function discovers the router factory from the API package's routes module + and calls it with the impl_getter to create the final router instance. Args: api: The API enum value impl_getter: Function that returns the implementation for a given API Returns: - APIRouter if registered, None otherwise + APIRouter if the API has a routes module with create_router, None otherwise """ - api_value = api.value # type: ignore[attr-defined] - if api_value not in _router_factories: - return None - - return _router_factories[api_value](impl_getter) + try: + routes_module = importlib.import_module(f"llama_stack_api.{api.value}.routes") + if hasattr(routes_module, "create_router"): + router_factory = routes_module.create_router + return router_factory(impl_getter) + except (ImportError, AttributeError): + pass + + return None diff --git a/src/llama_stack/core/server/routes.py b/src/llama_stack/core/server/routes.py index 25027267f8..cd24e9af28 100644 --- a/src/llama_stack/core/server/routes.py +++ b/src/llama_stack/core/server/routes.py @@ -30,7 +30,7 @@ def get_all_api_routes( This function only returns routes from APIs that use the legacy @webmethod decorator system. For APIs that have been migrated to FastAPI routers, - use the router registry (router_registry.has_router() and router_registry.create_router()). + use the router registry (router_registry.has_router() and router_registry.build_router()). Args: external_apis: Optional dictionary of external API specifications diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 76f283f3a7..f34cef4a91 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -44,7 +44,7 @@ request_provider_data_context, user_from_scope, ) -from llama_stack.core.server.router_registry import create_router, has_router +from llama_stack.core.server.router_registry import build_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, @@ -449,14 +449,6 @@ def create_app() -> StackApp: external_apis = load_external_apis(config) all_routes = get_all_api_routes(external_apis) - # Import batches router to trigger router registration - # This ensures the router is registered before we try to use it - # We will make this code better once the migration is complete - try: - from llama_stack.core.server.routers import batches # noqa: F401 - except ImportError: - pass - if config.apis: apis_to_serve = set(config.apis) else: @@ -483,15 +475,11 @@ def impl_getter(api: Api) -> Any: for api_str in apis_to_serve: api = Api(api_str) - if has_router(api): - router = create_router(api, impl_getter) - if router: - app.include_router(router) - logger.debug(f"Registered router for {api} API") - else: - logger.warning( - f"API '{api.value}' has a registered router factory but it returned None. Skipping this API." - ) + # Try to discover and use a router factory from the API package + router = build_router(api, impl_getter) + if router: + app.include_router(router) + logger.debug(f"Registered router for {api} API") else: # Fall back to old webmethod-based route discovery until the migration is complete routes = all_routes[api] diff --git a/src/llama_stack_api/batches/__init__.py b/src/llama_stack_api/batches/__init__.py index 2dff546b42..6f778de8e0 100644 --- a/src/llama_stack_api/batches/__init__.py +++ b/src/llama_stack_api/batches/__init__.py @@ -8,7 +8,7 @@ This module contains the Batches protocol definition. Pydantic models are defined in llama_stack_api.batches.models. -The router implementation is in llama_stack.core.server.routers.batches. +The FastAPI router is defined in llama_stack_api.batches.routes. """ from typing import Literal, Protocol, runtime_checkable diff --git a/src/llama_stack/core/server/routers/batches.py b/src/llama_stack_api/batches/routes.py similarity index 90% rename from src/llama_stack/core/server/routers/batches.py rename to src/llama_stack_api/batches/routes.py index de1605eafb..e8b6aaf411 100644 --- a/src/llama_stack/core/server/routers/batches.py +++ b/src/llama_stack_api/batches/routes.py @@ -7,7 +7,8 @@ """FastAPI router for the Batches API. This module defines the FastAPI router for the Batches API using standard -FastAPI route decorators instead of Protocol-based route discovery. +FastAPI route decorators. The router is defined in the API package to keep +all API-related code together. """ from collections.abc import Callable @@ -15,15 +16,14 @@ from fastapi import APIRouter, Body, Depends -from llama_stack.core.server.router_registry import register_router -from llama_stack.core.server.router_utils import standard_responses from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse from llama_stack_api.batches.models import CreateBatchRequest from llama_stack_api.datatypes import Api +from llama_stack_api.router_utils import standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 -def create_batches_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: +def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: """Create a FastAPI router for the Batches API. Args: @@ -111,7 +111,3 @@ async def list_batches( return await svc.list_batches(after=after, limit=limit) return router - - -# Register the router factory -register_router(Api.batches, create_batches_router) diff --git a/src/llama_stack_api/pyproject.toml b/src/llama_stack_api/pyproject.toml index 0ceb2bb4e0..0fec354db5 100644 --- a/src/llama_stack_api/pyproject.toml +++ b/src/llama_stack_api/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Information Analysis", ] dependencies = [ + "fastapi>=0.115.0,<1.0", "pydantic>=2.11.9", "jsonschema", "opentelemetry-sdk>=1.30.0", diff --git a/src/llama_stack/core/server/router_utils.py b/src/llama_stack_api/router_utils.py similarity index 73% rename from src/llama_stack/core/server/router_utils.py rename to src/llama_stack_api/router_utils.py index 1c508af76b..1ad19b05ff 100644 --- a/src/llama_stack/core/server/router_utils.py +++ b/src/llama_stack_api/router_utils.py @@ -4,7 +4,12 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -"""Utilities for creating FastAPI routers with standard error responses.""" +"""Utilities for creating FastAPI routers with standard error responses. + +This module provides standard error response definitions for FastAPI routers. +These responses use OpenAPI $ref references to component responses defined +in the OpenAPI specification. +""" standard_responses = { 400: {"$ref": "#/components/responses/BadRequest400"}, diff --git a/uv.lock b/uv.lock index a343eb5d87..a5eded2fdd 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -2294,6 +2294,7 @@ name = "llama-stack-api" version = "0.4.0.dev0" source = { editable = "src/llama_stack_api" } dependencies = [ + { name = "fastapi" }, { name = "jsonschema" }, { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "opentelemetry-sdk" }, @@ -2302,6 +2303,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "fastapi", specifier = ">=0.115.0,<1.0" }, { name = "jsonschema" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" }, @@ -4656,6 +4658,8 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/6b/fa/3234f913fe9a6525a7b97c6dad1f51e72b917e6872e051a5e2ffd8b16fbb/ruamel.yaml.clib-0.2.14-cp314-cp314-macosx_15_0_arm64.whl", hash = "sha256:70eda7703b8126f5e52fcf276e6c0f40b0d314674f896fc58c47b0aef2b9ae83", size = 137970, upload-time = "2025-09-22T19:51:09.472Z" }, { url = "https://files.pythonhosted.org/packages/ef/ec/4edbf17ac2c87fa0845dd366ef8d5852b96eb58fcd65fc1ecf5fe27b4641/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_i686.whl", hash = "sha256:a0cb71ccc6ef9ce36eecb6272c81afdc2f565950cdcec33ae8e6cd8f7fc86f27", size = 739639, upload-time = "2025-09-22T19:51:10.566Z" }, { url = "https://files.pythonhosted.org/packages/15/18/b0e1fafe59051de9e79cdd431863b03593ecfa8341c110affad7c8121efc/ruamel.yaml.clib-0.2.14-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:e7cb9ad1d525d40f7d87b6df7c0ff916a66bc52cb61b66ac1b2a16d0c1b07640", size = 764456, upload-time = "2025-09-22T19:51:11.736Z" }, + { url = "https://files.pythonhosted.org/packages/e7/cd/150fdb96b8fab27fe08d8a59fe67554568727981806e6bc2677a16081ec7/ruamel_yaml_clib-0.2.14-cp314-cp314-win32.whl", hash = "sha256:9b4104bf43ca0cd4e6f738cb86326a3b2f6eef00f417bd1e7efb7bdffe74c539", size = 102394, upload-time = "2025-11-14T21:57:36.703Z" }, + { url = "https://files.pythonhosted.org/packages/bd/e6/a3fa40084558c7e1dc9546385f22a93949c890a8b2e445b2ba43935f51da/ruamel_yaml_clib-0.2.14-cp314-cp314-win_amd64.whl", hash = "sha256:13997d7d354a9890ea1ec5937a219817464e5cc344805b37671562a401ca3008", size = 122673, upload-time = "2025-11-14T21:57:38.177Z" }, ] [[package]] From 30cab020835ae0e2472be0090bf53d03cca5f75d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 20 Nov 2025 15:54:07 +0100 Subject: [PATCH 04/24] chore: refactor Batches protocol to use request models MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This commit refactors the Batches protocol to use Pydantic request models for both create_batch and list_batches methods, improving consistency, readability, and maintainability. - create_batch now accepts a single CreateBatchRequest parameter instead of individual arguments. This aligns the protocol with FastAPI’s request model pattern, allowing the router to pass the request object directly without unpacking parameters. Provider implementations now access fields via request.input_file_id, request.endpoint, etc. - list_batches now accepts a single ListBatchesRequest parameter, replacing individual query parameters. The model includes after and limit fields with proper OpenAPI descriptions. FastAPI automatically parses query parameters into the model for GET requests, keeping router code clean. Provider implementations access fields via request.after and request.limit. Signed-off-by: Sébastien Han --- client-sdks/stainless/openapi.yml | 20 ++++++++ docs/static/deprecated-llama-stack-spec.yaml | 16 ++++++ .../static/experimental-llama-stack-spec.yaml | 16 ++++++ docs/static/llama-stack-spec.yaml | 20 ++++++++ docs/static/stainless-llama-stack-spec.yaml | 20 ++++++++ .../inline/batches/reference/batches.py | 50 +++++++++---------- src/llama_stack_api/batches/__init__.py | 15 ++---- src/llama_stack_api/batches/models.py | 12 ++++- src/llama_stack_api/batches/routes.py | 26 +++++----- 9 files changed, 145 insertions(+), 50 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index be941f6528..c477ae32c7 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -48,14 +48,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -12690,6 +12694,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 94b1a69a77..2e517da70c 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -9531,6 +9531,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index dfd3545447..21112924ce 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -8544,6 +8544,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index a736fc8f98..5335e69b80 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -46,14 +46,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -11419,6 +11423,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index be941f6528..c477ae32c7 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -48,14 +48,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -12690,6 +12694,22 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index 73727799df..aaa105f28d 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -11,7 +11,7 @@ import time import uuid from io import BytesIO -from typing import Any, Literal +from typing import Any from openai.types.batch import BatchError, Errors from pydantic import BaseModel @@ -38,6 +38,7 @@ OpenAIUserMessageParam, ResourceNotFoundError, ) +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest from .config import ReferenceBatchesImplConfig @@ -140,11 +141,7 @@ async def shutdown(self) -> None: # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions async def create_batch( self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - idempotency_key: str | None = None, + request: CreateBatchRequest, ) -> BatchObject: """ Create a new batch for processing multiple API requests. @@ -185,14 +182,14 @@ async def create_batch( # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: + if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", ) - if completion_window != "24h": + if request.completion_window != "24h": raise ValueError( - f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", + f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", ) batch_id = f"batch_{uuid.uuid4().hex[:16]}" @@ -200,8 +197,8 @@ async def create_batch( # For idempotent requests, use the idempotency key for the batch ID # This ensures the same key always maps to the same batch ID, # allowing us to detect parameter conflicts - if idempotency_key is not None: - hash_input = idempotency_key.encode("utf-8") + if request.idempotency_key is not None: + hash_input = request.idempotency_key.encode("utf-8") hash_digest = hashlib.sha256(hash_input).hexdigest()[:24] batch_id = f"batch_{hash_digest}" @@ -209,13 +206,13 @@ async def create_batch( existing_batch = await self.retrieve_batch(batch_id) if ( - existing_batch.input_file_id != input_file_id - or existing_batch.endpoint != endpoint - or existing_batch.completion_window != completion_window - or existing_batch.metadata != metadata + existing_batch.input_file_id != request.input_file_id + or existing_batch.endpoint != request.endpoint + or existing_batch.completion_window != request.completion_window + or existing_batch.metadata != request.metadata ): raise ConflictError( - f"Idempotency key '{idempotency_key}' was previously used with different parameters. " + f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. " "Either use a new idempotency key or ensure all parameters match the original request." ) @@ -230,12 +227,12 @@ async def create_batch( batch = BatchObject( id=batch_id, object="batch", - endpoint=endpoint, - input_file_id=input_file_id, - completion_window=completion_window, + endpoint=request.endpoint, + input_file_id=request.input_file_id, + completion_window=request.completion_window, status="validating", created_at=current_time, - metadata=metadata, + metadata=request.metadata, ) await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) @@ -267,8 +264,7 @@ async def cancel_batch(self, batch_id: str) -> BatchObject: async def list_batches( self, - after: str | None = None, - limit: int = 20, + request: ListBatchesRequest, ) -> ListBatchesResponse: """ List all batches, eventually only for the current user. @@ -285,14 +281,14 @@ async def list_batches( batches.sort(key=lambda b: b.created_at, reverse=True) start_idx = 0 - if after: + if request.after: for i, batch in enumerate(batches): - if batch.id == after: + if batch.id == request.after: start_idx = i + 1 break - page_batches = batches[start_idx : start_idx + limit] - has_more = (start_idx + limit) < len(batches) + page_batches = batches[start_idx : start_idx + request.limit] + has_more = (start_idx + request.limit) < len(batches) first_id = page_batches[0].id if page_batches else None last_id = page_batches[-1].id if page_batches else None diff --git a/src/llama_stack_api/batches/__init__.py b/src/llama_stack_api/batches/__init__.py index 6f778de8e0..3d301598c0 100644 --- a/src/llama_stack_api/batches/__init__.py +++ b/src/llama_stack_api/batches/__init__.py @@ -11,7 +11,7 @@ The FastAPI router is defined in llama_stack_api.batches.routes. """ -from typing import Literal, Protocol, runtime_checkable +from typing import Protocol, runtime_checkable try: from openai.types import Batch as BatchObject @@ -19,7 +19,7 @@ raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e # Import models for re-export -from llama_stack_api.batches.models import ListBatchesResponse +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse @runtime_checkable @@ -39,11 +39,7 @@ class Batches(Protocol): async def create_batch( self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - idempotency_key: str | None = None, + request: CreateBatchRequest, ) -> BatchObject: ... async def retrieve_batch(self, batch_id: str) -> BatchObject: ... @@ -52,9 +48,8 @@ async def cancel_batch(self, batch_id: str) -> BatchObject: ... async def list_batches( self, - after: str | None = None, - limit: int = 20, + request: ListBatchesRequest, ) -> ListBatchesResponse: ... -__all__ = ["Batches", "BatchObject", "ListBatchesResponse"] +__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"] diff --git a/src/llama_stack_api/batches/models.py b/src/llama_stack_api/batches/models.py index fe449280d7..bb6d7e3d0e 100644 --- a/src/llama_stack_api/batches/models.py +++ b/src/llama_stack_api/batches/models.py @@ -37,6 +37,16 @@ class CreateBatchRequest(BaseModel): ) +@json_schema_type +class ListBatchesRequest(BaseModel): + """Request model for listing batches.""" + + after: str | None = Field( + default=None, description="Optional cursor for pagination. Returns batches after this ID." + ) + limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.") + + @json_schema_type class ListBatchesResponse(BaseModel): """Response containing a list of batch objects.""" @@ -48,4 +58,4 @@ class ListBatchesResponse(BaseModel): has_more: bool = Field(default=False, description="Whether there are more batches available") -__all__ = ["CreateBatchRequest", "ListBatchesResponse", "BatchObject"] +__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"] diff --git a/src/llama_stack_api/batches/routes.py b/src/llama_stack_api/batches/routes.py index e8b6aaf411..adbc10be54 100644 --- a/src/llama_stack_api/batches/routes.py +++ b/src/llama_stack_api/batches/routes.py @@ -14,10 +14,10 @@ from collections.abc import Callable from typing import Annotated -from fastapi import APIRouter, Body, Depends +from fastapi import APIRouter, Body, Depends, Query from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse -from llama_stack_api.batches.models import CreateBatchRequest +from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest from llama_stack_api.datatypes import Api from llama_stack_api.router_utils import standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 @@ -56,13 +56,7 @@ async def create_batch( request: Annotated[CreateBatchRequest, Body(...)], svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - return await svc.create_batch( - input_file_id=request.input_file_id, - endpoint=request.endpoint, - completion_window=request.completion_window, - metadata=request.metadata, - idempotency_key=request.idempotency_key, - ) + return await svc.create_batch(request) @router.get( "/batches/{batch_id}", @@ -94,6 +88,15 @@ async def cancel_batch( ) -> BatchObject: return await svc.cancel_batch(batch_id) + def get_list_batches_request( + after: Annotated[ + str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.") + ] = None, + limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20, + ) -> ListBatchesRequest: + """Dependency function to create ListBatchesRequest from query parameters.""" + return ListBatchesRequest(after=after, limit=limit) + @router.get( "/batches", response_model=ListBatchesResponse, @@ -104,10 +107,9 @@ async def cancel_batch( }, ) async def list_batches( + request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)], svc: Annotated[Batches, Depends(get_batch_service)], - after: str | None = None, - limit: int = 20, ) -> ListBatchesResponse: - return await svc.list_batches(after=after, limit=limit) + return await svc.list_batches(request) return router From 20030429e731eb13e9426c227cee001386843897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 20 Nov 2025 16:12:52 +0100 Subject: [PATCH 05/24] chore: same as previous commit but for more fields MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- client-sdks/stainless/openapi.yml | 26 +++++++++++++++ docs/static/deprecated-llama-stack-spec.yaml | 22 +++++++++++++ .../static/experimental-llama-stack-spec.yaml | 22 +++++++++++++ docs/static/llama-stack-spec.yaml | 26 +++++++++++++++ docs/static/stainless-llama-stack-spec.yaml | 26 +++++++++++++++ .../inline/batches/reference/batches.py | 33 +++++++++++-------- src/llama_stack_api/batches/__init__.py | 28 +++++++++++++--- src/llama_stack_api/batches/models.py | 23 ++++++++++++- src/llama_stack_api/batches/routes.py | 29 ++++++++++++---- 9 files changed, 210 insertions(+), 25 deletions(-) diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index c477ae32c7..7e81dbd600 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -125,7 +125,9 @@ paths: required: true schema: type: string + description: The ID of the batch to retrieve. title: Batch Id + description: The ID of the batch to retrieve. /v1/batches/{batch_id}/cancel: post: responses: @@ -158,7 +160,9 @@ paths: required: true schema: type: string + description: The ID of the batch to cancel. title: Batch Id + description: The ID of the batch to cancel. /v1/chat/completions: get: responses: @@ -12710,6 +12714,28 @@ components: type: integer title: ListBatchesRequest type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index 2e517da70c..558fbbf6c9 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -9547,6 +9547,28 @@ components: type: integer title: ListBatchesRequest type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 21112924ce..79a18161bd 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -8560,6 +8560,28 @@ components: type: integer title: ListBatchesRequest type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 5335e69b80..3aebf95cb1 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -123,7 +123,9 @@ paths: required: true schema: type: string + description: The ID of the batch to retrieve. title: Batch Id + description: The ID of the batch to retrieve. /v1/batches/{batch_id}/cancel: post: responses: @@ -156,7 +158,9 @@ paths: required: true schema: type: string + description: The ID of the batch to cancel. title: Batch Id + description: The ID of the batch to cancel. /v1/chat/completions: get: responses: @@ -11439,6 +11443,28 @@ components: type: integer title: ListBatchesRequest type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index c477ae32c7..7e81dbd600 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -125,7 +125,9 @@ paths: required: true schema: type: string + description: The ID of the batch to retrieve. title: Batch Id + description: The ID of the batch to retrieve. /v1/batches/{batch_id}/cancel: post: responses: @@ -158,7 +160,9 @@ paths: required: true schema: type: string + description: The ID of the batch to cancel. title: Batch Id + description: The ID of the batch to cancel. /v1/chat/completions: get: responses: @@ -12710,6 +12714,28 @@ components: type: integer title: ListBatchesRequest type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index aaa105f28d..b0169a4126 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -38,7 +38,12 @@ OpenAIUserMessageParam, ResourceNotFoundError, ) -from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + RetrieveBatchRequest, +) from .config import ReferenceBatchesImplConfig @@ -203,7 +208,7 @@ async def create_batch( batch_id = f"batch_{hash_digest}" try: - existing_batch = await self.retrieve_batch(batch_id) + existing_batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id)) if ( existing_batch.input_file_id != request.input_file_id @@ -244,23 +249,23 @@ async def create_batch( return batch - async def cancel_batch(self, batch_id: str) -> BatchObject: + async def cancel_batch(self, request: CancelBatchRequest) -> BatchObject: """Cancel a batch that is in progress.""" - batch = await self.retrieve_batch(batch_id) + batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id)) if batch.status in ["cancelled", "cancelling"]: return batch if batch.status in ["completed", "failed", "expired"]: - raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") + raise ConflictError(f"Cannot cancel batch '{request.batch_id}' with status '{batch.status}'") - await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) + await self._update_batch(request.batch_id, status="cancelling", cancelling_at=int(time.time())) - if batch_id in self._processing_tasks: - self._processing_tasks[batch_id].cancel() + if request.batch_id in self._processing_tasks: + self._processing_tasks[request.batch_id].cancel() # note: task removal and status="cancelled" handled in finally block of _process_batch - return await self.retrieve_batch(batch_id) + return await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id)) async def list_batches( self, @@ -300,11 +305,11 @@ async def list_batches( has_more=has_more, ) - async def retrieve_batch(self, batch_id: str) -> BatchObject: + async def retrieve_batch(self, request: RetrieveBatchRequest) -> BatchObject: """Retrieve information about a specific batch.""" - batch_data = await self.kvstore.get(f"batch:{batch_id}") + batch_data = await self.kvstore.get(f"batch:{request.batch_id}") if not batch_data: - raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") + raise ResourceNotFoundError(request.batch_id, "Batch", "batches.list()") return BatchObject.model_validate_json(batch_data) @@ -312,7 +317,7 @@ async def _update_batch(self, batch_id: str, **updates) -> None: """Update batch fields in kvstore.""" async with self._update_batch_lock: try: - batch = await self.retrieve_batch(batch_id) + batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id)) # batch processing is async. once cancelling, only allow "cancelled" status updates if batch.status == "cancelling" and updates.get("status") != "cancelled": @@ -532,7 +537,7 @@ async def _process_batch(self, batch_id: str) -> None: async def _process_batch_impl(self, batch_id: str) -> None: """Implementation of batch processing logic.""" errors: list[BatchError] = [] - batch = await self.retrieve_batch(batch_id) + batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id)) errors, requests = await self._validate_input(batch) if errors: diff --git a/src/llama_stack_api/batches/__init__.py b/src/llama_stack_api/batches/__init__.py index 3d301598c0..ebde458447 100644 --- a/src/llama_stack_api/batches/__init__.py +++ b/src/llama_stack_api/batches/__init__.py @@ -19,7 +19,13 @@ raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e # Import models for re-export -from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest, ListBatchesResponse +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + ListBatchesResponse, + RetrieveBatchRequest, +) @runtime_checkable @@ -42,9 +48,15 @@ async def create_batch( request: CreateBatchRequest, ) -> BatchObject: ... - async def retrieve_batch(self, batch_id: str) -> BatchObject: ... + async def retrieve_batch( + self, + request: RetrieveBatchRequest, + ) -> BatchObject: ... - async def cancel_batch(self, batch_id: str) -> BatchObject: ... + async def cancel_batch( + self, + request: CancelBatchRequest, + ) -> BatchObject: ... async def list_batches( self, @@ -52,4 +64,12 @@ async def list_batches( ) -> ListBatchesResponse: ... -__all__ = ["Batches", "BatchObject", "CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse"] +__all__ = [ + "Batches", + "BatchObject", + "CreateBatchRequest", + "ListBatchesRequest", + "RetrieveBatchRequest", + "CancelBatchRequest", + "ListBatchesResponse", +] diff --git a/src/llama_stack_api/batches/models.py b/src/llama_stack_api/batches/models.py index bb6d7e3d0e..49fd25d169 100644 --- a/src/llama_stack_api/batches/models.py +++ b/src/llama_stack_api/batches/models.py @@ -47,6 +47,20 @@ class ListBatchesRequest(BaseModel): limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.") +@json_schema_type +class RetrieveBatchRequest(BaseModel): + """Request model for retrieving a batch.""" + + batch_id: str = Field(..., description="The ID of the batch to retrieve.") + + +@json_schema_type +class CancelBatchRequest(BaseModel): + """Request model for canceling a batch.""" + + batch_id: str = Field(..., description="The ID of the batch to cancel.") + + @json_schema_type class ListBatchesResponse(BaseModel): """Response containing a list of batch objects.""" @@ -58,4 +72,11 @@ class ListBatchesResponse(BaseModel): has_more: bool = Field(default=False, description="Whether there are more batches available") -__all__ = ["CreateBatchRequest", "ListBatchesRequest", "ListBatchesResponse", "BatchObject"] +__all__ = [ + "CreateBatchRequest", + "ListBatchesRequest", + "RetrieveBatchRequest", + "CancelBatchRequest", + "ListBatchesResponse", + "BatchObject", +] diff --git a/src/llama_stack_api/batches/routes.py b/src/llama_stack_api/batches/routes.py index adbc10be54..0946f89e6d 100644 --- a/src/llama_stack_api/batches/routes.py +++ b/src/llama_stack_api/batches/routes.py @@ -14,10 +14,15 @@ from collections.abc import Callable from typing import Annotated -from fastapi import APIRouter, Body, Depends, Query +from fastapi import APIRouter, Body, Depends, Path, Query from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse -from llama_stack_api.batches.models import CreateBatchRequest, ListBatchesRequest +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + RetrieveBatchRequest, +) from llama_stack_api.datatypes import Api from llama_stack_api.router_utils import standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 @@ -58,6 +63,12 @@ async def create_batch( ) -> BatchObject: return await svc.create_batch(request) + def get_retrieve_batch_request( + batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")], + ) -> RetrieveBatchRequest: + """Dependency function to create RetrieveBatchRequest from path parameter.""" + return RetrieveBatchRequest(batch_id=batch_id) + @router.get( "/batches/{batch_id}", response_model=BatchObject, @@ -68,10 +79,16 @@ async def create_batch( }, ) async def retrieve_batch( - batch_id: str, + request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)], svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - return await svc.retrieve_batch(batch_id) + return await svc.retrieve_batch(request) + + def get_cancel_batch_request( + batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")], + ) -> CancelBatchRequest: + """Dependency function to create CancelBatchRequest from path parameter.""" + return CancelBatchRequest(batch_id=batch_id) @router.post( "/batches/{batch_id}/cancel", @@ -83,10 +100,10 @@ async def retrieve_batch( }, ) async def cancel_batch( - batch_id: str, + request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)], svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - return await svc.cancel_batch(batch_id) + return await svc.cancel_batch(request) def get_list_batches_request( after: Annotated[ From 9595619b9f6add0956ccb5616248563de47faa75 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 20 Nov 2025 16:29:14 +0100 Subject: [PATCH 06/24] chore: remove empty dir MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- src/llama_stack/core/server/routers/__init__.py | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 src/llama_stack/core/server/routers/__init__.py diff --git a/src/llama_stack/core/server/routers/__init__.py b/src/llama_stack/core/server/routers/__init__.py deleted file mode 100644 index 213cb75c8d..0000000000 --- a/src/llama_stack/core/server/routers/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -"""FastAPI router implementations for server endpoints. - -This package contains FastAPI router implementations that define the HTTP -endpoints for each API. The API contracts (protocols and models) are defined -in llama_stack_api, while the server routing implementation lives here. -""" From f62c6044b3934ea72e13b635c9ad9b8bf29291e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Thu, 20 Nov 2025 16:40:49 +0100 Subject: [PATCH 07/24] chore: update unit test to use previously created Class MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- .../unit/providers/batches/test_reference.py | 93 +++++++++++-------- .../batches/test_reference_idempotency.py | 37 ++++---- 2 files changed, 71 insertions(+), 59 deletions(-) diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py index 32d59234d4..bff286fc74 100644 --- a/tests/unit/providers/batches/test_reference.py +++ b/tests/unit/providers/batches/test_reference.py @@ -58,8 +58,15 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from pydantic import ValidationError from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + RetrieveBatchRequest, +) class TestReferenceBatchesImpl: @@ -169,7 +176,7 @@ def _validate_batch_type(self, batch, expected_metadata=None): async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data): """Test successful batch creation and retrieval.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"]) @@ -184,7 +191,7 @@ async def test_create_and_retrieve_batch_success(self, provider, sample_batch_da assert isinstance(created_batch.created_at, int) assert created_batch.created_at > 0 - retrieved_batch = await provider.retrieve_batch(created_batch.id) + retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=created_batch.id)) self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"]) @@ -197,17 +204,15 @@ async def test_create_and_retrieve_batch_success(self, provider, sample_batch_da async def test_create_batch_without_metadata(self, provider): """Test batch creation without optional metadata.""" batch = await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h") ) assert batch.metadata is None async def test_create_batch_completion_window(self, provider): """Test batch creation with invalid completion window.""" - with pytest.raises(ValueError, match="Invalid completion_window"): - await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now" - ) + with pytest.raises(ValidationError, match="completion_window"): + CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now") @pytest.mark.parametrize( "endpoint", @@ -219,37 +224,43 @@ async def test_create_batch_completion_window(self, provider): async def test_create_batch_invalid_endpoints(self, provider, endpoint): """Test batch creation with various invalid endpoints.""" with pytest.raises(ValueError, match="Invalid endpoint"): - await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h") + await provider.create_batch( + CreateBatchRequest(input_file_id="file_123", endpoint=endpoint, completion_window="24h") + ) async def test_create_batch_invalid_metadata(self, provider): """Test that batch creation fails with invalid metadata.""" with pytest.raises(ValueError, match="should be a valid string"): await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={123: "invalid_key"}, # Non-string key + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={123: "invalid_key"}, # Non-string key + ) ) with pytest.raises(ValueError, match="should be a valid string"): await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={"valid_key": 456}, # Non-string value + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"valid_key": 456}, # Non-string value + ) ) async def test_retrieve_batch_not_found(self, provider): """Test error when retrieving non-existent batch.""" with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.retrieve_batch("nonexistent_batch") + await provider.retrieve_batch(RetrieveBatchRequest(batch_id="nonexistent_batch")) async def test_cancel_batch_success(self, provider, sample_batch_data): """Test successful batch cancellation.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) assert created_batch.status == "validating" - cancelled_batch = await provider.cancel_batch(created_batch.id) + cancelled_batch = await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id)) assert cancelled_batch.id == created_batch.id assert cancelled_batch.status in ["cancelling", "cancelled"] @@ -260,22 +271,22 @@ async def test_cancel_batch_success(self, provider, sample_batch_data): async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status): """Test error when cancelling batch in final states.""" provider.process_batches = False - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) # directly update status in kvstore await provider._update_batch(created_batch.id, status=status) with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"): - await provider.cancel_batch(created_batch.id) + await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id)) async def test_cancel_batch_not_found(self, provider): """Test error when cancelling non-existent batch.""" with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.cancel_batch("nonexistent_batch") + await provider.cancel_batch(CancelBatchRequest(batch_id="nonexistent_batch")) async def test_list_batches_empty(self, provider): """Test listing batches when none exist.""" - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert response.object == "list" assert response.data == [] @@ -285,9 +296,9 @@ async def test_list_batches_empty(self, provider): async def test_list_batches_single_batch(self, provider, sample_batch_data): """Test listing batches with single batch.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert len(response.data) == 1 self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"]) @@ -300,12 +311,12 @@ async def test_list_batches_multiple_batches(self, provider): """Test listing multiple batches.""" batches = [ await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) for i in range(3) ] - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert len(response.data) == 3 @@ -321,12 +332,12 @@ async def test_list_batches_with_limit(self, provider): """Test listing batches with limit parameter.""" batches = [ await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) for i in range(3) ] - response = await provider.list_batches(limit=2) + response = await provider.list_batches(ListBatchesRequest(limit=2)) assert len(response.data) == 2 assert response.has_more is True @@ -340,36 +351,36 @@ async def test_list_batches_with_pagination(self, provider): """Test listing batches with pagination using 'after' parameter.""" for i in range(3): await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) # Get first page - first_page = await provider.list_batches(limit=1) + first_page = await provider.list_batches(ListBatchesRequest(limit=1)) assert len(first_page.data) == 1 assert first_page.has_more is True # Get second page using 'after' - second_page = await provider.list_batches(limit=1, after=first_page.data[0].id) + second_page = await provider.list_batches(ListBatchesRequest(limit=1, after=first_page.data[0].id)) assert len(second_page.data) == 1 assert second_page.data[0].id != first_page.data[0].id # Verify we got the next batch in order - all_batches = await provider.list_batches() + all_batches = await provider.list_batches(ListBatchesRequest()) expected_second_batch_id = all_batches.data[1].id assert second_page.data[0].id == expected_second_batch_id async def test_list_batches_invalid_after(self, provider, sample_batch_data): """Test listing batches with invalid 'after' parameter.""" - await provider.create_batch(**sample_batch_data) + await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - response = await provider.list_batches(after="nonexistent_batch") + response = await provider.list_batches(ListBatchesRequest(after="nonexistent_batch")) # Should return all batches (no filtering when 'after' batch not found) assert len(response.data) == 1 async def test_kvstore_persistence(self, provider, sample_batch_data): """Test that batches are properly persisted in kvstore.""" - batch = await provider.create_batch(**sample_batch_data) + batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) stored_data = await provider.kvstore.get(f"batch:{batch.id}") assert stored_data is not None @@ -757,7 +768,7 @@ async def add_and_wait(batch_id: str): for _ in range(3): await provider.create_batch( - input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h") ) await asyncio.sleep(0.042) # let tasks start @@ -767,8 +778,10 @@ async def add_and_wait(batch_id: str): async def test_create_batch_embeddings_endpoint(self, provider): """Test that batch creation succeeds with embeddings endpoint.""" batch = await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/embeddings", - completion_window="24h", + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/embeddings", + completion_window="24h", + ) ) assert batch.endpoint == "/v1/embeddings" diff --git a/tests/unit/providers/batches/test_reference_idempotency.py b/tests/unit/providers/batches/test_reference_idempotency.py index acb7ca01c5..0ac73841eb 100644 --- a/tests/unit/providers/batches/test_reference_idempotency.py +++ b/tests/unit/providers/batches/test_reference_idempotency.py @@ -45,6 +45,7 @@ import pytest from llama_stack_api import ConflictError +from llama_stack_api.batches.models import CreateBatchRequest, RetrieveBatchRequest class TestReferenceBatchesIdempotency: @@ -56,18 +57,22 @@ async def test_idempotent_batch_creation_same_params(self, provider, sample_batc del sample_batch_data["metadata"] batch1 = await provider.create_batch( - **sample_batch_data, - metadata={"test": "value1", "other": "value2"}, - idempotency_key="unique-token-1", + CreateBatchRequest( + **sample_batch_data, + metadata={"test": "value1", "other": "value2"}, + idempotency_key="unique-token-1", + ) ) # sleep for 1 second to allow created_at timestamps to be different await asyncio.sleep(1) batch2 = await provider.create_batch( - **sample_batch_data, - metadata={"other": "value2", "test": "value1"}, # Different order - idempotency_key="unique-token-1", + CreateBatchRequest( + **sample_batch_data, + metadata={"other": "value2", "test": "value1"}, # Different order + idempotency_key="unique-token-1", + ) ) assert batch1.id == batch2.id @@ -77,23 +82,17 @@ async def test_idempotent_batch_creation_same_params(self, provider, sample_batc async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data): """Test that different idempotency keys create different batches even with same params.""" - batch1 = await provider.create_batch( - **sample_batch_data, - idempotency_key="token-A", - ) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-A")) - batch2 = await provider.create_batch( - **sample_batch_data, - idempotency_key="token-B", - ) + batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-B")) assert batch1.id != batch2.id async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data): """Test that batches without idempotency key create unique batches even with identical parameters.""" - batch1 = await provider.create_batch(**sample_batch_data) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - batch2 = await provider.create_batch(**sample_batch_data) + batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) assert batch1.id != batch2.id assert batch1.input_file_id == batch2.input_file_id @@ -117,12 +116,12 @@ async def test_same_idempotency_key_different_params_conflict( sample_batch_data[param_name] = first_value - batch1 = await provider.create_batch(**sample_batch_data) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"): sample_batch_data[param_name] = second_value - await provider.create_batch(**sample_batch_data) + await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - retrieved_batch = await provider.retrieve_batch(batch1.id) + retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=batch1.id)) assert retrieved_batch.id == batch1.id assert getattr(retrieved_batch, param_name) == first_value From 23e74446db217ddbf9559ab9b0f83cf10055297a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 21 Nov 2025 11:41:53 +0100 Subject: [PATCH 08/24] chore: rename routes.py to fastapi_routes.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- src/llama_stack/core/server/router_registry.py | 4 ++-- src/llama_stack_api/batches/__init__.py | 2 +- src/llama_stack_api/batches/{routes.py => fastapi_routes.py} | 0 3 files changed, 3 insertions(+), 3 deletions(-) rename src/llama_stack_api/batches/{routes.py => fastapi_routes.py} (100%) diff --git a/src/llama_stack/core/server/router_registry.py b/src/llama_stack/core/server/router_registry.py index 4849241e24..5851c0b5be 100644 --- a/src/llama_stack/core/server/router_registry.py +++ b/src/llama_stack/core/server/router_registry.py @@ -30,7 +30,7 @@ def has_router(api: "Api") -> bool: True if the API has a routes module with a create_router function """ try: - routes_module = importlib.import_module(f"llama_stack_api.{api.value}.routes") + routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes") return hasattr(routes_module, "create_router") except (ImportError, AttributeError): return False @@ -50,7 +50,7 @@ def build_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | APIRouter if the API has a routes module with create_router, None otherwise """ try: - routes_module = importlib.import_module(f"llama_stack_api.{api.value}.routes") + routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes") if hasattr(routes_module, "create_router"): router_factory = routes_module.create_router return router_factory(impl_getter) diff --git a/src/llama_stack_api/batches/__init__.py b/src/llama_stack_api/batches/__init__.py index ebde458447..33dc62b811 100644 --- a/src/llama_stack_api/batches/__init__.py +++ b/src/llama_stack_api/batches/__init__.py @@ -8,7 +8,7 @@ This module contains the Batches protocol definition. Pydantic models are defined in llama_stack_api.batches.models. -The FastAPI router is defined in llama_stack_api.batches.routes. +The FastAPI router is defined in llama_stack_api.batches.fastapi_routes. """ from typing import Protocol, runtime_checkable diff --git a/src/llama_stack_api/batches/routes.py b/src/llama_stack_api/batches/fastapi_routes.py similarity index 100% rename from src/llama_stack_api/batches/routes.py rename to src/llama_stack_api/batches/fastapi_routes.py From 8a21d8debedff0c5f2529be08e1c0a59a59369ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 21 Nov 2025 11:44:25 +0100 Subject: [PATCH 09/24] chore: mv router_registry.py to fastapi_router_registry.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit For clarity Signed-off-by: Sébastien Han --- scripts/openapi_generator/app.py | 2 +- src/llama_stack/core/inspect.py | 2 +- .../server/{router_registry.py => fastapi_router_registry.py} | 0 src/llama_stack/core/server/routes.py | 2 +- src/llama_stack/core/server/server.py | 2 +- src/llama_stack/core/server/tracing.py | 2 +- 6 files changed, 5 insertions(+), 5 deletions(-) rename src/llama_stack/core/server/{router_registry.py => fastapi_router_registry.py} (100%) diff --git a/scripts/openapi_generator/app.py b/scripts/openapi_generator/app.py index e98aafb353..28d49f0b74 100644 --- a/scripts/openapi_generator/app.py +++ b/scripts/openapi_generator/app.py @@ -77,7 +77,7 @@ def create_llama_stack_app() -> FastAPI: ) # Include routers for APIs that have them (automatic discovery) - from llama_stack.core.server.router_registry import build_router, has_router + from llama_stack.core.server.fastapi_router_registry import build_router, has_router def dummy_impl_getter(api: Api) -> Any: """Dummy implementation getter for OpenAPI generation.""" diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index ac7958968b..dfaf1fd5c5 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,7 +10,7 @@ from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis -from llama_stack.core.server.router_registry import build_router, has_router +from llama_stack.core.server.fastapi_router_registry import build_router, has_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( Api, diff --git a/src/llama_stack/core/server/router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py similarity index 100% rename from src/llama_stack/core/server/router_registry.py rename to src/llama_stack/core/server/fastapi_router_registry.py diff --git a/src/llama_stack/core/server/routes.py b/src/llama_stack/core/server/routes.py index cd24e9af28..b6508a7a47 100644 --- a/src/llama_stack/core/server/routes.py +++ b/src/llama_stack/core/server/routes.py @@ -30,7 +30,7 @@ def get_all_api_routes( This function only returns routes from APIs that use the legacy @webmethod decorator system. For APIs that have been migrated to FastAPI routers, - use the router registry (router_registry.has_router() and router_registry.build_router()). + use the router registry (fastapi_router_registry.has_router() and fastapi_router_registry.build_router()). Args: external_apis: Optional dictionary of external API specifications diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index f34cef4a91..15c377413f 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -44,7 +44,7 @@ request_provider_data_context, user_from_scope, ) -from llama_stack.core.server.router_registry import build_router +from llama_stack.core.server.fastapi_router_registry import build_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, diff --git a/src/llama_stack/core/server/tracing.py b/src/llama_stack/core/server/tracing.py index 352badcaae..7a6aec4364 100644 --- a/src/llama_stack/core/server/tracing.py +++ b/src/llama_stack/core/server/tracing.py @@ -6,7 +6,7 @@ from aiohttp import hdrs from llama_stack.core.external import ExternalApiSpec -from llama_stack.core.server.router_registry import has_router +from llama_stack.core.server.fastapi_router_registry import has_router from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.telemetry.tracing import end_trace, start_trace from llama_stack.log import get_logger From 95e945533533d961952e9ff0aeadfb0fd471e04d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 21 Nov 2025 12:02:09 +0100 Subject: [PATCH 10/24] chore: removed impl_getter from router function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactored the router to accept the implementation directly instead of using the impl_getter pattern. The caller already knows which API it's building a router for.for Signed-off-by: Sébastien Han --- scripts/openapi_generator/app.py | 13 +++---------- src/llama_stack/core/inspect.py | 7 ++----- .../core/server/fastapi_router_registry.py | 9 ++++----- src/llama_stack_api/batches/fastapi_routes.py | 8 +++----- 4 files changed, 12 insertions(+), 25 deletions(-) diff --git a/scripts/openapi_generator/app.py b/scripts/openapi_generator/app.py index 28d49f0b74..01b51046a1 100644 --- a/scripts/openapi_generator/app.py +++ b/scripts/openapi_generator/app.py @@ -14,6 +14,7 @@ from fastapi import FastAPI from llama_stack.core.resolver import api_protocol_map +from llama_stack.core.server.fastapi_router_registry import build_router, has_router from llama_stack_api import Api from .state import _protocol_methods_cache @@ -77,19 +78,11 @@ def create_llama_stack_app() -> FastAPI: ) # Include routers for APIs that have them (automatic discovery) - from llama_stack.core.server.fastapi_router_registry import build_router, has_router - - def dummy_impl_getter(api: Api) -> Any: - """Dummy implementation getter for OpenAPI generation.""" - return None - - # Get all APIs that might have routers - from llama_stack.core.resolver import api_protocol_map - protocols = api_protocol_map() for api in protocols.keys(): if has_router(api): - router = build_router(api, dummy_impl_getter) + # For OpenAPI generation, we don't need a real implementation + router = build_router(api, None) if router: app.include_router(router) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index dfaf1fd5c5..7d8937339e 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -109,10 +109,6 @@ def should_include_router_route(route, router_prefix: str | None) -> bool: return not route_deprecated # Process router-based routes - def dummy_impl_getter(api: Api) -> None: - """Dummy implementation getter for route inspection.""" - return None - from llama_stack.core.resolver import api_protocol_map protocols = api_protocol_map(external_apis) @@ -120,7 +116,8 @@ def dummy_impl_getter(api: Api) -> None: if not has_router(api): continue - router = build_router(api, dummy_impl_getter) + # For route inspection, we don't need a real implementation + router = build_router(api, None) if not router: continue diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index 5851c0b5be..a220ed78cf 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -11,7 +11,6 @@ """ import importlib -from collections.abc import Callable from typing import TYPE_CHECKING, Any from fastapi import APIRouter @@ -36,15 +35,15 @@ def has_router(api: "Api") -> bool: return False -def build_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | None: +def build_router(api: "Api", impl: Any) -> APIRouter | None: """Build a router for an API by combining its router factory with the implementation. This function discovers the router factory from the API package's routes module - and calls it with the impl_getter to create the final router instance. + and calls it with the implementation to create the final router instance. Args: api: The API enum value - impl_getter: Function that returns the implementation for a given API + impl: The implementation instance for the API Returns: APIRouter if the API has a routes module with create_router, None otherwise @@ -53,7 +52,7 @@ def build_router(api: "Api", impl_getter: Callable[["Api"], Any]) -> APIRouter | routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes") if hasattr(routes_module, "create_router"): router_factory = routes_module.create_router - return router_factory(impl_getter) + return router_factory(impl) except (ImportError, AttributeError): pass diff --git a/src/llama_stack_api/batches/fastapi_routes.py b/src/llama_stack_api/batches/fastapi_routes.py index 0946f89e6d..3e916c033c 100644 --- a/src/llama_stack_api/batches/fastapi_routes.py +++ b/src/llama_stack_api/batches/fastapi_routes.py @@ -11,7 +11,6 @@ all API-related code together. """ -from collections.abc import Callable from typing import Annotated from fastapi import APIRouter, Body, Depends, Path, Query @@ -23,16 +22,15 @@ ListBatchesRequest, RetrieveBatchRequest, ) -from llama_stack_api.datatypes import Api from llama_stack_api.router_utils import standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 -def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: +def create_router(impl: Batches) -> APIRouter: """Create a FastAPI router for the Batches API. Args: - impl_getter: Function that returns the Batches implementation for the batches API + impl: The Batches implementation instance Returns: APIRouter configured for the Batches API @@ -45,7 +43,7 @@ def create_router(impl_getter: Callable[[Api], Batches]) -> APIRouter: def get_batch_service() -> Batches: """Dependency function to get the batch service implementation.""" - return impl_getter(Api.batches) + return impl @router.post( "/batches", From 234eaf4709d0990f3888596711834722a37402c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 21 Nov 2025 12:03:06 +0100 Subject: [PATCH 11/24] chore: remove impl_getter function MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit We already have an impl at this point, no need to validate this again. Signed-off-by: Sébastien Han --- src/llama_stack/core/server/server.py | 75 ++++++++++++--------------- 1 file changed, 34 insertions(+), 41 deletions(-) diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 15c377413f..67db3de2e1 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -44,7 +44,7 @@ request_provider_data_context, user_from_scope, ) -from llama_stack.core.server.fastapi_router_registry import build_router +from llama_stack.core.server.fastapi_router_registry import build_router, has_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, @@ -465,51 +465,44 @@ def create_app() -> StackApp: apis_to_serve.add("prompts") apis_to_serve.add("conversations") - def impl_getter(api: Api) -> Any: - """Get the implementation for a given API.""" - try: - return impls[api] - except KeyError as e: - raise ValueError(f"Could not find provider implementation for {api} API") from e - for api_str in apis_to_serve: api = Api(api_str) # Try to discover and use a router factory from the API package - router = build_router(api, impl_getter) - if router: - app.include_router(router) - logger.debug(f"Registered router for {api} API") - else: - # Fall back to old webmethod-based route discovery until the migration is complete - routes = all_routes[api] - try: - impl = impls[api] - except KeyError as e: - raise ValueError(f"Could not find provider implementation for {api} API") from e - - for route, _ in routes: - if not hasattr(impl, route.name): - # ideally this should be a typing violation already - raise ValueError(f"Could not find method {route.name} on {impl}!") - - impl_method = getattr(impl, route.name) - # Filter out HEAD method since it's automatically handled by FastAPI for GET routes - available_methods = [m for m in route.methods if m != "HEAD"] - if not available_methods: - raise ValueError(f"No methods found for {route.name} on {impl}") - method = available_methods[0] - logger.debug(f"{method} {route.path}") - - with warnings.catch_warnings(): - warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") - getattr(app, method.lower())(route.path, response_model=None)( - create_dynamic_typed_route( - impl_method, - method.lower(), - route.path, - ) + if has_router(api): + impl = impls[api] + router = build_router(api, impl) + if router: + app.include_router(router) + logger.debug(f"Registered router for {api} API") + continue + + # Fall back to old webmethod-based route discovery until the migration is complete + impl = impls[api] + + routes = all_routes[api] + for route, _ in routes: + if not hasattr(impl, route.name): + # ideally this should be a typing violation already + raise ValueError(f"Could not find method {route.name} on {impl}!") + + impl_method = getattr(impl, route.name) + # Filter out HEAD method since it's automatically handled by FastAPI for GET routes + available_methods = [m for m in route.methods if m != "HEAD"] + if not available_methods: + raise ValueError(f"No methods found for {route.name} on {impl}") + method = available_methods[0] + logger.debug(f"{method} {route.path}") + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") + getattr(app, method.lower())(route.path, response_model=None)( + create_dynamic_typed_route( + impl_method, + method.lower(), + route.path, ) + ) logger.debug(f"serving APIs: {apis_to_serve}") From 6f552e0a3142425acdaf0329b162548f880c34f8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 21 Nov 2025 12:18:25 +0100 Subject: [PATCH 12/24] fix: mypy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- src/llama_stack/core/inspect.py | 5 ++++- .../core/server/fastapi_router_registry.py | 16 ++++++++++------ src/llama_stack_api/router_utils.py | 4 +++- 3 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 7d8937339e..0af04d4f82 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -132,7 +132,10 @@ def should_include_router_route(route, router_prefix: str | None) -> bool: methods = {m for m in route.methods if m != "HEAD"} if methods and should_include_router_route(route, router_prefix): # FastAPI already combines router prefix with route path - path = route.path + # Only APIRoute has a path attribute, use getattr to safely access it + path = getattr(route, "path", None) + if path is None: + continue ret.append( RouteInfo( diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index a220ed78cf..42e83b7cda 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -7,11 +7,11 @@ """Router utilities for FastAPI routers. This module provides utilities to discover and create FastAPI routers from API packages. -Routers are automatically discovered by checking for routes modules in each API package. +Routers are automatically discovered by checking for fastapi_routes modules in each API package. """ import importlib -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast from fastapi import APIRouter @@ -20,13 +20,13 @@ def has_router(api: "Api") -> bool: - """Check if an API has a router factory in its routes module. + """Check if an API has a router factory in its fastapi_routes module. Args: api: The API enum value Returns: - True if the API has a routes module with a create_router function + True if the API has a fastapi_routes module with a create_router function """ try: routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes") @@ -46,13 +46,17 @@ def build_router(api: "Api", impl: Any) -> APIRouter | None: impl: The implementation instance for the API Returns: - APIRouter if the API has a routes module with create_router, None otherwise + APIRouter if the API has a fastapi_routes module with create_router, None otherwise """ try: routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes") if hasattr(routes_module, "create_router"): router_factory = routes_module.create_router - return router_factory(impl) + # cast is safe here: mypy can't verify the return type statically because + # we're dynamically importing the module. However, all router factories in + # API packages are required to return APIRouter. If a router factory returns the wrong + # type, it will fail at runtime when app.include_router(router) is called + return cast(APIRouter, router_factory(impl)) except (ImportError, AttributeError): pass diff --git a/src/llama_stack_api/router_utils.py b/src/llama_stack_api/router_utils.py index 1ad19b05ff..fd0efe0608 100644 --- a/src/llama_stack_api/router_utils.py +++ b/src/llama_stack_api/router_utils.py @@ -11,7 +11,9 @@ in the OpenAPI specification. """ -standard_responses = { +from typing import Any + +standard_responses: dict[int | str, dict[str, Any]] = { 400: {"$ref": "#/components/responses/BadRequest400"}, 429: {"$ref": "#/components/responses/TooManyRequests429"}, 500: {"$ref": "#/components/responses/InternalServerError500"}, From ac816a6b255f0ba201fda17d49952014f91ebef0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Fri, 21 Nov 2025 15:56:44 +0100 Subject: [PATCH 13/24] fix: move models.py to top-level init MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit All batch models are now exported from the top level for better discoverability and IDE support. Signed-off-by: Sébastien Han --- src/llama_stack_api/__init__.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/llama_stack_api/__init__.py b/src/llama_stack_api/__init__.py index b6fe2fd239..f919d2afd7 100644 --- a/src/llama_stack_api/__init__.py +++ b/src/llama_stack_api/__init__.py @@ -26,7 +26,15 @@ # Import all public API symbols from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec -from .batches import Batches, BatchObject, ListBatchesResponse +from .batches import ( + Batches, + BatchObject, + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + ListBatchesResponse, + RetrieveBatchRequest, +) from .benchmarks import ( Benchmark, BenchmarkInput, @@ -462,6 +470,9 @@ "BasicScoringFnParams", "Batches", "BatchObject", + "CancelBatchRequest", + "CreateBatchRequest", + "ListBatchesRequest", "Benchmark", "BenchmarkConfig", "BenchmarkInput", @@ -555,6 +566,7 @@ "LLMAsJudgeScoringFnParams", "LLMRAGQueryGeneratorConfig", "ListBatchesResponse", + "RetrieveBatchRequest", "ListBenchmarksResponse", "ListDatasetsResponse", "ListModelsResponse", From 03a31269adb92b272c74ecf907f152ee44a9c397 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 24 Nov 2025 09:00:41 +0100 Subject: [PATCH 14/24] chore: more accurate route parcing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Use our built-in version levels. Signed-off-by: Sébastien Han --- src/llama_stack/core/server/tracing.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/llama_stack/core/server/tracing.py b/src/llama_stack/core/server/tracing.py index 7a6aec4364..210b74de33 100644 --- a/src/llama_stack/core/server/tracing.py +++ b/src/llama_stack/core/server/tracing.py @@ -11,9 +11,17 @@ from llama_stack.core.telemetry.tracing import end_trace, start_trace from llama_stack.log import get_logger from llama_stack_api.datatypes import Api +from llama_stack_api.version import ( + LLAMA_STACK_API_V1, + LLAMA_STACK_API_V1ALPHA, + LLAMA_STACK_API_V1BETA, +) logger = get_logger(name=__name__, category="core::server") +# Valid API version levels - all routes must start with one of these +VALID_API_LEVELS = {LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA, LLAMA_STACK_API_V1BETA} + class TracingMiddleware: def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): @@ -30,9 +38,9 @@ def _is_router_based_route(self, path: str) -> bool: We need to check if the path matches any router-based API prefix. """ # Extract API name from path (e.g., /v1/batches -> batches) - # Paths are typically /v1/{api_name} or /v1/{api_name}/... + # Paths must start with a valid API level: /v1/{api_name} or /v1alpha/{api_name} or /v1beta/{api_name} parts = path.strip("/").split("/") - if len(parts) >= 2 and parts[0].startswith("v"): + if len(parts) >= 2 and parts[0] in VALID_API_LEVELS: api_name = parts[1] try: api = Api(api_name) From 49005f1a39c16c991403e6f4bc3d7f926fbf7de1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 24 Nov 2025 09:56:39 +0100 Subject: [PATCH 15/24] fix: use hardcoded list and dictionary mapping for router registry MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace dynamic import-based router discovery with an explicit hardcoded list of APIs that have routers. Signed-off-by: Sébastien Han --- scripts/openapi_generator/app.py | 15 +++--- src/llama_stack/core/inspect.py | 7 +-- .../core/server/fastapi_router_registry.py | 51 +++++++------------ src/llama_stack/core/server/server.py | 15 +++--- src/llama_stack/core/server/tracing.py | 21 +++----- 5 files changed, 40 insertions(+), 69 deletions(-) diff --git a/scripts/openapi_generator/app.py b/scripts/openapi_generator/app.py index 01b51046a1..8495c0013f 100644 --- a/scripts/openapi_generator/app.py +++ b/scripts/openapi_generator/app.py @@ -14,7 +14,7 @@ from fastapi import FastAPI from llama_stack.core.resolver import api_protocol_map -from llama_stack.core.server.fastapi_router_registry import build_router, has_router +from llama_stack.core.server.fastapi_router_registry import build_router from llama_stack_api import Api from .state import _protocol_methods_cache @@ -77,14 +77,13 @@ def create_llama_stack_app() -> FastAPI: ], ) - # Include routers for APIs that have them (automatic discovery) + # Include routers for APIs that have them protocols = api_protocol_map() for api in protocols.keys(): - if has_router(api): - # For OpenAPI generation, we don't need a real implementation - router = build_router(api, None) - if router: - app.include_router(router) + # For OpenAPI generation, we don't need a real implementation + router = build_router(api, None) + if router: + app.include_router(router) # Get all API routes (for legacy webmethod-based routes) from llama_stack.core.server.routes import get_all_api_routes @@ -96,7 +95,7 @@ def create_llama_stack_app() -> FastAPI: for api, routes in api_routes.items(): # Skip APIs that have routers - they're already included above - if has_router(api): + if build_router(api, None) is not None: continue for route, webmethod in routes: diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 0af04d4f82..dfc4281fb1 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,7 +10,7 @@ from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis -from llama_stack.core.server.fastapi_router_registry import build_router, has_router +from llama_stack.core.server.fastapi_router_registry import build_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( Api, @@ -70,7 +70,7 @@ def get_provider_types(api: Api) -> list[str]: # Process webmethod-based routes (legacy) for api, endpoints in all_endpoints.items(): # Skip APIs that have routers - they'll be processed separately - if has_router(api): + if build_router(api, None) is not None: continue provider_types = get_provider_types(api) @@ -113,9 +113,6 @@ def should_include_router_route(route, router_prefix: str | None) -> bool: protocols = api_protocol_map(external_apis) for api in protocols.keys(): - if not has_router(api): - continue - # For route inspection, we don't need a real implementation router = build_router(api, None) if not router: diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index 42e83b7cda..4b4e9fe8c5 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -6,11 +6,10 @@ """Router utilities for FastAPI routers. -This module provides utilities to discover and create FastAPI routers from API packages. -Routers are automatically discovered by checking for fastapi_routes modules in each API package. +This module provides utilities to create FastAPI routers from API packages. +APIs with routers are explicitly listed here. """ -import importlib from typing import TYPE_CHECKING, Any, cast from fastapi import APIRouter @@ -18,46 +17,30 @@ if TYPE_CHECKING: from llama_stack_api.datatypes import Api +# Router factories for APIs that have FastAPI routers +# Add new APIs here as they are migrated to the router system +from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router -def has_router(api: "Api") -> bool: - """Check if an API has a router factory in its fastapi_routes module. - - Args: - api: The API enum value - - Returns: - True if the API has a fastapi_routes module with a create_router function - """ - try: - routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes") - return hasattr(routes_module, "create_router") - except (ImportError, AttributeError): - return False +_ROUTER_FACTORIES: dict[str, APIRouter] = { + "batches": create_batches_router, +} def build_router(api: "Api", impl: Any) -> APIRouter | None: """Build a router for an API by combining its router factory with the implementation. - This function discovers the router factory from the API package's routes module - and calls it with the implementation to create the final router instance. - Args: api: The API enum value impl: The implementation instance for the API Returns: - APIRouter if the API has a fastapi_routes module with create_router, None otherwise + APIRouter if the API has a router factory, None otherwise """ - try: - routes_module = importlib.import_module(f"llama_stack_api.{api.value}.fastapi_routes") - if hasattr(routes_module, "create_router"): - router_factory = routes_module.create_router - # cast is safe here: mypy can't verify the return type statically because - # we're dynamically importing the module. However, all router factories in - # API packages are required to return APIRouter. If a router factory returns the wrong - # type, it will fail at runtime when app.include_router(router) is called - return cast(APIRouter, router_factory(impl)) - except (ImportError, AttributeError): - pass - - return None + router_factory = _ROUTER_FACTORIES.get(api.value) + if router_factory is None: + return None + + # cast is safe here: all router factories in API packages are required to return APIRouter. + # If a router factory returns the wrong type, it will fail at runtime when + # app.include_router(router) is called + return cast(APIRouter, router_factory(impl)) diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 67db3de2e1..84f724fdbb 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -44,7 +44,7 @@ request_provider_data_context, user_from_scope, ) -from llama_stack.core.server.fastapi_router_registry import build_router, has_router +from llama_stack.core.server.fastapi_router_registry import build_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, @@ -469,13 +469,12 @@ def create_app() -> StackApp: api = Api(api_str) # Try to discover and use a router factory from the API package - if has_router(api): - impl = impls[api] - router = build_router(api, impl) - if router: - app.include_router(router) - logger.debug(f"Registered router for {api} API") - continue + impl = impls[api] + router = build_router(api, impl) + if router: + app.include_router(router) + logger.debug(f"Registered router for {api} API") + continue # Fall back to old webmethod-based route discovery until the migration is complete impl = impls[api] diff --git a/src/llama_stack/core/server/tracing.py b/src/llama_stack/core/server/tracing.py index 210b74de33..7e851fbb5f 100644 --- a/src/llama_stack/core/server/tracing.py +++ b/src/llama_stack/core/server/tracing.py @@ -6,11 +6,10 @@ from aiohttp import hdrs from llama_stack.core.external import ExternalApiSpec -from llama_stack.core.server.fastapi_router_registry import has_router +from llama_stack.core.server.fastapi_router_registry import _ROUTER_FACTORIES from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.telemetry.tracing import end_trace, start_trace from llama_stack.log import get_logger -from llama_stack_api.datatypes import Api from llama_stack_api.version import ( LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA, @@ -35,20 +34,14 @@ def _is_router_based_route(self, path: str) -> bool: """Check if a path belongs to a router-based API. Router-based APIs use FastAPI routers instead of the old webmethod system. - We need to check if the path matches any router-based API prefix. + Paths must start with a valid API level (v1, v1alpha, v1beta) followed by an API name. """ - # Extract API name from path (e.g., /v1/batches -> batches) - # Paths must start with a valid API level: /v1/{api_name} or /v1alpha/{api_name} or /v1beta/{api_name} parts = path.strip("/").split("/") - if len(parts) >= 2 and parts[0] in VALID_API_LEVELS: - api_name = parts[1] - try: - api = Api(api_name) - return has_router(api) - except (ValueError, KeyError): - # Not a known API or not router-based - return False - return False + if len(parts) < 2 or parts[0] not in VALID_API_LEVELS: + return False + + # Check directly if the API name is in the router factories list + return parts[1] in _ROUTER_FACTORIES async def __call__(self, scope, receive, send): if scope.get("type") == "lifespan": From 87e60bc48fef33e3e32a1cc44e6739338e701e27 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 24 Nov 2025 11:30:44 +0100 Subject: [PATCH 16/24] chore: move dep functions outside of create_router MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Less indirection and clearer declarations. Signed-off-by: Sébastien Han --- src/llama_stack_api/batches/fastapi_routes.py | 61 +++++++++---------- 1 file changed, 28 insertions(+), 33 deletions(-) diff --git a/src/llama_stack_api/batches/fastapi_routes.py b/src/llama_stack_api/batches/fastapi_routes.py index 3e916c033c..b53a4fc034 100644 --- a/src/llama_stack_api/batches/fastapi_routes.py +++ b/src/llama_stack_api/batches/fastapi_routes.py @@ -26,6 +26,30 @@ from llama_stack_api.version import LLAMA_STACK_API_V1 +def get_retrieve_batch_request( + batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")], +) -> RetrieveBatchRequest: + """Dependency function to create RetrieveBatchRequest from path parameter.""" + return RetrieveBatchRequest(batch_id=batch_id) + + +def get_cancel_batch_request( + batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")], +) -> CancelBatchRequest: + """Dependency function to create CancelBatchRequest from path parameter.""" + return CancelBatchRequest(batch_id=batch_id) + + +def get_list_batches_request( + after: Annotated[ + str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.") + ] = None, + limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20, +) -> ListBatchesRequest: + """Dependency function to create ListBatchesRequest from query parameters.""" + return ListBatchesRequest(after=after, limit=limit) + + def create_router(impl: Batches) -> APIRouter: """Create a FastAPI router for the Batches API. @@ -41,10 +65,6 @@ def create_router(impl: Batches) -> APIRouter: responses=standard_responses, ) - def get_batch_service() -> Batches: - """Dependency function to get the batch service implementation.""" - return impl - @router.post( "/batches", response_model=BatchObject, @@ -57,15 +77,8 @@ def get_batch_service() -> Batches: ) async def create_batch( request: Annotated[CreateBatchRequest, Body(...)], - svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - return await svc.create_batch(request) - - def get_retrieve_batch_request( - batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")], - ) -> RetrieveBatchRequest: - """Dependency function to create RetrieveBatchRequest from path parameter.""" - return RetrieveBatchRequest(batch_id=batch_id) + return await impl.create_batch(request) @router.get( "/batches/{batch_id}", @@ -78,15 +91,8 @@ def get_retrieve_batch_request( ) async def retrieve_batch( request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)], - svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - return await svc.retrieve_batch(request) - - def get_cancel_batch_request( - batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")], - ) -> CancelBatchRequest: - """Dependency function to create CancelBatchRequest from path parameter.""" - return CancelBatchRequest(batch_id=batch_id) + return await impl.retrieve_batch(request) @router.post( "/batches/{batch_id}/cancel", @@ -99,18 +105,8 @@ def get_cancel_batch_request( ) async def cancel_batch( request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)], - svc: Annotated[Batches, Depends(get_batch_service)], ) -> BatchObject: - return await svc.cancel_batch(request) - - def get_list_batches_request( - after: Annotated[ - str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.") - ] = None, - limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20, - ) -> ListBatchesRequest: - """Dependency function to create ListBatchesRequest from query parameters.""" - return ListBatchesRequest(after=after, limit=limit) + return await impl.cancel_batch(request) @router.get( "/batches", @@ -123,8 +119,7 @@ def get_list_batches_request( ) async def list_batches( request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)], - svc: Annotated[Batches, Depends(get_batch_service)], ) -> ListBatchesResponse: - return await svc.list_batches(request) + return await impl.list_batches(request) return router From 4f08a62fa1d4fc4a9871c6a28ed11a02d08b6718 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 24 Nov 2025 11:52:29 +0100 Subject: [PATCH 17/24] chore: remove telemetry code for routers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit addressed https://github.com/llamastack/llama-stack/pull/4191/files#r2554273774 Signed-off-by: Sébastien Han --- src/llama_stack/core/server/tracing.py | 60 -------------------------- 1 file changed, 60 deletions(-) diff --git a/src/llama_stack/core/server/tracing.py b/src/llama_stack/core/server/tracing.py index 7e851fbb5f..c4901d9b12 100644 --- a/src/llama_stack/core/server/tracing.py +++ b/src/llama_stack/core/server/tracing.py @@ -6,21 +6,12 @@ from aiohttp import hdrs from llama_stack.core.external import ExternalApiSpec -from llama_stack.core.server.fastapi_router_registry import _ROUTER_FACTORIES from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.telemetry.tracing import end_trace, start_trace from llama_stack.log import get_logger -from llama_stack_api.version import ( - LLAMA_STACK_API_V1, - LLAMA_STACK_API_V1ALPHA, - LLAMA_STACK_API_V1BETA, -) logger = get_logger(name=__name__, category="core::server") -# Valid API version levels - all routes must start with one of these -VALID_API_LEVELS = {LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA, LLAMA_STACK_API_V1BETA} - class TracingMiddleware: def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): @@ -30,19 +21,6 @@ def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]): # FastAPI built-in paths that should bypass custom routing self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static") - def _is_router_based_route(self, path: str) -> bool: - """Check if a path belongs to a router-based API. - - Router-based APIs use FastAPI routers instead of the old webmethod system. - Paths must start with a valid API level (v1, v1alpha, v1beta) followed by an API name. - """ - parts = path.strip("/").split("/") - if len(parts) < 2 or parts[0] not in VALID_API_LEVELS: - return False - - # Check directly if the API name is in the router factories list - return parts[1] in _ROUTER_FACTORIES - async def __call__(self, scope, receive, send): if scope.get("type") == "lifespan": return await self.app(scope, receive, send) @@ -55,44 +33,6 @@ async def __call__(self, scope, receive, send): logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}") return await self.app(scope, receive, send) - # Check if this is a router-based route - if so, pass through to FastAPI - # Router-based routes are handled by FastAPI directly, so we skip the old route lookup - # but still need to set up tracing - is_router_based = self._is_router_based_route(path) - if is_router_based: - logger.debug(f"Router-based route detected: {path}, setting up tracing") - # Set up tracing for router-based routes - trace_attributes = {"__location__": "server", "raw_path": path} - - # Extract W3C trace context headers and store as trace attributes - headers = dict(scope.get("headers", [])) - traceparent = headers.get(b"traceparent", b"").decode() - if traceparent: - trace_attributes["traceparent"] = traceparent - tracestate = headers.get(b"tracestate", b"").decode() - if tracestate: - trace_attributes["tracestate"] = tracestate - - trace_context = await start_trace(path, trace_attributes) - - async def send_with_trace_id(message): - if message["type"] == "http.response.start": - headers = message.get("headers", []) - headers.append([b"x-trace-id", str(trace_context.trace_id).encode()]) - message["headers"] = headers - await send(message) - - try: - return await self.app(scope, receive, send_with_trace_id) - finally: - # Always end trace, even if exception occurred - # FastAPI's exception handler will handle the exception and send the response - # The exception will continue to propagate for logging, which is normal - try: - await end_trace() - except Exception: - logger.exception("Error ending trace") - if not hasattr(self, "route_impls"): self.route_impls = initialize_route_impls(self.impls, self.external_apis) From a6aaf18bb6dc3ed81a375eb0a8eabc445036f447 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 24 Nov 2025 14:47:46 +0100 Subject: [PATCH 18/24] chore: generate FastAPI dependency functions from Pydantic models to eliminate duplication MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Added create_query_dependency() and create_path_dependency() helpers that automatically generate FastAPI dependency functions from Pydantic models. This makes the models the single source of truth for field types, descriptions, and defaults, eliminating duplication between models.py and fastapi_routes.py. Signed-off-by: Sébastien Han --- src/llama_stack_api/batches/fastapi_routes.py | 32 ++-- src/llama_stack_api/router_utils.py | 150 +++++++++++++++++- 2 files changed, 158 insertions(+), 24 deletions(-) diff --git a/src/llama_stack_api/batches/fastapi_routes.py b/src/llama_stack_api/batches/fastapi_routes.py index b53a4fc034..dd5dc7a6cf 100644 --- a/src/llama_stack_api/batches/fastapi_routes.py +++ b/src/llama_stack_api/batches/fastapi_routes.py @@ -13,7 +13,7 @@ from typing import Annotated -from fastapi import APIRouter, Body, Depends, Path, Query +from fastapi import APIRouter, Body, Depends from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse from llama_stack_api.batches.models import ( @@ -22,32 +22,18 @@ ListBatchesRequest, RetrieveBatchRequest, ) -from llama_stack_api.router_utils import standard_responses +from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses from llama_stack_api.version import LLAMA_STACK_API_V1 +# Automatically generate dependency functions from Pydantic models +# This ensures the models are the single source of truth for descriptions +get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest) +get_cancel_batch_request = create_path_dependency(CancelBatchRequest) -def get_retrieve_batch_request( - batch_id: Annotated[str, Path(description="The ID of the batch to retrieve.")], -) -> RetrieveBatchRequest: - """Dependency function to create RetrieveBatchRequest from path parameter.""" - return RetrieveBatchRequest(batch_id=batch_id) - -def get_cancel_batch_request( - batch_id: Annotated[str, Path(description="The ID of the batch to cancel.")], -) -> CancelBatchRequest: - """Dependency function to create CancelBatchRequest from path parameter.""" - return CancelBatchRequest(batch_id=batch_id) - - -def get_list_batches_request( - after: Annotated[ - str | None, Query(description="Optional cursor for pagination. Returns batches after this ID.") - ] = None, - limit: Annotated[int, Query(description="Maximum number of batches to return. Defaults to 20.")] = 20, -) -> ListBatchesRequest: - """Dependency function to create ListBatchesRequest from query parameters.""" - return ListBatchesRequest(after=after, limit=limit) +# Automatically generate dependency function from Pydantic model +# This ensures the model is the single source of truth for descriptions and defaults +get_list_batches_request = create_query_dependency(ListBatchesRequest) def create_router(impl: Batches) -> APIRouter: diff --git a/src/llama_stack_api/router_utils.py b/src/llama_stack_api/router_utils.py index fd0efe0608..5d934826c9 100644 --- a/src/llama_stack_api/router_utils.py +++ b/src/llama_stack_api/router_utils.py @@ -11,7 +11,12 @@ in the OpenAPI specification. """ -from typing import Any +import inspect +from collections.abc import Callable +from typing import Annotated, Any, TypeVar + +from fastapi import Path, Query +from pydantic import BaseModel standard_responses: dict[int | str, dict[str, Any]] = { 400: {"$ref": "#/components/responses/BadRequest400"}, @@ -19,3 +24,146 @@ 500: {"$ref": "#/components/responses/InternalServerError500"}, "default": {"$ref": "#/components/responses/DefaultError"}, } + +T = TypeVar("T", bound=BaseModel) + + +def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]: + """Create a FastAPI dependency function from a Pydantic model for query parameters. + + FastAPI does not natively support using Pydantic models as query parameters + without a dependency function. Using a dependency function typically leads to + duplication: field types, default values, and descriptions must be repeated in + `Query(...)` annotations even though they already exist in the Pydantic model. + + This function automatically generates a dependency function that extracts query parameters + from the request and constructs an instance of the Pydantic model. The descriptions and + defaults are automatically extracted from the model's Field definitions, making the model + the single source of truth. + + Args: + model_class: The Pydantic model class to create a dependency for + + Returns: + A dependency function that can be used with FastAPI's Depends() + + Example: + ```python + get_list_batches_request = create_query_dependency(ListBatchesRequest) + + @router.get("/batches") + async def list_batches( + request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)] + ): + ... + ``` + """ + # Build function signature dynamically from model fields + annotations: dict[str, Any] = {} + defaults: dict[str, Any] = {} + + for field_name, field_info in model_class.model_fields.items(): + # Extract description from Field + description = field_info.description + + # Create Query annotation with description from model + query_annotation = Query(description=description) if description else Query() + + # Create Annotated type with Query + field_type = field_info.annotation + annotations[field_name] = Annotated[field_type, query_annotation] + + # Set default value from model + if field_info.default is not inspect.Parameter.empty: + defaults[field_name] = field_info.default + + # Create the dependency function dynamically + def dependency_func(**kwargs: Any) -> T: + return model_class(**kwargs) + + # Set function signature + sig_params = [] + for field_name, field_type in annotations.items(): + default = defaults.get(field_name, inspect.Parameter.empty) + param = inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=default, + annotation=field_type, + ) + sig_params.append(param) + + dependency_func.__signature__ = inspect.Signature(sig_params) + dependency_func.__annotations__ = annotations + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" + + return dependency_func + + +def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]: + """Create a FastAPI dependency function from a Pydantic model for path parameters. + + FastAPI requires path parameters to be explicitly annotated with `Path()`. When using + a Pydantic model that contains path parameters, you typically need a dependency function + that extracts the path parameter and constructs the model. This leads to duplication: + the parameter name, type, and description must be repeated in `Path(...)` annotations + even though they already exist in the Pydantic model. + + This function automatically generates a dependency function that extracts path parameters + from the request and constructs an instance of the Pydantic model. The descriptions are + automatically extracted from the model's Field definitions, making the model the single + source of truth. + + Args: + model_class: The Pydantic model class to create a dependency for. The model should + have exactly one field that represents the path parameter. + + Returns: + A dependency function that can be used with FastAPI's Depends() + + Example: + ```python + get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest) + + @router.get("/batches/{batch_id}") + async def retrieve_batch( + request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)] + ): + ... + ``` + """ + # Get the single field from the model (path parameter models typically have one field) + if len(model_class.model_fields) != 1: + raise ValueError( + f"Path parameter model {model_class.__name__} must have exactly one field, " + f"but has {len(model_class.model_fields)} fields" + ) + + field_name, field_info = next(iter(model_class.model_fields.items())) + + # Extract description from Field + description = field_info.description + + # Create Path annotation with description from model + path_annotation = Path(description=description) if description else Path() + + # Create Annotated type with Path + field_type = field_info.annotation + annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]} + + # Create the dependency function dynamically + def dependency_func(**kwargs: Any) -> T: + return model_class(**kwargs) + + # Set function signature + param = inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotations[field_name], + ) + + dependency_func.__signature__ = inspect.Signature([param]) + dependency_func.__annotations__ = annotations + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" + + return dependency_func From 6d76a63eb754bae372abcff68d10db5c395b565d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Mon, 24 Nov 2025 14:53:26 +0100 Subject: [PATCH 19/24] fix: mypy MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- .../core/server/fastapi_router_registry.py | 2 +- src/llama_stack_api/router_utils.py | 38 ++++++------------- 2 files changed, 13 insertions(+), 27 deletions(-) diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index 4b4e9fe8c5..1e340bb75b 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -21,7 +21,7 @@ # Add new APIs here as they are migrated to the router system from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router -_ROUTER_FACTORIES: dict[str, APIRouter] = { +_ROUTER_FACTORIES: dict[str, Any] = { "batches": create_batches_router, } diff --git a/src/llama_stack_api/router_utils.py b/src/llama_stack_api/router_utils.py index 5d934826c9..25c8f47c40 100644 --- a/src/llama_stack_api/router_utils.py +++ b/src/llama_stack_api/router_utils.py @@ -46,16 +46,6 @@ def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., Returns: A dependency function that can be used with FastAPI's Depends() - - Example: - ```python - get_list_batches_request = create_query_dependency(ListBatchesRequest) - - @router.get("/batches") - async def list_batches( - request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)] - ): - ... ``` """ # Build function signature dynamically from model fields @@ -93,9 +83,12 @@ def dependency_func(**kwargs: Any) -> T: ) sig_params.append(param) - dependency_func.__signature__ = inspect.Signature(sig_params) - dependency_func.__annotations__ = annotations - dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" + # These attributes are set dynamically at runtime. While mypy can't verify them statically, + # they are standard Python function attributes that exist on all callable objects at runtime. + # Setting them allows FastAPI to properly introspect the function signature for dependency injection. + dependency_func.__signature__ = inspect.Signature(sig_params) # type: ignore[attr-defined] + dependency_func.__annotations__ = annotations # type: ignore[attr-defined] + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined] return dependency_func @@ -120,16 +113,6 @@ def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., Returns: A dependency function that can be used with FastAPI's Depends() - - Example: - ```python - get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest) - - @router.get("/batches/{batch_id}") - async def retrieve_batch( - request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)] - ): - ... ``` """ # Get the single field from the model (path parameter models typically have one field) @@ -162,8 +145,11 @@ def dependency_func(**kwargs: Any) -> T: annotation=annotations[field_name], ) - dependency_func.__signature__ = inspect.Signature([param]) - dependency_func.__annotations__ = annotations - dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" + # These attributes are set dynamically at runtime. While mypy can't verify them statically, + # they are standard Python function attributes that exist on all callable objects at runtime. + # Setting them allows FastAPI to properly introspect the function signature for dependency injection. + dependency_func.__signature__ = inspect.Signature([param]) # type: ignore[attr-defined] + dependency_func.__annotations__ = annotations # type: ignore[attr-defined] + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined] return dependency_func From 9a2b4efabd2142c5512a3f52aaedf5cac1eb26ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 25 Nov 2025 10:51:52 +0100 Subject: [PATCH 20/24] chore: clarify function and log about which router MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit It's FastAPI Signed-off-by: Sébastien Han --- scripts/openapi_generator/app.py | 6 +++--- src/llama_stack/core/inspect.py | 6 +++--- src/llama_stack/core/server/fastapi_router_registry.py | 2 +- src/llama_stack/core/server/routes.py | 2 +- src/llama_stack/core/server/server.py | 6 +++--- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/scripts/openapi_generator/app.py b/scripts/openapi_generator/app.py index 8495c0013f..023a4c62eb 100644 --- a/scripts/openapi_generator/app.py +++ b/scripts/openapi_generator/app.py @@ -14,7 +14,7 @@ from fastapi import FastAPI from llama_stack.core.resolver import api_protocol_map -from llama_stack.core.server.fastapi_router_registry import build_router +from llama_stack.core.server.fastapi_router_registry import build_fastapi_router from llama_stack_api import Api from .state import _protocol_methods_cache @@ -81,7 +81,7 @@ def create_llama_stack_app() -> FastAPI: protocols = api_protocol_map() for api in protocols.keys(): # For OpenAPI generation, we don't need a real implementation - router = build_router(api, None) + router = build_fastapi_router(api, None) if router: app.include_router(router) @@ -95,7 +95,7 @@ def create_llama_stack_app() -> FastAPI: for api, routes in api_routes.items(): # Skip APIs that have routers - they're already included above - if build_router(api, None) is not None: + if build_fastapi_router(api, None) is not None: continue for route, webmethod in routes: diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index dfc4281fb1..fe185d4338 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,7 +10,7 @@ from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis -from llama_stack.core.server.fastapi_router_registry import build_router +from llama_stack.core.server.fastapi_router_registry import build_fastapi_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( Api, @@ -70,7 +70,7 @@ def get_provider_types(api: Api) -> list[str]: # Process webmethod-based routes (legacy) for api, endpoints in all_endpoints.items(): # Skip APIs that have routers - they'll be processed separately - if build_router(api, None) is not None: + if build_fastapi_router(api, None) is not None: continue provider_types = get_provider_types(api) @@ -114,7 +114,7 @@ def should_include_router_route(route, router_prefix: str | None) -> bool: protocols = api_protocol_map(external_apis) for api in protocols.keys(): # For route inspection, we don't need a real implementation - router = build_router(api, None) + router = build_fastapi_router(api, None) if not router: continue diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index 1e340bb75b..f097ac0a46 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -26,7 +26,7 @@ } -def build_router(api: "Api", impl: Any) -> APIRouter | None: +def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None: """Build a router for an API by combining its router factory with the implementation. Args: diff --git a/src/llama_stack/core/server/routes.py b/src/llama_stack/core/server/routes.py index b6508a7a47..9df9e4a60e 100644 --- a/src/llama_stack/core/server/routes.py +++ b/src/llama_stack/core/server/routes.py @@ -30,7 +30,7 @@ def get_all_api_routes( This function only returns routes from APIs that use the legacy @webmethod decorator system. For APIs that have been migrated to FastAPI routers, - use the router registry (fastapi_router_registry.has_router() and fastapi_router_registry.build_router()). + use the router registry (fastapi_router_registry.has_router() and fastapi_router_registry.build_fastapi_router()). Args: external_apis: Optional dictionary of external API specifications diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 84f724fdbb..e316609c35 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -44,7 +44,7 @@ request_provider_data_context, user_from_scope, ) -from llama_stack.core.server.fastapi_router_registry import build_router +from llama_stack.core.server.fastapi_router_registry import build_fastapi_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, @@ -470,10 +470,10 @@ def create_app() -> StackApp: # Try to discover and use a router factory from the API package impl = impls[api] - router = build_router(api, impl) + router = build_fastapi_router(api, impl) if router: app.include_router(router) - logger.debug(f"Registered router for {api} API") + logger.debug(f"Registered FastAPIrouter for {api} API") continue # Fall back to old webmethod-based route discovery until the migration is complete From b0b3034f16006aa39fedf939ccf40d77a5ff469c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 25 Nov 2025 10:54:43 +0100 Subject: [PATCH 21/24] chore: rm leftover MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Sébastien Han --- src/llama_stack/core/server/fastapi_router_registry.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index f097ac0a46..dea3250367 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -10,16 +10,14 @@ APIs with routers are explicitly listed here. """ -from typing import TYPE_CHECKING, Any, cast +from typing import Any, cast from fastapi import APIRouter -if TYPE_CHECKING: - from llama_stack_api.datatypes import Api - # Router factories for APIs that have FastAPI routers # Add new APIs here as they are migrated to the router system from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router +from llama_stack_api.datatypes import Api _ROUTER_FACTORIES: dict[str, Any] = { "batches": create_batches_router, From 3dc5b5d3a016fa8113e02336eb208515ba9fa21b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 25 Nov 2025 10:57:27 +0100 Subject: [PATCH 22/24] fix: more accurate type MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://github.com/llamastack/llama-stack/pull/4191#discussion_r2557389025 Signed-off-by: Sébastien Han --- src/llama_stack/core/server/fastapi_router_registry.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index dea3250367..1786228534 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -10,6 +10,7 @@ APIs with routers are explicitly listed here. """ +from collections.abc import Callable from typing import Any, cast from fastapi import APIRouter @@ -19,7 +20,7 @@ from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router from llama_stack_api.datatypes import Api -_ROUTER_FACTORIES: dict[str, Any] = { +_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = { "batches": create_batches_router, } From ead9e63ef897c2b90bd2d1cdb835d84729276a5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 25 Nov 2025 11:04:33 +0100 Subject: [PATCH 23/24] fix: no inline import MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://github.com/llamastack/llama-stack/pull/4191#discussion_r2557412421 Signed-off-by: Sébastien Han --- src/llama_stack/core/inspect.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index fe185d4338..28d23b8151 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,6 +10,7 @@ from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis +from llama_stack.core.resolver import api_protocol_map from llama_stack.core.server.fastapi_router_registry import build_fastapi_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( @@ -108,9 +109,6 @@ def should_include_router_route(route, router_prefix: str | None) -> bool: return not route_deprecated and prefix_level == api_filter return not route_deprecated - # Process router-based routes - from llama_stack.core.resolver import api_protocol_map - protocols = api_protocol_map(external_apis) for api in protocols.keys(): # For route inspection, we don't need a real implementation From f330c8eb2f51567bf44b12e618d23abc81a63629 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Tue, 25 Nov 2025 13:48:47 +0100 Subject: [PATCH 24/24] chore: simplify route addition when calling inspect MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit https://github.com/llamastack/llama-stack/pull/4191/files#r2557411918 Signed-off-by: Sébastien Han --- src/llama_stack/core/inspect.py | 114 +++++++++--------- .../core/server/fastapi_router_registry.py | 39 ++++++ 2 files changed, 97 insertions(+), 56 deletions(-) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 28d23b8151..45cab29707 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,8 +10,11 @@ from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis -from llama_stack.core.resolver import api_protocol_map -from llama_stack.core.server.fastapi_router_registry import build_fastapi_router +from llama_stack.core.server.fastapi_router_registry import ( + _ROUTER_FACTORIES, + build_fastapi_router, + get_router_routes, +) from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( Api, @@ -46,6 +49,7 @@ async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse run_config: StackRunConfig = self.config.run_config # Helper function to determine if a route should be included based on api_filter + # TODO: remove this once we've migrated all APIs to FastAPI routers def should_include_route(webmethod) -> bool: if api_filter is None: # Default: only non-deprecated APIs @@ -57,40 +61,15 @@ def should_include_route(webmethod) -> bool: # Filter by API level (non-deprecated routes only) return not webmethod.deprecated and webmethod.level == api_filter - ret = [] - external_apis = load_external_apis(run_config) - all_endpoints = get_all_api_routes(external_apis) - # Helper function to get provider types for an API - def get_provider_types(api: Api) -> list[str]: + def _get_provider_types(api: Api) -> list[str]: if api.value in ["providers", "inspect"]: return [] # These APIs don't have "real" providers they're internal to the stack providers = run_config.providers.get(api.value, []) return [p.provider_type for p in providers] if providers else [] - # Process webmethod-based routes (legacy) - for api, endpoints in all_endpoints.items(): - # Skip APIs that have routers - they'll be processed separately - if build_fastapi_router(api, None) is not None: - continue - - provider_types = get_provider_types(api) - # Always include provider and inspect APIs, filter others based on run config - if api.value in ["providers", "inspect"] or provider_types: - ret.extend( - [ - RouteInfo( - route=e.path, - method=next(iter([m for m in e.methods if m != "HEAD"])), - provider_types=provider_types, - ) - for e, webmethod in endpoints - if e.methods is not None and should_include_route(webmethod) - ] - ) - # Helper function to determine if a router route should be included based on api_filter - def should_include_router_route(route, router_prefix: str | None) -> bool: + def _should_include_router_route(route, router_prefix: str | None) -> bool: """Check if a router-based route should be included based on api_filter.""" # Check deprecated status route_deprecated = getattr(route, "deprecated", False) or False @@ -109,36 +88,59 @@ def should_include_router_route(route, router_prefix: str | None) -> bool: return not route_deprecated and prefix_level == api_filter return not route_deprecated - protocols = api_protocol_map(external_apis) - for api in protocols.keys(): - # For route inspection, we don't need a real implementation - router = build_fastapi_router(api, None) - if not router: + ret = [] + external_apis = load_external_apis(run_config) + all_endpoints = get_all_api_routes(external_apis) + + # Process routes from APIs with FastAPI routers + for api_name in _ROUTER_FACTORIES.keys(): + api = Api(api_name) + router = build_fastapi_router(api, None) # we don't need the impl here, just the routes + if router: + router_routes = get_router_routes(router) + for route in router_routes: + if _should_include_router_route(route, router.prefix): + ret.append( + RouteInfo( + route=route.path, + method=next(iter([m for m in route.methods if m != "HEAD"])), + provider_types=_get_provider_types(api), + ) + ) + + # Process routes from legacy webmethod-based APIs + for api, endpoints in all_endpoints.items(): + # Skip APIs that have routers (already processed above) + if api.value in _ROUTER_FACTORIES: continue - provider_types = get_provider_types(api) - # Only include if there are providers (or it's a special API) - if api.value in ["providers", "inspect"] or provider_types: - router_prefix = getattr(router, "prefix", None) - for route in router.routes: - # Extract HTTP methods from the route - # FastAPI routes have methods as a set - if hasattr(route, "methods") and route.methods: - methods = {m for m in route.methods if m != "HEAD"} - if methods and should_include_router_route(route, router_prefix): - # FastAPI already combines router prefix with route path - # Only APIRoute has a path attribute, use getattr to safely access it - path = getattr(route, "path", None) - if path is None: - continue - - ret.append( - RouteInfo( - route=path, - method=next(iter(methods)), - provider_types=provider_types, - ) + # Always include provider and inspect APIs, filter others based on run config + if api.value in ["providers", "inspect"]: + ret.extend( + [ + RouteInfo( + route=e.path, + method=next(iter([m for m in e.methods if m != "HEAD"])), + provider_types=[], # These APIs don't have "real" providers - they're internal to the stack + ) + for e, webmethod in endpoints + if e.methods is not None and should_include_route(webmethod) + ] + ) + else: + providers = run_config.providers.get(api.value, []) + if providers: # Only process if there are providers for this API + ret.extend( + [ + RouteInfo( + route=e.path, + method=next(iter([m for m in e.methods if m != "HEAD"])), + provider_types=[p.provider_type for p in providers], ) + for e, webmethod in endpoints + if e.methods is not None and should_include_route(webmethod) + ] + ) return ListRoutesResponse(data=ret) diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py index 1786228534..84f41693d9 100644 --- a/src/llama_stack/core/server/fastapi_router_registry.py +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -14,6 +14,8 @@ from typing import Any, cast from fastapi import APIRouter +from fastapi.routing import APIRoute +from starlette.routing import Route # Router factories for APIs that have FastAPI routers # Add new APIs here as they are migrated to the router system @@ -43,3 +45,40 @@ def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None: # If a router factory returns the wrong type, it will fail at runtime when # app.include_router(router) is called return cast(APIRouter, router_factory(impl)) + + +def get_router_routes(router: APIRouter) -> list[Route]: + """Extract routes from a FastAPI router. + + Args: + router: The FastAPI router to extract routes from + + Returns: + List of Route objects from the router + """ + routes = [] + + for route in router.routes: + # FastAPI routers use APIRoute objects, which have path and methods attributes + if isinstance(route, APIRoute): + # Combine router prefix with route path + routes.append( + Route( + path=route.path, + methods=route.methods, + name=route.name, + endpoint=route.endpoint, + ) + ) + elif isinstance(route, Route): + # Fallback for regular Starlette Route objects + routes.append( + Route( + path=route.path, + methods=route.methods, + name=route.name, + endpoint=route.endpoint, + ) + ) + + return routes