From 5763320d7f50752514dd8ea0bc04f70fb0d13e6b Mon Sep 17 00:00:00 2001 From: wenlei07 <1522419171@qq.com> Date: Thu, 27 Nov 2025 07:34:01 +0000 Subject: [PATCH] request disconnect --- fastdeploy/entrypoints/openai/api_server.py | 9 +++- fastdeploy/entrypoints/openai/utils.py | 60 +++++++++++++++++++++ 2 files changed, 67 insertions(+), 2 deletions(-) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 6517e5da136..22fe76bc46d 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -56,7 +56,11 @@ from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager -from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser +from fastdeploy.entrypoints.openai.utils import ( + UVICORN_CONFIG, + make_arg_parser, + with_cancellation, +) from fastdeploy.envs import environment_variables from fastdeploy.metrics.metrics import ( EXCLUDE_LABELS, @@ -384,7 +388,8 @@ async def wrapped_generator(): @app.post("/v1/chat/completions") -async def create_chat_completion(request: ChatCompletionRequest): +@with_cancellation +async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): """ Create a chat completion for the provided prompt and parameters. """ diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index 350f9e61f1b..041091c06fb 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -15,12 +15,14 @@ """ import asyncio +import functools import heapq import random import aiozmq import msgpack import zmq +from fastapi import Request from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.utils import FlexibleArgumentParser, api_server_logger @@ -243,3 +245,61 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser = EngineArgs.add_cli_args(parser) return parser + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + # If load tracking is enabled *and* the counter exists, decrement + # it. Combines the previous nested checks into a single condition + # to satisfy the linter rule. + if getattr(request.app.state, "enable_server_load_tracking", False) and hasattr( + request.app.state, "server_load_metrics" + ): + request.app.state.server_load_metrics -= 1 + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects. + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["raw_request"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper