Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion fastdeploy/engine/common_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import zmq
from opentelemetry import trace

from fastdeploy.engine.request import Request, RequestOutput, RequestType
from fastdeploy.engine.request import Request, RequestOutput, RequestStatus, RequestType
from fastdeploy.engine.resource_manager import ResourceManager
from fastdeploy.engine.sched.resource_manager_v1 import ResourceManagerV1
from fastdeploy.eplb.utils import init_eplb_signals
Expand Down Expand Up @@ -911,6 +911,26 @@ def _insert_zmq_task_to_scheduler(self):
request, insert_task = None, []
results: List[Tuple[str, Optional[str]]] = list()
if data:
status_value = data.get("status", None)
if status_value is not None and status_value == RequestStatus.ABORT.value:
req_id = data["request_id"]
batch_id = self.resource_manager.req_dict[req_id]
abort_res = RequestOutput(
request_id=req_id,
finished=True,
error_code=499,
error_msg=f"Your request with request_id:{req_id} is aborted.",
)
abort_task = self.resource_manager.tasks_list[batch_id]
is_prefill = (
abort_task.disaggregate_info is not None
and abort_task.disaggregate_info["role"] == "prefill"
)
self.token_processor._recycle_resources(
req_id, batch_id, abort_task, abort_res, is_prefill, True
)
self.scheduler.put_results([abort_res])
continue
err_msg = None
try:
request = Request.from_dict(data)
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/engine/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class RequestStatus(Enum):
RUNNING = 1
PREEMPTED = 2
FINISHED = 3
ABORT = 4


class RequestType(Enum):
Expand Down
24 changes: 24 additions & 0 deletions fastdeploy/entrypoints/engine_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import inspect
import os
import re
import time
import traceback
import uuid
Expand All @@ -27,6 +28,7 @@

from fastdeploy import envs
from fastdeploy.config import FDConfig
from fastdeploy.engine.request import RequestStatus
from fastdeploy.entrypoints.openai.utils import DealerConnectionManager
from fastdeploy.envs import FD_SUPPORT_MAX_CONNECTIONS
from fastdeploy.eplb.utils import RedundantExpertWorkload
Expand Down Expand Up @@ -824,3 +826,25 @@ async def check_redundant(self, request_dict: dict):
content = {"code": 0, "msg": "ok", "data": update_weight_from_disk_list}
status_code = HTTPStatus.OK
return content, status_code

async def abort(self, request_id, n=1) -> None:
if n <= 0:
api_server_logger.warning("Abort function called with non-positive n: %d. No requests aborted.", n)
return
match = re.search(r"_\d+$", request_id)
if match:
prefix = request_id[: match.start()]
else:
api_server_logger.warning(
"request_id format error: %s does not end with _<number>. Using it as prefix.", request_id
)
prefix = request_id
request_ids = [f"{prefix}_{i}" for i in range(n)]
for req_id in request_ids:
data = {
"request_id": req_id,
"status": RequestStatus.ABORT.value,
}
self._send_task(data)

api_server_logger.info("Aborted request(s) %s.", ",".join(request_ids))
12 changes: 9 additions & 3 deletions fastdeploy/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@
from fastdeploy.entrypoints.openai.serving_models import ModelPath, OpenAIServingModels
from fastdeploy.entrypoints.openai.serving_reward import OpenAIServingReward
from fastdeploy.entrypoints.openai.tool_parsers import ToolParserManager
from fastdeploy.entrypoints.openai.utils import UVICORN_CONFIG, make_arg_parser
from fastdeploy.entrypoints.openai.utils import (
UVICORN_CONFIG,
make_arg_parser,
with_cancellation,
)
from fastdeploy.envs import environment_variables
from fastdeploy.metrics.metrics import get_filtered_metrics
from fastdeploy.metrics.metrics_middleware import PrometheusMiddleware
Expand Down Expand Up @@ -371,7 +375,8 @@ async def wrapped_generator():


@app.post("/v1/chat/completions")
async def create_chat_completion(request: ChatCompletionRequest):
@with_cancellation
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
"""
Create a chat completion for the provided prompt and parameters.
"""
Expand Down Expand Up @@ -403,7 +408,8 @@ async def create_chat_completion(request: ChatCompletionRequest):


