diff --git a/client-sdks/stainless/openapi.yml b/client-sdks/stainless/openapi.yml index 9269b7e39e..e658a62378 100644 --- a/client-sdks/stainless/openapi.yml +++ b/client-sdks/stainless/openapi.yml @@ -37,7 +37,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -48,14 +48,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -76,9 +80,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -97,20 +103,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -119,7 +125,9 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + description: The ID of the batch to retrieve. + title: Batch Id + description: The ID of the batch to retrieve. /v1/batches/{batch_id}/cancel: post: responses: @@ -130,20 +138,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -152,7 +160,9 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + description: The ID of the batch to cancel. + title: Batch Id + description: The ID of the batch to cancel. /v1/chat/completions: get: responses: @@ -3950,29 +3960,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: @@ -12760,6 +12776,44 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/deprecated-llama-stack-spec.yaml b/docs/static/deprecated-llama-stack-spec.yaml index cf9bd14c4b..aae0fbe442 100644 --- a/docs/static/deprecated-llama-stack-spec.yaml +++ b/docs/static/deprecated-llama-stack-spec.yaml @@ -793,29 +793,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: @@ -9603,6 +9609,44 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/experimental-llama-stack-spec.yaml b/docs/static/experimental-llama-stack-spec.yaml index 18ce75562b..0fcbe8ba12 100644 --- a/docs/static/experimental-llama-stack-spec.yaml +++ b/docs/static/experimental-llama-stack-spec.yaml @@ -688,6 +688,40 @@ components: - data title: ListBatchesResponse description: Response containing a list of batch objects. + CreateBatchRequest: + properties: + input_file_id: + type: string + title: Input File Id + description: The ID of an uploaded file containing requests for the batch. + endpoint: + type: string + title: Endpoint + description: The endpoint to be used for all requests in the batch. + completion_window: + type: string + const: 24h + title: Completion Window + description: The time window within which the batch should be processed. + metadata: + anyOf: + - additionalProperties: + type: string + type: object + - type: 'null' + description: Optional metadata for the batch. + idempotency_key: + anyOf: + - type: string + - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. + type: object + required: + - input_file_id + - endpoint + - completion_window + title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: @@ -8532,6 +8566,44 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/llama-stack-spec.yaml b/docs/static/llama-stack-spec.yaml index 9f7b2ed64a..403e01dd1a 100644 --- a/docs/static/llama-stack-spec.yaml +++ b/docs/static/llama-stack-spec.yaml @@ -35,7 +35,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -46,14 +46,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -74,9 +78,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -95,20 +101,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -117,7 +123,9 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + description: The ID of the batch to retrieve. + title: Batch Id + description: The ID of the batch to retrieve. /v1/batches/{batch_id}/cancel: post: responses: @@ -128,20 +136,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -150,7 +158,9 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + description: The ID of the batch to cancel. + title: Batch Id + description: The ID of the batch to cancel. /v1/chat/completions: get: responses: @@ -2971,29 +2981,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: @@ -11430,6 +11446,44 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/docs/static/stainless-llama-stack-spec.yaml b/docs/static/stainless-llama-stack-spec.yaml index 9269b7e39e..e658a62378 100644 --- a/docs/static/stainless-llama-stack-spec.yaml +++ b/docs/static/stainless-llama-stack-spec.yaml @@ -37,7 +37,7 @@ paths: description: Default Response tags: - Batches - summary: List Batches + summary: List all batches for the current user. description: List all batches for the current user. operationId: list_batches_v1_batches_get parameters: @@ -48,14 +48,18 @@ paths: anyOf: - type: string - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. title: After + description: Optional cursor for pagination. Returns batches after this ID. - name: limit in: query required: false schema: type: integer + description: Maximum number of batches to return. Defaults to 20. default: 20 title: Limit + description: Maximum number of batches to return. Defaults to 20. post: responses: '200': @@ -76,9 +80,11 @@ paths: default: $ref: '#/components/responses/DefaultError' description: Default Response + '409': + description: 'Conflict: The idempotency key was previously used with different parameters.' tags: - Batches - summary: Create Batch + summary: Create a new batch for processing multiple API requests. description: Create a new batch for processing multiple API requests. operationId: create_batch_v1_batches_post requestBody: @@ -97,20 +103,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Retrieve Batch + summary: Retrieve information about a specific batch. description: Retrieve information about a specific batch. operationId: retrieve_batch_v1_batches__batch_id__get parameters: @@ -119,7 +125,9 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + description: The ID of the batch to retrieve. + title: Batch Id + description: The ID of the batch to retrieve. /v1/batches/{batch_id}/cancel: post: responses: @@ -130,20 +138,20 @@ paths: schema: $ref: '#/components/schemas/Batch' '400': - description: Bad Request $ref: '#/components/responses/BadRequest400' + description: Bad Request '429': - description: Too Many Requests $ref: '#/components/responses/TooManyRequests429' + description: Too Many Requests '500': - description: Internal Server Error $ref: '#/components/responses/InternalServerError500' + description: Internal Server Error default: - description: Default Response $ref: '#/components/responses/DefaultError' + description: Default Response tags: - Batches - summary: Cancel Batch + summary: Cancel a batch that is in progress. description: Cancel a batch that is in progress. operationId: cancel_batch_v1_batches__batch_id__cancel_post parameters: @@ -152,7 +160,9 @@ paths: required: true schema: type: string - description: 'Path parameter: batch_id' + description: The ID of the batch to cancel. + title: Batch Id + description: The ID of the batch to cancel. /v1/chat/completions: get: responses: @@ -3950,29 +3960,35 @@ components: input_file_id: type: string title: Input File Id + description: The ID of an uploaded file containing requests for the batch. endpoint: type: string title: Endpoint + description: The endpoint to be used for all requests in the batch. completion_window: type: string const: 24h title: Completion Window + description: The time window within which the batch should be processed. metadata: anyOf: - additionalProperties: type: string type: object - type: 'null' + description: Optional metadata for the batch. idempotency_key: anyOf: - type: string - type: 'null' + description: Optional idempotency key. When provided, enables idempotent behavior. type: object required: - input_file_id - endpoint - completion_window title: CreateBatchRequest + description: Request model for creating a batch. Batch: properties: id: @@ -12760,6 +12776,44 @@ components: - query title: VectorStoreSearchRequest type: object + ListBatchesRequest: + description: Request model for listing batches. + properties: + after: + anyOf: + - type: string + - type: 'null' + description: Optional cursor for pagination. Returns batches after this ID. + nullable: true + limit: + default: 20 + description: Maximum number of batches to return. Defaults to 20. + title: Limit + type: integer + title: ListBatchesRequest + type: object + RetrieveBatchRequest: + description: Request model for retrieving a batch. + properties: + batch_id: + description: The ID of the batch to retrieve. + title: Batch Id + type: string + required: + - batch_id + title: RetrieveBatchRequest + type: object + CancelBatchRequest: + description: Request model for canceling a batch. + properties: + batch_id: + description: The ID of the batch to cancel. + title: Batch Id + type: string + required: + - batch_id + title: CancelBatchRequest + type: object DialogType: description: Parameter type for dialog data with semantic output labels. properties: diff --git a/scripts/openapi_generator/app.py b/scripts/openapi_generator/app.py index d972889cdc..023a4c62eb 100644 --- a/scripts/openapi_generator/app.py +++ b/scripts/openapi_generator/app.py @@ -14,6 +14,7 @@ from fastapi import FastAPI from llama_stack.core.resolver import api_protocol_map +from llama_stack.core.server.fastapi_router_registry import build_fastapi_router from llama_stack_api import Api from .state import _protocol_methods_cache @@ -64,7 +65,8 @@ def _get_protocol_method(api: Api, method_name: str) -> Any | None: def create_llama_stack_app() -> FastAPI: """ Create a FastAPI app that represents the Llama Stack API. - This uses the existing route discovery system to automatically find all routes. + This uses both router-based routes (for migrated APIs) and the existing + route discovery system for legacy webmethod-based routes. """ app = FastAPI( title="Llama Stack API", @@ -75,15 +77,27 @@ def create_llama_stack_app() -> FastAPI: ], ) - # Get all API routes + # Include routers for APIs that have them + protocols = api_protocol_map() + for api in protocols.keys(): + # For OpenAPI generation, we don't need a real implementation + router = build_fastapi_router(api, None) + if router: + app.include_router(router) + + # Get all API routes (for legacy webmethod-based routes) from llama_stack.core.server.routes import get_all_api_routes api_routes = get_all_api_routes() - # Create FastAPI routes from the discovered routes + # Create FastAPI routes from the discovered routes (skip APIs that have routers) from . import endpoints for api, routes in api_routes.items(): + # Skip APIs that have routers - they're already included above + if build_fastapi_router(api, None) is not None: + continue + for route, webmethod in routes: # Convert the route to a FastAPI endpoint endpoints._create_fastapi_endpoint(app, route, webmethod, api) diff --git a/src/llama_stack/core/inspect.py b/src/llama_stack/core/inspect.py index 272c9d1bc2..45cab29707 100644 --- a/src/llama_stack/core/inspect.py +++ b/src/llama_stack/core/inspect.py @@ -10,8 +10,14 @@ from llama_stack.core.datatypes import StackRunConfig from llama_stack.core.external import load_external_apis +from llama_stack.core.server.fastapi_router_registry import ( + _ROUTER_FACTORIES, + build_fastapi_router, + get_router_routes, +) from llama_stack.core.server.routes import get_all_api_routes from llama_stack_api import ( + Api, HealthInfo, HealthStatus, Inspect, @@ -43,6 +49,7 @@ async def list_routes(self, api_filter: str | None = None) -> ListRoutesResponse run_config: StackRunConfig = self.config.run_config # Helper function to determine if a route should be included based on api_filter + # TODO: remove this once we've migrated all APIs to FastAPI routers def should_include_route(webmethod) -> bool: if api_filter is None: # Default: only non-deprecated APIs @@ -54,10 +61,59 @@ def should_include_route(webmethod) -> bool: # Filter by API level (non-deprecated routes only) return not webmethod.deprecated and webmethod.level == api_filter + # Helper function to get provider types for an API + def _get_provider_types(api: Api) -> list[str]: + if api.value in ["providers", "inspect"]: + return [] # These APIs don't have "real" providers they're internal to the stack + providers = run_config.providers.get(api.value, []) + return [p.provider_type for p in providers] if providers else [] + + # Helper function to determine if a router route should be included based on api_filter + def _should_include_router_route(route, router_prefix: str | None) -> bool: + """Check if a router-based route should be included based on api_filter.""" + # Check deprecated status + route_deprecated = getattr(route, "deprecated", False) or False + + if api_filter is None: + # Default: only non-deprecated routes + return not route_deprecated + elif api_filter == "deprecated": + # Special filter: show deprecated routes regardless of their actual level + return route_deprecated + else: + # Filter by API level (non-deprecated routes only) + # Extract level from router prefix (e.g., "/v1" -> "v1") + if router_prefix: + prefix_level = router_prefix.lstrip("/") + return not route_deprecated and prefix_level == api_filter + return not route_deprecated + ret = [] external_apis = load_external_apis(run_config) all_endpoints = get_all_api_routes(external_apis) + + # Process routes from APIs with FastAPI routers + for api_name in _ROUTER_FACTORIES.keys(): + api = Api(api_name) + router = build_fastapi_router(api, None) # we don't need the impl here, just the routes + if router: + router_routes = get_router_routes(router) + for route in router_routes: + if _should_include_router_route(route, router.prefix): + ret.append( + RouteInfo( + route=route.path, + method=next(iter([m for m in route.methods if m != "HEAD"])), + provider_types=_get_provider_types(api), + ) + ) + + # Process routes from legacy webmethod-based APIs for api, endpoints in all_endpoints.items(): + # Skip APIs that have routers (already processed above) + if api.value in _ROUTER_FACTORIES: + continue + # Always include provider and inspect APIs, filter others based on run config if api.value in ["providers", "inspect"]: ret.extend( diff --git a/src/llama_stack/core/server/fastapi_router_registry.py b/src/llama_stack/core/server/fastapi_router_registry.py new file mode 100644 index 0000000000..84f41693d9 --- /dev/null +++ b/src/llama_stack/core/server/fastapi_router_registry.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Router utilities for FastAPI routers. + +This module provides utilities to create FastAPI routers from API packages. +APIs with routers are explicitly listed here. +""" + +from collections.abc import Callable +from typing import Any, cast + +from fastapi import APIRouter +from fastapi.routing import APIRoute +from starlette.routing import Route + +# Router factories for APIs that have FastAPI routers +# Add new APIs here as they are migrated to the router system +from llama_stack_api.batches.fastapi_routes import create_router as create_batches_router +from llama_stack_api.datatypes import Api + +_ROUTER_FACTORIES: dict[str, Callable[[Any], APIRouter]] = { + "batches": create_batches_router, +} + + +def build_fastapi_router(api: "Api", impl: Any) -> APIRouter | None: + """Build a router for an API by combining its router factory with the implementation. + + Args: + api: The API enum value + impl: The implementation instance for the API + + Returns: + APIRouter if the API has a router factory, None otherwise + """ + router_factory = _ROUTER_FACTORIES.get(api.value) + if router_factory is None: + return None + + # cast is safe here: all router factories in API packages are required to return APIRouter. + # If a router factory returns the wrong type, it will fail at runtime when + # app.include_router(router) is called + return cast(APIRouter, router_factory(impl)) + + +def get_router_routes(router: APIRouter) -> list[Route]: + """Extract routes from a FastAPI router. + + Args: + router: The FastAPI router to extract routes from + + Returns: + List of Route objects from the router + """ + routes = [] + + for route in router.routes: + # FastAPI routers use APIRoute objects, which have path and methods attributes + if isinstance(route, APIRoute): + # Combine router prefix with route path + routes.append( + Route( + path=route.path, + methods=route.methods, + name=route.name, + endpoint=route.endpoint, + ) + ) + elif isinstance(route, Route): + # Fallback for regular Starlette Route objects + routes.append( + Route( + path=route.path, + methods=route.methods, + name=route.name, + endpoint=route.endpoint, + ) + ) + + return routes diff --git a/src/llama_stack/core/server/routes.py b/src/llama_stack/core/server/routes.py index af50025654..9df9e4a60e 100644 --- a/src/llama_stack/core/server/routes.py +++ b/src/llama_stack/core/server/routes.py @@ -26,6 +26,18 @@ def get_all_api_routes( external_apis: dict[Api, ExternalApiSpec] | None = None, ) -> dict[Api, list[tuple[Route, WebMethod]]]: + """Get all API routes from webmethod-based protocols. + + This function only returns routes from APIs that use the legacy @webmethod + decorator system. For APIs that have been migrated to FastAPI routers, + use the router registry (fastapi_router_registry.has_router() and fastapi_router_registry.build_fastapi_router()). + + Args: + external_apis: Optional dictionary of external API specifications + + Returns: + Dictionary mapping API to list of (Route, WebMethod) tuples + """ apis = {} protocols = api_protocol_map(external_apis) diff --git a/src/llama_stack/core/server/server.py b/src/llama_stack/core/server/server.py index 0d3513980a..e316609c35 100644 --- a/src/llama_stack/core/server/server.py +++ b/src/llama_stack/core/server/server.py @@ -44,6 +44,7 @@ request_provider_data_context, user_from_scope, ) +from llama_stack.core.server.fastapi_router_registry import build_fastapi_router from llama_stack.core.server.routes import get_all_api_routes from llama_stack.core.stack import ( Stack, @@ -87,7 +88,7 @@ def create_sse_event(data: Any) -> str: async def global_exception_handler(request: Request, exc: Exception): - traceback.print_exception(exc) + traceback.print_exception(type(exc), exc, exc.__traceback__) http_exc = translate_exception(exc) return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}}) @@ -463,15 +464,22 @@ def create_app() -> StackApp: apis_to_serve.add("providers") apis_to_serve.add("prompts") apis_to_serve.add("conversations") + for api_str in apis_to_serve: api = Api(api_str) - routes = all_routes[api] - try: - impl = impls[api] - except KeyError as e: - raise ValueError(f"Could not find provider implementation for {api} API") from e + # Try to discover and use a router factory from the API package + impl = impls[api] + router = build_fastapi_router(api, impl) + if router: + app.include_router(router) + logger.debug(f"Registered FastAPIrouter for {api} API") + continue + + # Fall back to old webmethod-based route discovery until the migration is complete + impl = impls[api] + routes = all_routes[api] for route, _ in routes: if not hasattr(impl, route.name): # ideally this should be a typing violation already @@ -497,7 +505,15 @@ def create_app() -> StackApp: logger.debug(f"serving APIs: {apis_to_serve}") + # Register specific exception handlers before the generic Exception handler + # This prevents the re-raising behavior that causes connection resets app.exception_handler(RequestValidationError)(global_exception_handler) + app.exception_handler(ConflictError)(global_exception_handler) + app.exception_handler(ResourceNotFoundError)(global_exception_handler) + app.exception_handler(AuthenticationRequiredError)(global_exception_handler) + app.exception_handler(AccessDeniedError)(global_exception_handler) + app.exception_handler(BadRequestError)(global_exception_handler) + # Generic Exception handler should be last app.exception_handler(Exception)(global_exception_handler) if config.telemetry.enabled: diff --git a/src/llama_stack/providers/inline/batches/reference/batches.py b/src/llama_stack/providers/inline/batches/reference/batches.py index aaa2c7b220..57ef939d3a 100644 --- a/src/llama_stack/providers/inline/batches/reference/batches.py +++ b/src/llama_stack/providers/inline/batches/reference/batches.py @@ -11,7 +11,7 @@ import time import uuid from io import BytesIO -from typing import Any, Literal +from typing import Any from openai.types.batch import BatchError, Errors from pydantic import BaseModel @@ -38,6 +38,12 @@ OpenAIUserMessageParam, ResourceNotFoundError, ) +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + RetrieveBatchRequest, +) from .config import ReferenceBatchesImplConfig @@ -140,11 +146,7 @@ async def shutdown(self) -> None: # TODO (SECURITY): this currently works w/ configured api keys, not with x-llamastack-provider-data or with user policy restrictions async def create_batch( self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - idempotency_key: str | None = None, + request: CreateBatchRequest, ) -> BatchObject: """ Create a new batch for processing multiple API requests. @@ -185,14 +187,14 @@ async def create_batch( # TODO: set expiration time for garbage collection - if endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: + if request.endpoint not in ["/v1/chat/completions", "/v1/completions", "/v1/embeddings"]: raise ValueError( - f"Invalid endpoint: {endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", + f"Invalid endpoint: {request.endpoint}. Supported values: /v1/chat/completions, /v1/completions, /v1/embeddings. Code: invalid_value. Param: endpoint", ) - if completion_window != "24h": + if request.completion_window != "24h": raise ValueError( - f"Invalid completion_window: {completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", + f"Invalid completion_window: {request.completion_window}. Supported values are: 24h. Code: invalid_value. Param: completion_window", ) batch_id = f"batch_{uuid.uuid4().hex[:16]}" @@ -200,22 +202,22 @@ async def create_batch( # For idempotent requests, use the idempotency key for the batch ID # This ensures the same key always maps to the same batch ID, # allowing us to detect parameter conflicts - if idempotency_key is not None: - hash_input = idempotency_key.encode("utf-8") + if request.idempotency_key is not None: + hash_input = request.idempotency_key.encode("utf-8") hash_digest = hashlib.sha256(hash_input).hexdigest()[:24] batch_id = f"batch_{hash_digest}" try: - existing_batch = await self.retrieve_batch(batch_id) + existing_batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id)) if ( - existing_batch.input_file_id != input_file_id - or existing_batch.endpoint != endpoint - or existing_batch.completion_window != completion_window - or existing_batch.metadata != metadata + existing_batch.input_file_id != request.input_file_id + or existing_batch.endpoint != request.endpoint + or existing_batch.completion_window != request.completion_window + or existing_batch.metadata != request.metadata ): raise ConflictError( - f"Idempotency key '{idempotency_key}' was previously used with different parameters. " + f"Idempotency key '{request.idempotency_key}' was previously used with different parameters. " "Either use a new idempotency key or ensure all parameters match the original request." ) @@ -230,12 +232,12 @@ async def create_batch( batch = BatchObject( id=batch_id, object="batch", - endpoint=endpoint, - input_file_id=input_file_id, - completion_window=completion_window, + endpoint=request.endpoint, + input_file_id=request.input_file_id, + completion_window=request.completion_window, status="validating", created_at=current_time, - metadata=metadata, + metadata=request.metadata, ) await self.kvstore.set(f"batch:{batch_id}", batch.to_json()) @@ -247,28 +249,27 @@ async def create_batch( return batch - async def cancel_batch(self, batch_id: str) -> BatchObject: + async def cancel_batch(self, request: CancelBatchRequest) -> BatchObject: """Cancel a batch that is in progress.""" - batch = await self.retrieve_batch(batch_id) + batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id)) if batch.status in ["cancelled", "cancelling"]: return batch if batch.status in ["completed", "failed", "expired"]: - raise ConflictError(f"Cannot cancel batch '{batch_id}' with status '{batch.status}'") + raise ConflictError(f"Cannot cancel batch '{request.batch_id}' with status '{batch.status}'") - await self._update_batch(batch_id, status="cancelling", cancelling_at=int(time.time())) + await self._update_batch(request.batch_id, status="cancelling", cancelling_at=int(time.time())) - if batch_id in self._processing_tasks: - self._processing_tasks[batch_id].cancel() + if request.batch_id in self._processing_tasks: + self._processing_tasks[request.batch_id].cancel() # note: task removal and status="cancelled" handled in finally block of _process_batch - return await self.retrieve_batch(batch_id) + return await self.retrieve_batch(RetrieveBatchRequest(batch_id=request.batch_id)) async def list_batches( self, - after: str | None = None, - limit: int = 20, + request: ListBatchesRequest, ) -> ListBatchesResponse: """ List all batches, eventually only for the current user. @@ -285,14 +286,14 @@ async def list_batches( batches.sort(key=lambda b: b.created_at, reverse=True) start_idx = 0 - if after: + if request.after: for i, batch in enumerate(batches): - if batch.id == after: + if batch.id == request.after: start_idx = i + 1 break - page_batches = batches[start_idx : start_idx + limit] - has_more = (start_idx + limit) < len(batches) + page_batches = batches[start_idx : start_idx + request.limit] + has_more = (start_idx + request.limit) < len(batches) first_id = page_batches[0].id if page_batches else None last_id = page_batches[-1].id if page_batches else None @@ -304,11 +305,11 @@ async def list_batches( has_more=has_more, ) - async def retrieve_batch(self, batch_id: str) -> BatchObject: + async def retrieve_batch(self, request: RetrieveBatchRequest) -> BatchObject: """Retrieve information about a specific batch.""" - batch_data = await self.kvstore.get(f"batch:{batch_id}") + batch_data = await self.kvstore.get(f"batch:{request.batch_id}") if not batch_data: - raise ResourceNotFoundError(batch_id, "Batch", "batches.list()") + raise ResourceNotFoundError(request.batch_id, "Batch", "batches.list()") return BatchObject.model_validate_json(batch_data) @@ -316,7 +317,7 @@ async def _update_batch(self, batch_id: str, **updates) -> None: """Update batch fields in kvstore.""" async with self._update_batch_lock: try: - batch = await self.retrieve_batch(batch_id) + batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id)) # batch processing is async. once cancelling, only allow "cancelled" status updates if batch.status == "cancelling" and updates.get("status") != "cancelled": @@ -536,7 +537,7 @@ async def _process_batch(self, batch_id: str) -> None: async def _process_batch_impl(self, batch_id: str) -> None: """Implementation of batch processing logic.""" errors: list[BatchError] = [] - batch = await self.retrieve_batch(batch_id) + batch = await self.retrieve_batch(RetrieveBatchRequest(batch_id=batch_id)) errors, requests = await self._validate_input(batch) if errors: diff --git a/src/llama_stack_api/__init__.py b/src/llama_stack_api/__init__.py index b6fe2fd239..f919d2afd7 100644 --- a/src/llama_stack_api/__init__.py +++ b/src/llama_stack_api/__init__.py @@ -26,7 +26,15 @@ # Import all public API symbols from .agents import Agents, ResponseGuardrail, ResponseGuardrailSpec -from .batches import Batches, BatchObject, ListBatchesResponse +from .batches import ( + Batches, + BatchObject, + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + ListBatchesResponse, + RetrieveBatchRequest, +) from .benchmarks import ( Benchmark, BenchmarkInput, @@ -462,6 +470,9 @@ "BasicScoringFnParams", "Batches", "BatchObject", + "CancelBatchRequest", + "CreateBatchRequest", + "ListBatchesRequest", "Benchmark", "BenchmarkConfig", "BenchmarkInput", @@ -555,6 +566,7 @@ "LLMAsJudgeScoringFnParams", "LLMRAGQueryGeneratorConfig", "ListBatchesResponse", + "RetrieveBatchRequest", "ListBenchmarksResponse", "ListDatasetsResponse", "ListModelsResponse", diff --git a/src/llama_stack_api/batches.py b/src/llama_stack_api/batches.py deleted file mode 100644 index 00c47d39f7..0000000000 --- a/src/llama_stack_api/batches.py +++ /dev/null @@ -1,96 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Literal, Protocol, runtime_checkable - -from pydantic import BaseModel, Field - -from llama_stack_api.schema_utils import json_schema_type, webmethod -from llama_stack_api.version import LLAMA_STACK_API_V1 - -try: - from openai.types import Batch as BatchObject -except ImportError as e: - raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e - - -@json_schema_type -class ListBatchesResponse(BaseModel): - """Response containing a list of batch objects.""" - - object: Literal["list"] = "list" - data: list[BatchObject] = Field(..., description="List of batch objects") - first_id: str | None = Field(default=None, description="ID of the first batch in the list") - last_id: str | None = Field(default=None, description="ID of the last batch in the list") - has_more: bool = Field(default=False, description="Whether there are more batches available") - - -@runtime_checkable -class Batches(Protocol): - """ - The Batches API enables efficient processing of multiple requests in a single operation, - particularly useful for processing large datasets, batch evaluation workflows, and - cost-effective inference at scale. - - The API is designed to allow use of openai client libraries for seamless integration. - - This API provides the following extensions: - - idempotent batch creation - - Note: This API is currently under active development and may undergo changes. - """ - - @webmethod(route="/batches", method="POST", level=LLAMA_STACK_API_V1) - async def create_batch( - self, - input_file_id: str, - endpoint: str, - completion_window: Literal["24h"], - metadata: dict[str, str] | None = None, - idempotency_key: str | None = None, - ) -> BatchObject: - """Create a new batch for processing multiple API requests. - - :param input_file_id: The ID of an uploaded file containing requests for the batch. - :param endpoint: The endpoint to be used for all requests in the batch. - :param completion_window: The time window within which the batch should be processed. - :param metadata: Optional metadata for the batch. - :param idempotency_key: Optional idempotency key. When provided, enables idempotent behavior. - :returns: The created batch object. - """ - ... - - @webmethod(route="/batches/{batch_id}", method="GET", level=LLAMA_STACK_API_V1) - async def retrieve_batch(self, batch_id: str) -> BatchObject: - """Retrieve information about a specific batch. - - :param batch_id: The ID of the batch to retrieve. - :returns: The batch object. - """ - ... - - @webmethod(route="/batches/{batch_id}/cancel", method="POST", level=LLAMA_STACK_API_V1) - async def cancel_batch(self, batch_id: str) -> BatchObject: - """Cancel a batch that is in progress. - - :param batch_id: The ID of the batch to cancel. - :returns: The updated batch object. - """ - ... - - @webmethod(route="/batches", method="GET", level=LLAMA_STACK_API_V1) - async def list_batches( - self, - after: str | None = None, - limit: int = 20, - ) -> ListBatchesResponse: - """List all batches for the current user. - - :param after: A cursor for pagination; returns batches after this batch ID. - :param limit: Number of batches to return (default 20, max 100). - :returns: A list of batch objects. - """ - ... diff --git a/src/llama_stack_api/batches/__init__.py b/src/llama_stack_api/batches/__init__.py new file mode 100644 index 0000000000..33dc62b811 --- /dev/null +++ b/src/llama_stack_api/batches/__init__.py @@ -0,0 +1,75 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Batches API protocol and models. + +This module contains the Batches protocol definition. +Pydantic models are defined in llama_stack_api.batches.models. +The FastAPI router is defined in llama_stack_api.batches.fastapi_routes. +""" + +from typing import Protocol, runtime_checkable + +try: + from openai.types import Batch as BatchObject +except ImportError as e: + raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e + +# Import models for re-export +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + ListBatchesResponse, + RetrieveBatchRequest, +) + + +@runtime_checkable +class Batches(Protocol): + """ + The Batches API enables efficient processing of multiple requests in a single operation, + particularly useful for processing large datasets, batch evaluation workflows, and + cost-effective inference at scale. + + The API is designed to allow use of openai client libraries for seamless integration. + + This API provides the following extensions: + - idempotent batch creation + + Note: This API is currently under active development and may undergo changes. + """ + + async def create_batch( + self, + request: CreateBatchRequest, + ) -> BatchObject: ... + + async def retrieve_batch( + self, + request: RetrieveBatchRequest, + ) -> BatchObject: ... + + async def cancel_batch( + self, + request: CancelBatchRequest, + ) -> BatchObject: ... + + async def list_batches( + self, + request: ListBatchesRequest, + ) -> ListBatchesResponse: ... + + +__all__ = [ + "Batches", + "BatchObject", + "CreateBatchRequest", + "ListBatchesRequest", + "RetrieveBatchRequest", + "CancelBatchRequest", + "ListBatchesResponse", +] diff --git a/src/llama_stack_api/batches/fastapi_routes.py b/src/llama_stack_api/batches/fastapi_routes.py new file mode 100644 index 0000000000..dd5dc7a6cf --- /dev/null +++ b/src/llama_stack_api/batches/fastapi_routes.py @@ -0,0 +1,111 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""FastAPI router for the Batches API. + +This module defines the FastAPI router for the Batches API using standard +FastAPI route decorators. The router is defined in the API package to keep +all API-related code together. +""" + +from typing import Annotated + +from fastapi import APIRouter, Body, Depends + +from llama_stack_api.batches import Batches, BatchObject, ListBatchesResponse +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + RetrieveBatchRequest, +) +from llama_stack_api.router_utils import create_path_dependency, create_query_dependency, standard_responses +from llama_stack_api.version import LLAMA_STACK_API_V1 + +# Automatically generate dependency functions from Pydantic models +# This ensures the models are the single source of truth for descriptions +get_retrieve_batch_request = create_path_dependency(RetrieveBatchRequest) +get_cancel_batch_request = create_path_dependency(CancelBatchRequest) + + +# Automatically generate dependency function from Pydantic model +# This ensures the model is the single source of truth for descriptions and defaults +get_list_batches_request = create_query_dependency(ListBatchesRequest) + + +def create_router(impl: Batches) -> APIRouter: + """Create a FastAPI router for the Batches API. + + Args: + impl: The Batches implementation instance + + Returns: + APIRouter configured for the Batches API + """ + router = APIRouter( + prefix=f"/{LLAMA_STACK_API_V1}", + tags=["Batches"], + responses=standard_responses, + ) + + @router.post( + "/batches", + response_model=BatchObject, + summary="Create a new batch for processing multiple API requests.", + description="Create a new batch for processing multiple API requests.", + responses={ + 200: {"description": "The created batch object."}, + 409: {"description": "Conflict: The idempotency key was previously used with different parameters."}, + }, + ) + async def create_batch( + request: Annotated[CreateBatchRequest, Body(...)], + ) -> BatchObject: + return await impl.create_batch(request) + + @router.get( + "/batches/{batch_id}", + response_model=BatchObject, + summary="Retrieve information about a specific batch.", + description="Retrieve information about a specific batch.", + responses={ + 200: {"description": "The batch object."}, + }, + ) + async def retrieve_batch( + request: Annotated[RetrieveBatchRequest, Depends(get_retrieve_batch_request)], + ) -> BatchObject: + return await impl.retrieve_batch(request) + + @router.post( + "/batches/{batch_id}/cancel", + response_model=BatchObject, + summary="Cancel a batch that is in progress.", + description="Cancel a batch that is in progress.", + responses={ + 200: {"description": "The updated batch object."}, + }, + ) + async def cancel_batch( + request: Annotated[CancelBatchRequest, Depends(get_cancel_batch_request)], + ) -> BatchObject: + return await impl.cancel_batch(request) + + @router.get( + "/batches", + response_model=ListBatchesResponse, + summary="List all batches for the current user.", + description="List all batches for the current user.", + responses={ + 200: {"description": "A list of batch objects."}, + }, + ) + async def list_batches( + request: Annotated[ListBatchesRequest, Depends(get_list_batches_request)], + ) -> ListBatchesResponse: + return await impl.list_batches(request) + + return router diff --git a/src/llama_stack_api/batches/models.py b/src/llama_stack_api/batches/models.py new file mode 100644 index 0000000000..49fd25d169 --- /dev/null +++ b/src/llama_stack_api/batches/models.py @@ -0,0 +1,82 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Pydantic models for Batches API requests and responses. + +This module defines the request and response models for the Batches API +using Pydantic with Field descriptions for OpenAPI schema generation. +""" + +from typing import Literal + +from pydantic import BaseModel, Field + +from llama_stack_api.schema_utils import json_schema_type + +try: + from openai.types import Batch as BatchObject +except ImportError as e: + raise ImportError("OpenAI package is required for batches API. Please install it with: pip install openai") from e + + +@json_schema_type +class CreateBatchRequest(BaseModel): + """Request model for creating a batch.""" + + input_file_id: str = Field(..., description="The ID of an uploaded file containing requests for the batch.") + endpoint: str = Field(..., description="The endpoint to be used for all requests in the batch.") + completion_window: Literal["24h"] = Field( + ..., description="The time window within which the batch should be processed." + ) + metadata: dict[str, str] | None = Field(default=None, description="Optional metadata for the batch.") + idempotency_key: str | None = Field( + default=None, description="Optional idempotency key. When provided, enables idempotent behavior." + ) + + +@json_schema_type +class ListBatchesRequest(BaseModel): + """Request model for listing batches.""" + + after: str | None = Field( + default=None, description="Optional cursor for pagination. Returns batches after this ID." + ) + limit: int = Field(default=20, description="Maximum number of batches to return. Defaults to 20.") + + +@json_schema_type +class RetrieveBatchRequest(BaseModel): + """Request model for retrieving a batch.""" + + batch_id: str = Field(..., description="The ID of the batch to retrieve.") + + +@json_schema_type +class CancelBatchRequest(BaseModel): + """Request model for canceling a batch.""" + + batch_id: str = Field(..., description="The ID of the batch to cancel.") + + +@json_schema_type +class ListBatchesResponse(BaseModel): + """Response containing a list of batch objects.""" + + object: Literal["list"] = "list" + data: list[BatchObject] = Field(..., description="List of batch objects") + first_id: str | None = Field(default=None, description="ID of the first batch in the list") + last_id: str | None = Field(default=None, description="ID of the last batch in the list") + has_more: bool = Field(default=False, description="Whether there are more batches available") + + +__all__ = [ + "CreateBatchRequest", + "ListBatchesRequest", + "RetrieveBatchRequest", + "CancelBatchRequest", + "ListBatchesResponse", + "BatchObject", +] diff --git a/src/llama_stack_api/pyproject.toml b/src/llama_stack_api/pyproject.toml index 0ceb2bb4e0..0fec354db5 100644 --- a/src/llama_stack_api/pyproject.toml +++ b/src/llama_stack_api/pyproject.toml @@ -24,6 +24,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Information Analysis", ] dependencies = [ + "fastapi>=0.115.0,<1.0", "pydantic>=2.11.9", "jsonschema", "opentelemetry-sdk>=1.30.0", diff --git a/src/llama_stack_api/router_utils.py b/src/llama_stack_api/router_utils.py new file mode 100644 index 0000000000..25c8f47c40 --- /dev/null +++ b/src/llama_stack_api/router_utils.py @@ -0,0 +1,155 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Utilities for creating FastAPI routers with standard error responses. + +This module provides standard error response definitions for FastAPI routers. +These responses use OpenAPI $ref references to component responses defined +in the OpenAPI specification. +""" + +import inspect +from collections.abc import Callable +from typing import Annotated, Any, TypeVar + +from fastapi import Path, Query +from pydantic import BaseModel + +standard_responses: dict[int | str, dict[str, Any]] = { + 400: {"$ref": "#/components/responses/BadRequest400"}, + 429: {"$ref": "#/components/responses/TooManyRequests429"}, + 500: {"$ref": "#/components/responses/InternalServerError500"}, + "default": {"$ref": "#/components/responses/DefaultError"}, +} + +T = TypeVar("T", bound=BaseModel) + + +def create_query_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]: + """Create a FastAPI dependency function from a Pydantic model for query parameters. + + FastAPI does not natively support using Pydantic models as query parameters + without a dependency function. Using a dependency function typically leads to + duplication: field types, default values, and descriptions must be repeated in + `Query(...)` annotations even though they already exist in the Pydantic model. + + This function automatically generates a dependency function that extracts query parameters + from the request and constructs an instance of the Pydantic model. The descriptions and + defaults are automatically extracted from the model's Field definitions, making the model + the single source of truth. + + Args: + model_class: The Pydantic model class to create a dependency for + + Returns: + A dependency function that can be used with FastAPI's Depends() + ``` + """ + # Build function signature dynamically from model fields + annotations: dict[str, Any] = {} + defaults: dict[str, Any] = {} + + for field_name, field_info in model_class.model_fields.items(): + # Extract description from Field + description = field_info.description + + # Create Query annotation with description from model + query_annotation = Query(description=description) if description else Query() + + # Create Annotated type with Query + field_type = field_info.annotation + annotations[field_name] = Annotated[field_type, query_annotation] + + # Set default value from model + if field_info.default is not inspect.Parameter.empty: + defaults[field_name] = field_info.default + + # Create the dependency function dynamically + def dependency_func(**kwargs: Any) -> T: + return model_class(**kwargs) + + # Set function signature + sig_params = [] + for field_name, field_type in annotations.items(): + default = defaults.get(field_name, inspect.Parameter.empty) + param = inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + default=default, + annotation=field_type, + ) + sig_params.append(param) + + # These attributes are set dynamically at runtime. While mypy can't verify them statically, + # they are standard Python function attributes that exist on all callable objects at runtime. + # Setting them allows FastAPI to properly introspect the function signature for dependency injection. + dependency_func.__signature__ = inspect.Signature(sig_params) # type: ignore[attr-defined] + dependency_func.__annotations__ = annotations # type: ignore[attr-defined] + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined] + + return dependency_func + + +def create_path_dependency[T: BaseModel](model_class: type[T]) -> Callable[..., T]: + """Create a FastAPI dependency function from a Pydantic model for path parameters. + + FastAPI requires path parameters to be explicitly annotated with `Path()`. When using + a Pydantic model that contains path parameters, you typically need a dependency function + that extracts the path parameter and constructs the model. This leads to duplication: + the parameter name, type, and description must be repeated in `Path(...)` annotations + even though they already exist in the Pydantic model. + + This function automatically generates a dependency function that extracts path parameters + from the request and constructs an instance of the Pydantic model. The descriptions are + automatically extracted from the model's Field definitions, making the model the single + source of truth. + + Args: + model_class: The Pydantic model class to create a dependency for. The model should + have exactly one field that represents the path parameter. + + Returns: + A dependency function that can be used with FastAPI's Depends() + ``` + """ + # Get the single field from the model (path parameter models typically have one field) + if len(model_class.model_fields) != 1: + raise ValueError( + f"Path parameter model {model_class.__name__} must have exactly one field, " + f"but has {len(model_class.model_fields)} fields" + ) + + field_name, field_info = next(iter(model_class.model_fields.items())) + + # Extract description from Field + description = field_info.description + + # Create Path annotation with description from model + path_annotation = Path(description=description) if description else Path() + + # Create Annotated type with Path + field_type = field_info.annotation + annotations: dict[str, Any] = {field_name: Annotated[field_type, path_annotation]} + + # Create the dependency function dynamically + def dependency_func(**kwargs: Any) -> T: + return model_class(**kwargs) + + # Set function signature + param = inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=annotations[field_name], + ) + + # These attributes are set dynamically at runtime. While mypy can't verify them statically, + # they are standard Python function attributes that exist on all callable objects at runtime. + # Setting them allows FastAPI to properly introspect the function signature for dependency injection. + dependency_func.__signature__ = inspect.Signature([param]) # type: ignore[attr-defined] + dependency_func.__annotations__ = annotations # type: ignore[attr-defined] + dependency_func.__name__ = f"get_{model_class.__name__.lower()}_request" # type: ignore[attr-defined] + + return dependency_func diff --git a/tests/unit/providers/batches/test_reference.py b/tests/unit/providers/batches/test_reference.py index 32d59234d4..bff286fc74 100644 --- a/tests/unit/providers/batches/test_reference.py +++ b/tests/unit/providers/batches/test_reference.py @@ -58,8 +58,15 @@ from unittest.mock import AsyncMock, MagicMock import pytest +from pydantic import ValidationError from llama_stack_api import BatchObject, ConflictError, ResourceNotFoundError +from llama_stack_api.batches.models import ( + CancelBatchRequest, + CreateBatchRequest, + ListBatchesRequest, + RetrieveBatchRequest, +) class TestReferenceBatchesImpl: @@ -169,7 +176,7 @@ def _validate_batch_type(self, batch, expected_metadata=None): async def test_create_and_retrieve_batch_success(self, provider, sample_batch_data): """Test successful batch creation and retrieval.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) self._validate_batch_type(created_batch, expected_metadata=sample_batch_data["metadata"]) @@ -184,7 +191,7 @@ async def test_create_and_retrieve_batch_success(self, provider, sample_batch_da assert isinstance(created_batch.created_at, int) assert created_batch.created_at > 0 - retrieved_batch = await provider.retrieve_batch(created_batch.id) + retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=created_batch.id)) self._validate_batch_type(retrieved_batch, expected_metadata=sample_batch_data["metadata"]) @@ -197,17 +204,15 @@ async def test_create_and_retrieve_batch_success(self, provider, sample_batch_da async def test_create_batch_without_metadata(self, provider): """Test batch creation without optional metadata.""" batch = await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="24h") ) assert batch.metadata is None async def test_create_batch_completion_window(self, provider): """Test batch creation with invalid completion window.""" - with pytest.raises(ValueError, match="Invalid completion_window"): - await provider.create_batch( - input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now" - ) + with pytest.raises(ValidationError, match="completion_window"): + CreateBatchRequest(input_file_id="file_123", endpoint="/v1/chat/completions", completion_window="now") @pytest.mark.parametrize( "endpoint", @@ -219,37 +224,43 @@ async def test_create_batch_completion_window(self, provider): async def test_create_batch_invalid_endpoints(self, provider, endpoint): """Test batch creation with various invalid endpoints.""" with pytest.raises(ValueError, match="Invalid endpoint"): - await provider.create_batch(input_file_id="file_123", endpoint=endpoint, completion_window="24h") + await provider.create_batch( + CreateBatchRequest(input_file_id="file_123", endpoint=endpoint, completion_window="24h") + ) async def test_create_batch_invalid_metadata(self, provider): """Test that batch creation fails with invalid metadata.""" with pytest.raises(ValueError, match="should be a valid string"): await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={123: "invalid_key"}, # Non-string key + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={123: "invalid_key"}, # Non-string key + ) ) with pytest.raises(ValueError, match="should be a valid string"): await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/chat/completions", - completion_window="24h", - metadata={"valid_key": 456}, # Non-string value + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/chat/completions", + completion_window="24h", + metadata={"valid_key": 456}, # Non-string value + ) ) async def test_retrieve_batch_not_found(self, provider): """Test error when retrieving non-existent batch.""" with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.retrieve_batch("nonexistent_batch") + await provider.retrieve_batch(RetrieveBatchRequest(batch_id="nonexistent_batch")) async def test_cancel_batch_success(self, provider, sample_batch_data): """Test successful batch cancellation.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) assert created_batch.status == "validating" - cancelled_batch = await provider.cancel_batch(created_batch.id) + cancelled_batch = await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id)) assert cancelled_batch.id == created_batch.id assert cancelled_batch.status in ["cancelling", "cancelled"] @@ -260,22 +271,22 @@ async def test_cancel_batch_success(self, provider, sample_batch_data): async def test_cancel_batch_invalid_statuses(self, provider, sample_batch_data, status): """Test error when cancelling batch in final states.""" provider.process_batches = False - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) # directly update status in kvstore await provider._update_batch(created_batch.id, status=status) with pytest.raises(ConflictError, match=f"Cannot cancel batch '{created_batch.id}' with status '{status}'"): - await provider.cancel_batch(created_batch.id) + await provider.cancel_batch(CancelBatchRequest(batch_id=created_batch.id)) async def test_cancel_batch_not_found(self, provider): """Test error when cancelling non-existent batch.""" with pytest.raises(ResourceNotFoundError, match=r"Batch 'nonexistent_batch' not found"): - await provider.cancel_batch("nonexistent_batch") + await provider.cancel_batch(CancelBatchRequest(batch_id="nonexistent_batch")) async def test_list_batches_empty(self, provider): """Test listing batches when none exist.""" - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert response.object == "list" assert response.data == [] @@ -285,9 +296,9 @@ async def test_list_batches_empty(self, provider): async def test_list_batches_single_batch(self, provider, sample_batch_data): """Test listing batches with single batch.""" - created_batch = await provider.create_batch(**sample_batch_data) + created_batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert len(response.data) == 1 self._validate_batch_type(response.data[0], expected_metadata=sample_batch_data["metadata"]) @@ -300,12 +311,12 @@ async def test_list_batches_multiple_batches(self, provider): """Test listing multiple batches.""" batches = [ await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) for i in range(3) ] - response = await provider.list_batches() + response = await provider.list_batches(ListBatchesRequest()) assert len(response.data) == 3 @@ -321,12 +332,12 @@ async def test_list_batches_with_limit(self, provider): """Test listing batches with limit parameter.""" batches = [ await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) for i in range(3) ] - response = await provider.list_batches(limit=2) + response = await provider.list_batches(ListBatchesRequest(limit=2)) assert len(response.data) == 2 assert response.has_more is True @@ -340,36 +351,36 @@ async def test_list_batches_with_pagination(self, provider): """Test listing batches with pagination using 'after' parameter.""" for i in range(3): await provider.create_batch( - input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id=f"file_{i}", endpoint="/v1/chat/completions", completion_window="24h") ) # Get first page - first_page = await provider.list_batches(limit=1) + first_page = await provider.list_batches(ListBatchesRequest(limit=1)) assert len(first_page.data) == 1 assert first_page.has_more is True # Get second page using 'after' - second_page = await provider.list_batches(limit=1, after=first_page.data[0].id) + second_page = await provider.list_batches(ListBatchesRequest(limit=1, after=first_page.data[0].id)) assert len(second_page.data) == 1 assert second_page.data[0].id != first_page.data[0].id # Verify we got the next batch in order - all_batches = await provider.list_batches() + all_batches = await provider.list_batches(ListBatchesRequest()) expected_second_batch_id = all_batches.data[1].id assert second_page.data[0].id == expected_second_batch_id async def test_list_batches_invalid_after(self, provider, sample_batch_data): """Test listing batches with invalid 'after' parameter.""" - await provider.create_batch(**sample_batch_data) + await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - response = await provider.list_batches(after="nonexistent_batch") + response = await provider.list_batches(ListBatchesRequest(after="nonexistent_batch")) # Should return all batches (no filtering when 'after' batch not found) assert len(response.data) == 1 async def test_kvstore_persistence(self, provider, sample_batch_data): """Test that batches are properly persisted in kvstore.""" - batch = await provider.create_batch(**sample_batch_data) + batch = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) stored_data = await provider.kvstore.get(f"batch:{batch.id}") assert stored_data is not None @@ -757,7 +768,7 @@ async def add_and_wait(batch_id: str): for _ in range(3): await provider.create_batch( - input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h" + CreateBatchRequest(input_file_id="file_id", endpoint="/v1/chat/completions", completion_window="24h") ) await asyncio.sleep(0.042) # let tasks start @@ -767,8 +778,10 @@ async def add_and_wait(batch_id: str): async def test_create_batch_embeddings_endpoint(self, provider): """Test that batch creation succeeds with embeddings endpoint.""" batch = await provider.create_batch( - input_file_id="file_123", - endpoint="/v1/embeddings", - completion_window="24h", + CreateBatchRequest( + input_file_id="file_123", + endpoint="/v1/embeddings", + completion_window="24h", + ) ) assert batch.endpoint == "/v1/embeddings" diff --git a/tests/unit/providers/batches/test_reference_idempotency.py b/tests/unit/providers/batches/test_reference_idempotency.py index acb7ca01c5..0ac73841eb 100644 --- a/tests/unit/providers/batches/test_reference_idempotency.py +++ b/tests/unit/providers/batches/test_reference_idempotency.py @@ -45,6 +45,7 @@ import pytest from llama_stack_api import ConflictError +from llama_stack_api.batches.models import CreateBatchRequest, RetrieveBatchRequest class TestReferenceBatchesIdempotency: @@ -56,18 +57,22 @@ async def test_idempotent_batch_creation_same_params(self, provider, sample_batc del sample_batch_data["metadata"] batch1 = await provider.create_batch( - **sample_batch_data, - metadata={"test": "value1", "other": "value2"}, - idempotency_key="unique-token-1", + CreateBatchRequest( + **sample_batch_data, + metadata={"test": "value1", "other": "value2"}, + idempotency_key="unique-token-1", + ) ) # sleep for 1 second to allow created_at timestamps to be different await asyncio.sleep(1) batch2 = await provider.create_batch( - **sample_batch_data, - metadata={"other": "value2", "test": "value1"}, # Different order - idempotency_key="unique-token-1", + CreateBatchRequest( + **sample_batch_data, + metadata={"other": "value2", "test": "value1"}, # Different order + idempotency_key="unique-token-1", + ) ) assert batch1.id == batch2.id @@ -77,23 +82,17 @@ async def test_idempotent_batch_creation_same_params(self, provider, sample_batc async def test_different_idempotency_keys_create_different_batches(self, provider, sample_batch_data): """Test that different idempotency keys create different batches even with same params.""" - batch1 = await provider.create_batch( - **sample_batch_data, - idempotency_key="token-A", - ) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-A")) - batch2 = await provider.create_batch( - **sample_batch_data, - idempotency_key="token-B", - ) + batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data, idempotency_key="token-B")) assert batch1.id != batch2.id async def test_non_idempotent_behavior_without_key(self, provider, sample_batch_data): """Test that batches without idempotency key create unique batches even with identical parameters.""" - batch1 = await provider.create_batch(**sample_batch_data) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - batch2 = await provider.create_batch(**sample_batch_data) + batch2 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) assert batch1.id != batch2.id assert batch1.input_file_id == batch2.input_file_id @@ -117,12 +116,12 @@ async def test_same_idempotency_key_different_params_conflict( sample_batch_data[param_name] = first_value - batch1 = await provider.create_batch(**sample_batch_data) + batch1 = await provider.create_batch(CreateBatchRequest(**sample_batch_data)) with pytest.raises(ConflictError, match="Idempotency key.*was previously used with different parameters"): sample_batch_data[param_name] = second_value - await provider.create_batch(**sample_batch_data) + await provider.create_batch(CreateBatchRequest(**sample_batch_data)) - retrieved_batch = await provider.retrieve_batch(batch1.id) + retrieved_batch = await provider.retrieve_batch(RetrieveBatchRequest(batch_id=batch1.id)) assert retrieved_batch.id == batch1.id assert getattr(retrieved_batch, param_name) == first_value diff --git a/uv.lock b/uv.lock index 8c648c3624..93ad53e67a 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" resolution-markers = [ "(python_full_version >= '3.13' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.13' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -2292,6 +2292,7 @@ name = "llama-stack-api" version = "0.4.0.dev0" source = { editable = "src/llama_stack_api" } dependencies = [ + { name = "fastapi" }, { name = "jsonschema" }, { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "opentelemetry-sdk" }, @@ -2300,6 +2301,7 @@ dependencies = [ [package.metadata] requires-dist = [ + { name = "fastapi", specifier = ">=0.115.0,<1.0" }, { name = "jsonschema" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" },