diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index 8dd697308c9..f6e409b1ecb 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -31,7 +31,7 @@ import zmq from opentelemetry import trace -from fastdeploy.engine.request import Request, RequestOutput, RequestType +from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType from fastdeploy.engine.resource_manager import ResourceManager from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1 from fastdeploy.eplb.utils import init_eplb_signals @@ -911,6 +911,26 @@ def _insert_zmq_task_to_scheduler(self): request, insert_task = None, [] results: List[Tuple[str, Optional[str]]] = list() if data: + status_value = data.get("status", None) + if status_value is not None and status_value == RequestStatus.ABORT.value: + req_id = data["request_id"] + batch_id = self.resource_manager.req_dict[req_id] + abort_res = RequestOutput( + request_id=req_id, + finished=True, + error_code=499, + error_msg=f"Your request with request_id:{req_id} is aborted.", + ) + abort_task = self.resource_manager.tasks_list[batch_id] + is_prefill = ( + abort_task.disaggregate_info is not None + and abort_task.disaggregate_info["role"] == "prefill" + ) + self.token_processor._recycle_resources( + req_id, batch_id, abort_task, abort_res, is_prefill, True + ) + self.scheduler.put_results([abort_res]) + continue err_msg = None try: request = Request.from_dict(data) diff --git a/fastdeploy/engine/request.py b/fastdeploy/engine/request.py index 888364bb672..0940c0375a7 100644 --- a/fastdeploy/engine/request.py +++ b/fastdeploy/engine/request.py @@ -39,6 +39,7 @@ class RequestStatus(Enum): RUNNING = 1 PREEMPTED = 2 FINISHED = 3 + ABORT = 4 class RequestType(Enum): diff --git a/fastdeploy/entrypoints/engine_client.py b/fastdeploy/entrypoints/engine_client.py index 78918314509..fdfd6ff355b 100644 --- a/fastdeploy/entrypoints/engine_client.py +++ b/fastdeploy/entrypoints/engine_client.py @@ -16,6 +16,7 @@ import inspect import os +import re import time import traceback import uuid @@ -27,6 +28,7 @@ from fastdeploy import envs from fastdeploy.config import FDConfig +from fastdeploy.engine.request import RequestStatus from fastdeploy.entrypoints.openai.utils import DealerConnectionManager from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS from fastdeploy.eplb.utils import RedundantExpertWorkload @@ -824,3 +826,25 @@ async def check_redundant(self, request_dict: dict): content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list} status_code = HTTPStatus.OK return content, status_code + + async def abort(self, request_id, n=1) -> None: + if n <= 0: + api_server_logger.warning("Abort function called with non-positive n: %d. No requests aborted.", n) + return + match = re.search(r"_\d+$", request_id) + if match: + prefix = request_id[: match.start()] + else: + api_server_logger.warning( + "request_id format error: %s does not end with _. Using it as prefix.", request_id + ) + prefix = request_id + request_ids = [f"{prefix}_{i}" for i in range(n)] + for req_id in request_ids: + data = { + "request_id": req_id, + "status": RequestStatus.ABORT.value, + } + self._send_task(data) + + api_server_logger.info("Aborted request(s) %s.", ",".join(request_ids)) diff --git a/fastdeploy/entrypoints/openai/api_server.py b/fastdeploy/entrypoints/openai/api_server.py index 2eb62a6ff84..5a1ac9ee437 100644 --- a/fastdeploy/entrypoints/openai/api_server.py +++ b/fastdeploy/entrypoints/openai/api_server.py @@ -55,7 +55,11 @@ from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager -from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser +from fastdeploy.entrypoints.openai.utils import ( + UVICORN_CONFIG, + make_arg_parser, + with_cancellation, +) from fastdeploy.envs import environment_variables from fastdeploy.metrics.metrics import get_filtered_metrics from fastdeploy.metrics.metrics_middleware import PrometheusMiddleware @@ -371,7 +375,8 @@ async def wrapped_generator(): @app.post("/v1/chat/completions") -async def create_chat_completion(request: ChatCompletionRequest): +@with_cancellation +async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request): """ Create a chat completion for the provided prompt and parameters. """ @@ -403,7 +408,8 @@ async def create_chat_completion(request: ChatCompletionRequest): @app.post("/v1/completions") -async def create_completion(request: CompletionRequest): +@with_cancellation +async def create_completion(request: CompletionRequest, raw_request: Request): """ Create a completion for the provided prompt and parameters. """ diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 9bb15f90942..a61aba377df 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -174,6 +174,13 @@ async def create_chat_completion(self, request: ChatCompletionRequest): error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) + except asyncio.CancelledError as e: + await self.engine_client.abort(f"{request_id}_0", request.n) + error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}" + api_server_logger.error(error_msg) + return ErrorResponse( + error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED) + ) except Exception as e: error_msg = ( f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, " diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 93013531759..bfc61b9a09c 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -223,7 +223,19 @@ async def create_completion(self, request: CompletionRequest): ) api_server_logger.error(error_msg) return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR)) - + except asyncio.CancelledError as e: + num = 1 + if request_prompt_ids is not None: + num = len(request_prompt_ids) + elif request_prompts is not None: + num = len(request_prompts) + num = num * 1 if request.n is None else request.n + await self.engine_client.abort(f"{request_id}_0", num) + error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}" + api_server_logger.error(error_msg) + return ErrorResponse( + error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED) + ) except Exception as e: error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}" api_server_logger.error(error_msg) diff --git a/fastdeploy/entrypoints/openai/utils.py b/fastdeploy/entrypoints/openai/utils.py index 9a7fe239a9d..1faa15d5c2a 100644 --- a/fastdeploy/entrypoints/openai/utils.py +++ b/fastdeploy/entrypoints/openai/utils.py @@ -15,6 +15,7 @@ """ import asyncio +import functools import heapq import random import time @@ -22,6 +23,7 @@ import aiozmq import zmq +from fastapi import Request from fastdeploy.engine.args_utils import EngineArgs from fastdeploy.metrics.metrics import main_process_metrics @@ -253,3 +255,54 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser = EngineArgs.add_cli_args(parser) return parser + + +async def listen_for_disconnect(request: Request) -> None: + """Returns if a disconnect message is received""" + while True: + message = await request.receive() + if message["type"] == "http.disconnect": + break + + +def with_cancellation(handler_func): + """Decorator that allows a route handler to be cancelled by client + disconnections. + + This does _not_ use request.is_disconnected, which does not work with + middleware. Instead this follows the pattern from + starlette.StreamingResponse, which simultaneously awaits on two tasks- one + to wait for an http disconnect message, and the other to do the work that we + want done. When the first task finishes, the other is cancelled. + + A core assumption of this method is that the body of the request has already + been read. This is a safe assumption to make for fastapi handlers that have + already parsed the body of the request into a pydantic model for us. + This decorator is unsafe to use elsewhere, as it will consume and throw away + all incoming messages for the request while it looks for a disconnect + message. + + In the case where a `StreamingResponse` is returned by the handler, this + wrapper will stop listening for disconnects and instead the response object + will start listening for disconnects. + """ + + # Functools.wraps is required for this wrapper to appear to fastapi as a + # normal route handler, with the correct request type hinting. + @functools.wraps(handler_func) + async def wrapper(*args, **kwargs): + # The request is either the second positional arg or `raw_request` + request = args[1] if len(args) > 1 else kwargs["raw_request"] + + handler_task = asyncio.create_task(handler_func(*args, **kwargs)) + cancellation_task = asyncio.create_task(listen_for_disconnect(request)) + + done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED) + for task in pending: + task.cancel() + + if handler_task in done: + return handler_task.result() + return None + + return wrapper diff --git a/fastdeploy/output/token_processor.py b/fastdeploy/output/token_processor.py index 67ad1ce5409..9d175447549 100644 --- a/fastdeploy/output/token_processor.py +++ b/fastdeploy/output/token_processor.py @@ -451,11 +451,11 @@ def postprocess(self, batch_result: List[RequestOutput], mtype=3): except Exception as e: llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}") - def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False): + def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False, is_abort=False): """ recycle resources """ - if is_prefill: + if is_prefill and not is_abort: start_time = time.time() while True: finished_task_ids = self.engine_worker_queue.get_finished_req() diff --git a/fastdeploy/utils.py b/fastdeploy/utils.py index a0878fa7c73..1cb1bfd17ca 100644 --- a/fastdeploy/utils.py +++ b/fastdeploy/utils.py @@ -219,6 +219,7 @@ class ErrorCode(str, Enum): CONNECTION_ERROR = "connection_error" MISSING_REQUIRED_PARAMETER = "missing_required_parameter" INTERNAL_ERROR = "internal_error" + CLIENT_ABORTED = "client_aborted" class ColoredFormatter(logging.Formatter):