Skip to content

Commit 6909e17

Browse files
committed
[TRTLLM-8274][feat] Check if executor is shutdown in /health entrypoint
Signed-off-by: Junyi Xu <219237550+JunyiXu-nv@users.noreply.github.com>
1 parent 548f5ce commit 6909e17

File tree

3 files changed

+28
-7
lines changed

3 files changed

+28
-7
lines changed

tensorrt_llm/llmapi/llm.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -766,6 +766,12 @@ def shutdown(self) -> None:
766766
self.mpi_session.shutdown()
767767
self.mpi_session = None
768768

769+
def check_health(self) -> bool:
770+
if hasattr(self, "_executor") and self._executor is not None:
771+
return not self._executor.is_shutdown()
772+
773+
return False
774+
769775
@staticmethod
770776
def _shutdown_wrapper(self_ref):
771777
# Retrieve the instance if it still exists

tensorrt_llm/serve/openai_server.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,6 +233,9 @@ def _create_response_id_not_found_error(self, response_id: str) -> Response:
233233
status_code=HTTPStatus.NOT_FOUND,
234234
)
235235

236+
def _check_health(self) -> bool:
237+
return self.llm.check_health()
238+
236239
def register_routes(self):
237240
self.app.add_api_route("/health", self.health, methods=["GET"])
238241
self.app.add_api_route("/health_generate", self.health_generate, methods=["GET"])
@@ -293,7 +296,10 @@ def register_mm_encoder_routes(self):
293296
methods=["POST"])
294297

295298
async def health(self) -> Response:
296-
return Response(status_code=200)
299+
if self._check_health():
300+
return Response(status_code=200)
301+
else:
302+
return Response(status_code=503, content="LLM is unavailable. Please check the server logs for more details.")
297303

298304
async def health_generate(self, raw_request: Request) -> Response:
299305
"""Health check that performs a minimal generation."""

tests/unittest/llmapi/apps/_test_openai_metrics.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import pytest
44
from fastapi.testclient import TestClient
5-
from transformers import AutoTokenizer
65

76
from tensorrt_llm import LLM as PyTorchLLM
87
from tensorrt_llm.llmapi import BuildConfig, KvCacheConfig
@@ -14,26 +13,36 @@
1413

1514

1615
@pytest.fixture(scope="module")
17-
def client():
16+
def llm():
1817
build_config = BuildConfig()
1918
build_config.max_batch_size = 8
2019
build_config.max_seq_len = 512
2120
llm = PyTorchLLM(model=llama_model_path,
2221
build_config=build_config,
2322
kv_cache_config=KvCacheConfig(),
2423
enable_iter_perf_stats=True)
25-
hf_tokenizer = AutoTokenizer.from_pretrained(llama_model_path)
24+
yield llm
25+
llm.shutdown()
2626

27+
28+
@pytest.fixture(scope="module")
29+
def client(llm):
2730
app_instance = OpenAIServer(llm,
2831
model=llama_model_path,
29-
hf_tokenizer=hf_tokenizer)
32+
tool_parser=None,
33+
server_role=None,
34+
metadata_server_cfg=None)
3035
client = TestClient(app_instance.app)
3136
yield client
3237

3338

34-
def test_health(client):
39+
@pytest.mark.parametrize("is_healthy,response_code", [(True, 200),
40+
(False, 503)])
41+
def test_health(client, llm, is_healthy, response_code):
42+
if not is_healthy:
43+
llm.shutdown()
3544
response = client.get("/health")
36-
assert response.status_code == 200
45+
assert response.status_code == response_code
3746

3847

3948
def test_version(client):

0 commit comments

Comments
 (0)