Skip to content

Commit ba464e6

Browse files
authored
Add ORCA endpoint load metrics support (vllm-project#24905)
Signed-off-by: Misha Efimov <mef@google.com>
1 parent 7f4bdad commit ba464e6

File tree

3 files changed

+265
-2
lines changed

3 files changed

+265
-2
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import openai
5+
import pytest
6+
import pytest_asyncio
7+
8+
from ...utils import RemoteOpenAIServer
9+
10+
# any model with a chat template should work here
11+
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
12+
13+
14+
@pytest.fixture(scope="module")
15+
def monkeypatch_module():
16+
from _pytest.monkeypatch import MonkeyPatch
17+
18+
mpatch = MonkeyPatch()
19+
yield mpatch
20+
mpatch.undo()
21+
22+
23+
@pytest.fixture(scope="module", params=[True])
24+
def server(request, monkeypatch_module):
25+
use_v1 = request.param
26+
monkeypatch_module.setenv("VLLM_USE_V1", "1" if use_v1 else "0")
27+
28+
args = [
29+
"--dtype",
30+
"bfloat16",
31+
"--max-model-len",
32+
"8192",
33+
"--enforce-eager",
34+
]
35+
36+
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
37+
yield remote_server
38+
39+
40+
@pytest_asyncio.fixture
41+
async def client(server):
42+
async with server.get_async_client() as async_client:
43+
yield async_client
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_chat_completion_with_orca_header(server: RemoteOpenAIServer):
48+
messages = [
49+
{"role": "system", "content": "you are a helpful assistant"},
50+
{"role": "user", "content": "what is 1+1?"},
51+
]
52+
53+
client = openai.OpenAI(
54+
api_key="EMPTY",
55+
base_url=f"http://localhost:{server.port}/v1",
56+
default_headers={"endpoint-load-metrics-format": "TEXT"},
57+
)
58+
59+
# 1. Use raw client to get response headers.
60+
raw_client = client.with_raw_response
61+
62+
# 2. Make the API call using the raw_client
63+
response_with_raw = raw_client.chat.completions.create(
64+
model=MODEL_NAME,
65+
messages=messages,
66+
extra_headers={"endpoint-load-metrics-format": "TEXT"},
67+
)
68+
69+
# 3. Access the raw httpx.Response object
70+
raw_http_response = response_with_raw.http_response
71+
72+
# 4. Get the headers from the httpx.Response object
73+
response_headers = raw_http_response.headers
74+
75+
assert "endpoint-load-metrics" in response_headers
76+
77+
78+
@pytest.mark.asyncio
79+
async def test_completion_with_orca_header(client: openai.AsyncOpenAI):
80+
# 1. Use raw client to get response headers.
81+
raw_client = client.with_raw_response
82+
83+
# 2. Make the API call using the raw_client
84+
completion = await raw_client.completions.create(
85+
model=MODEL_NAME,
86+
prompt="Hello, my name is",
87+
max_tokens=5,
88+
extra_headers={"endpoint-load-metrics-format": "JSON"},
89+
)
90+
91+
# 3. Access the raw httpx.Response object
92+
raw_http_response = completion.http_response
93+
94+
# 4. Get the headers from the httpx.Response object
95+
response_headers = raw_http_response.headers
96+
97+
assert "endpoint-load-metrics" in response_headers
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_single_completion(client: openai.AsyncOpenAI):
102+
completion = await client.completions.create(
103+
model=MODEL_NAME,
104+
prompt="Hello, my name is",
105+
max_tokens=5,
106+
extra_headers={"endpoint-load-metrics-format": "JSON"},
107+
temperature=0.0,
108+
)
109+
110+
assert completion.id is not None
111+
assert completion.choices is not None and len(completion.choices) == 1
112+
113+
choice = completion.choices[0]
114+
assert len(choice.text) >= 5
115+
assert choice.finish_reason == "length"
116+
assert completion.usage == openai.types.CompletionUsage(
117+
completion_tokens=5, prompt_tokens=6, total_tokens=11
118+
)
119+
120+
# test using token IDs
121+
completion = await client.completions.create(
122+
model=MODEL_NAME,
123+
prompt=[0, 0, 0, 0, 0],
124+
max_tokens=5,
125+
temperature=0.0,
126+
)
127+
assert len(completion.choices[0].text) >= 1
128+
assert completion.choices[0].prompt_logprobs is None

vllm/entrypoints/openai/api_server.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
from vllm.entrypoints.launcher import serve_http
5252
from vllm.entrypoints.logger import RequestLogger
5353
from vllm.entrypoints.openai.cli_args import make_arg_parser, validate_parsed_serve_args
54+
from vllm.entrypoints.openai.orca_metrics import metrics_header
5455
from vllm.entrypoints.openai.protocol import (
5556
ChatCompletionRequest,
5657
ChatCompletionResponse,
@@ -128,6 +129,8 @@
128129
# Cannot use __name__ (https://github.com/vllm-project/vllm/pull/4765)
129130
logger = init_logger("vllm.entrypoints.openai.api_server")
130131

132+
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL = "endpoint-load-metrics-format"
133+
131134
_running_tasks: set[asyncio.Task] = set()
132135

133136

@@ -672,6 +675,9 @@ def translate_error_response(response: ErrorResponse) -> JSONResponse:
672675
@with_cancellation
673676
@load_aware_call
674677
async def create_chat_completion(request: ChatCompletionRequest, raw_request: Request):
678+
metrics_header_format = raw_request.headers.get(
679+
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
680+
)
675681
handler = chat(raw_request)
676682
if handler is None:
677683
return base(raw_request).create_error_response(
@@ -689,7 +695,10 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
689695
)
690696

691697
elif isinstance(generator, ChatCompletionResponse):
692-
return JSONResponse(content=generator.model_dump())
698+
return JSONResponse(
699+
content=generator.model_dump(),
700+
headers=metrics_header(metrics_header_format),
701+
)
693702

694703
return StreamingResponse(content=generator, media_type="text/event-stream")
695704

@@ -707,6 +716,9 @@ async def create_chat_completion(request: ChatCompletionRequest, raw_request: Re
707716
@with_cancellation
708717
@load_aware_call
709718
async def create_completion(request: CompletionRequest, raw_request: Request):
719+
metrics_header_format = raw_request.headers.get(
720+
ENDPOINT_LOAD_METRICS_FORMAT_HEADER_LABEL, ""
721+
)
710722
handler = completion(raw_request)
711723
if handler is None:
712724
return base(raw_request).create_error_response(
@@ -729,7 +741,10 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
729741
content=generator.model_dump(), status_code=generator.error.code
730742
)
731743
elif isinstance(generator, CompletionResponse):
732-
return JSONResponse(content=generator.model_dump())
744+
return JSONResponse(
745+
content=generator.model_dump(),
746+
headers=metrics_header(metrics_header_format),
747+
)
733748

734749
return StreamingResponse(content=generator, media_type="text/event-stream")
735750

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
Utility functions that create ORCA endpoint load report response headers.
5+
"""
6+
7+
import json
8+
from collections.abc import Mapping
9+
10+
from vllm.logger import init_logger
11+
from vllm.v1.metrics.reader import Gauge, get_metrics_snapshot
12+
13+
logger = init_logger(__name__)
14+
15+
16+
def create_orca_header(
17+
metrics_format: str, named_metrics: list[tuple[str, float]]
18+
) -> Mapping[str, str] | None:
19+
"""
20+
Creates ORCA headers named 'endpoint-load-metrics' in the specified format
21+
and adds custom metrics to named_metrics.
22+
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
23+
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
24+
25+
Parameters:
26+
- metrics_format (str): The format of the header ('TEXT', 'JSON').
27+
- named_metrics (List[Tuple[str, float]]): List of tuples with metric names
28+
and their corresponding double values.
29+
30+
Returns:
31+
- Optional[Mapping[str,str]]: A dictionary with header key as
32+
'endpoint-load-metrics' and values as the ORCA header strings with
33+
format prefix and data in with named_metrics in.
34+
"""
35+
36+
if metrics_format.lower() not in ["text", "json"]:
37+
logger.warning(
38+
"Warning: `%s` format is not supported in the ORCA response header",
39+
format,
40+
)
41+
return None
42+
43+
header = {}
44+
orca_report = {
45+
"named_metrics": {
46+
metric_name: value
47+
for metric_name, value in named_metrics
48+
if isinstance(metric_name, str) and isinstance(value, float)
49+
}
50+
}
51+
# output example:
52+
# endpoint-load-metrics: TEXT named_metrics.kv_cache_utilization=0.4
53+
if metrics_format.lower() == "text":
54+
native_http_header = ", ".join(
55+
[
56+
f"named_metrics.{metric_name}={value}"
57+
for metric_name, value in named_metrics
58+
if isinstance(metric_name, str) and isinstance(value, float)
59+
]
60+
)
61+
header["endpoint-load-metrics"] = f"TEXT {native_http_header}"
62+
63+
# output example:
64+
# endpoint-load-metrics: JSON “named_metrics”: {“custom-metric-util”: 0.4}
65+
elif metrics_format.lower() == "json":
66+
header["endpoint-load-metrics"] = f"JSON {json.dumps(orca_report)}"
67+
68+
logger.info("Created ORCA header %s", header)
69+
70+
return header
71+
72+
73+
def get_named_metrics_from_prometheus() -> list[tuple[str, float]]:
74+
"""
75+
Collects current metrics from Prometheus and returns some of them
76+
in the form of the `named_metrics` list for `create_orca_header()`.
77+
78+
Parameters:
79+
- None
80+
81+
Returns:
82+
- list[tuple[str, float]]: List of tuples of metric names and their values.
83+
"""
84+
named_metrics: list[tuple[str, float]] = []
85+
# Map from prometheus metric names to ORCA named metrics.
86+
prometheus_to_orca_metrics = {
87+
"vllm:kv_cache_usage_perc": "kv_cache_usage_perc",
88+
"vllm:num_requests_waiting": "num_requests_waiting",
89+
}
90+
metrics = get_metrics_snapshot()
91+
for metric in metrics:
92+
orca_name = prometheus_to_orca_metrics.get(metric.name)
93+
# If this metric is mapped into ORCA, then add it to the report.
94+
# Note: Only Gauge metrics are currently supported.
95+
if orca_name is not None and isinstance(metric, Gauge):
96+
named_metrics.append((str(orca_name), float(metric.value)))
97+
return named_metrics
98+
99+
100+
def metrics_header(metrics_format: str) -> Mapping[str, str] | None:
101+
"""
102+
Creates ORCA headers named 'endpoint-load-metrics' in the specified format.
103+
Metrics are collected from Prometheus using `get_named_metrics_from_prometheus()`.
104+
105+
ORCA headers format description: https://docs.google.com/document/d/1C1ybMmDKJIVlrbOLbywhu9iRYo4rilR-cT50OTtOFTs/edit?tab=t.0
106+
ORCA proto https://github.com/cncf/xds/blob/main/xds/data/orca/v3/orca_load_report.proto
107+
108+
Parameters:
109+
- metrics_format (str): The format of the header ('TEXT', 'JSON').
110+
111+
Returns:
112+
- Optional[Mapping[str,str]]: A dictionary with header key as
113+
'endpoint-load-metrics' and values as the ORCA header strings with
114+
format prefix and data in with named_metrics in.
115+
"""
116+
if not metrics_format:
117+
return None
118+
# Get named metrics from prometheus.
119+
named_metrics = get_named_metrics_from_prometheus()
120+
return create_orca_header(metrics_format, named_metrics)

0 commit comments

Comments
 (0)