@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
@with_cancellation
async def create_completion(request: CompletionRequest, raw_request: Request):
"""
Create a completion for the provided prompt and parameters.
"""
Expand Down
7 changes: 7 additions & 0 deletions fastdeploy/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ async def create_chat_completion(self, request: ChatCompletionRequest):
error_msg = f"request[{request_id}]full generator error: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR))
except asyncio.CancelledError as e:
await self.engine_client.abort(f"{request_id}_0", request.n)
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(
error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED)
)
except Exception as e:
error_msg = (
f"request[{request_id}] waiting error: {str(e)}, {str(traceback.format_exc())}, "
Expand Down
14 changes: 13 additions & 1 deletion fastdeploy/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,19 @@ async def create_completion(self, request: CompletionRequest):
)
api_server_logger.error(error_msg)
return ErrorResponse(error=ErrorInfo(message=error_msg, type=ErrorType.INTERNAL_ERROR))

except asyncio.CancelledError as e:
num = 1
if request_prompt_ids is not None:
num = len(request_prompt_ids)
elif request_prompts is not None:
num = len(request_prompts)
num = num * 1 if request.n is None else request.n
await self.engine_client.abort(f"{request_id}_0", num)
error_msg = f"request[{request_id}_0] client disconnected: {str(e)}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
return ErrorResponse(
error=ErrorInfo(message=error_msg, type=ErrorType.INVALID_REQUEST_ERROR, code=ErrorCode.CLIENT_ABORTED)
)
except Exception as e:
error_msg = f"OpenAIServingCompletion create_completion error: {e}, {str(traceback.format_exc())}"
api_server_logger.error(error_msg)
Expand Down
53 changes: 53 additions & 0 deletions fastdeploy/entrypoints/openai/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
"""

import asyncio
import functools
import heapq
import random
import time
from multiprocessing.reduction import ForkingPickler

import aiozmq
import zmq
from fastapi import Request

from fastdeploy.engine.args_utils import EngineArgs
from fastdeploy.metrics.metrics import main_process_metrics
Expand Down Expand Up @@ -253,3 +255,54 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:

parser = EngineArgs.add_cli_args(parser)
return parser


async def listen_for_disconnect(request: Request) -> None:
"""Returns if a disconnect message is received"""
while True:
message = await request.receive()
if message["type"] == "http.disconnect":
break


def with_cancellation(handler_func):
"""Decorator that allows a route handler to be cancelled by client
disconnections.

This does _not_ use request.is_disconnected, which does not work with
middleware. Instead this follows the pattern from
starlette.StreamingResponse, which simultaneously awaits on two tasks- one
to wait for an http disconnect message, and the other to do the work that we
want done. When the first task finishes, the other is cancelled.

A core assumption of this method is that the body of the request has already
been read. This is a safe assumption to make for fastapi handlers that have
already parsed the body of the request into a pydantic model for us.
This decorator is unsafe to use elsewhere, as it will consume and throw away
all incoming messages for the request while it looks for a disconnect
message.

In the case where a `StreamingResponse` is returned by the handler, this
wrapper will stop listening for disconnects and instead the response object
will start listening for disconnects.
"""

# Functools.wraps is required for this wrapper to appear to fastapi as a
# normal route handler, with the correct request type hinting.
@functools.wraps(handler_func)
async def wrapper(*args, **kwargs):
# The request is either the second positional arg or `raw_request`
request = args[1] if len(args) > 1 else kwargs["raw_request"]

handler_task = asyncio.create_task(handler_func(*args, **kwargs))
cancellation_task = asyncio.create_task(listen_for_disconnect(request))

done, pending = await asyncio.wait([handler_task, cancellation_task], return_when=asyncio.FIRST_COMPLETED)
for task in pending:
task.cancel()

if handler_task in done:
return handler_task.result()
return None

return wrapper
4 changes: 2 additions & 2 deletions fastdeploy/output/token_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -451,11 +451,11 @@ def postprocess(self, batch_result: List[RequestOutput], mtype=3):
except Exception as e:
llm_logger.error(f"Error in TokenProcessor's postprocess: {e}, {str(traceback.format_exc())}")

def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False):
def _recycle_resources(self, task_id, index, task, result=None, is_prefill=False, is_abort=False):
"""
recycle resources
"""
if is_prefill:
if is_prefill and not is_abort:
start_time = time.time()
while True:
finished_task_ids = self.engine_worker_queue.get_finished_req()
Expand Down
1 change: 1 addition & 0 deletions fastdeploy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ class ErrorCode(str, Enum):
CONNECTION_ERROR = "connection_error"
MISSING_REQUIRED_PARAMETER = "missing_required_parameter"
INTERNAL_ERROR = "internal_error"
CLIENT_ABORTED = "client_aborted"


class ColoredFormatter(logging.Formatter):
Expand Down
Loading