Skip to content

Commit b7f5a04

Browse files
committed
[TRTLLM-8274][feat] Check if executor is shutdown in /health entrypoint
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
1 parent 1797e91 commit b7f5a04

File tree

5 files changed

+49
-13
lines changed

5 files changed

+49
-13
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,17 @@ def shutdown(self) -> None:
766766
self.mpi_session.shutdown()
767767
self.mpi_session = None
768768

769+
def _check_health(self) -> bool:
770+
"""Check if the LLM is healthy.
771+
772+
Returns:
773+
bool: True if the executor is running and not shutdown, False otherwise.
774+
"""
775+
if hasattr(self, "_executor") and self._executor is not None:
776+
return not self._executor.is_shutdown()
777+
778+
return False
779+
769780
@staticmethod
770781
def _shutdown_wrapper(self_ref):
771782
# Retrieve the instance if it still exists

tensorrt_llm/serve/openai_server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,9 @@ def _create_response_id_not_found_error(self, response_id: str) -> Response:
236236
status_code=HTTPStatus.NOT_FOUND,
237237
)
238238

239+
def _check_health(self) -> bool:
240+
return self.llm._check_health()
241+
239242
def register_routes(self):
240243
self.app.add_api_route("/health", self.health, methods=["GET"])
241244
self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"])
@@ -296,7 +299,10 @@ def register_mm_encoder_routes(self):
296299
methods=["POST"])
297300

298301
async def health(self) -> Response:
299-
return Response(status_code=200)
302+
if self._check_health():
303+
return Response(status_code=200)
304+
else:
305+
return Response(status_code=503, content="LLM is unavailable. Please check the server logs for more details.")
300306

301307
async def health_generate(self, raw_request: Request) -> Response:
302308
"""Health check that performs a minimal generation."""

tests/integration/defs/test_e2e.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,14 @@ def test_openai_responses(llm_root, llm_venv):
16651665
str(test_root / "_test_openai_responses.py")])
16661666

16671667

1668+
def test_openai_health(llm_root, llm_venv):
1669+
test_root = unittest_path() / "llmapi" / "apps"
1670+
llm_venv.run_cmd([
1671+
"-m", "pytest",
1672+
str(test_root / "_test_openai_metrics.py -k test_health")
1673+
])
1674+
1675+
16681676
def test_openai_prometheus(llm_root, llm_venv):
16691677
test_root = unittest_path() / "llmapi" / "apps"
16701678
llm_venv.run_cmd(

tests/integration/test_lists/test-db/l0_a10.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ l0_a10:
137137
- llmapi/test_llm_e2e.py::test_llmapi_exit
138138
- llmapi/test_llm_examples.py::test_llmapi_server_example
139139
- llmapi/test_llm_examples.py::test_llmapi_kv_cache_connector[Qwen2-0.5B]
140+
- test_e2e.py::test_openai_health
140141
- test_e2e.py::test_trtllm_serve_example
141142
- test_e2e.py::test_trtllm_serve_top_logprobs[trt]
142143
- test_e2e.py::test_openai_misc_example[trt]

tests/unittest/llmapi/apps/_test_openai_metrics.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
"""Test the metrics endpoint when using OpenAI API to send requests"""
22

3+
from unittest.mock import patch
4+
35
import pytest
46
from fastapi.testclient import TestClient
5-
from transformers import AutoTokenizer
67

78
from tensorrt_llm import LLM as PyTorchLLM
8-
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig
9+
from tensorrt_llm.llmapi import KvCacheConfig
910
from tensorrt_llm.serve.openai_server import OpenAIServer
1011

1112
from ..test_llm import llama_model_path
@@ -14,26 +15,35 @@
1415

1516

1617
@pytest.fixture(scope="module")
17-
def client():
18-
build_config = BuildConfig()
19-
build_config.max_batch_size = 8
20-
build_config.max_seq_len = 512
18+
def llm():
2119
llm = PyTorchLLM(model=llama_model_path,
22-
build_config=build_config,
2320
kv_cache_config=KvCacheConfig(),
2421
enable_iter_perf_stats=True)
25-
hf_tokenizer = AutoTokenizer.from_pretrained(llama_model_path)
22+
yield llm
23+
llm.shutdown()
24+
2625

26+
@pytest.fixture(scope="module")
27+
def client(llm):
2728
app_instance = OpenAIServer(llm,
2829
model=llama_model_path,
29-
hf_tokenizer=hf_tokenizer)
30+
tool_parser=None,
31+
server_role=None,
32+
metadata_server_cfg=None)
3033
client = TestClient(app_instance.app)
3134
yield client
3235

3336

34-
def test_health(client):
35-
response = client.get("/health")
36-
assert response.status_code == 200
37+
@pytest.mark.parametrize("is_healthy,response_code", [(True, 200),
38+
(False, 503)])
39+
def test_health(client, llm, is_healthy, response_code):
40+
if not is_healthy:
41+
with patch.object(llm._executor, 'is_shutdown', return_value=True):
42+
response = client.get("/health")
43+
assert response.status_code == response_code
44+
else:
45+
response = client.get("/health")
46+
assert response.status_code == response_code
3747

3848

3949
def test_version(client):

0 commit comments

Comments
 (0)