Skip to content

Commit b35f2ab

Browse files
committed
test: add HypercornServer and TestServer class for testing
1 parent ef8da55 commit b35f2ab

File tree

1 file changed

+139
-2
lines changed

1 file changed

+139
-2
lines changed

tests/app/tool.py

Lines changed: 139 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,23 @@
22

33
from contextlib import AsyncExitStack
44
from dataclasses import dataclass
5-
from typing import Any, TypedDict, Union
5+
from typing import Any, Literal, TypedDict, Union
66

77
import anyio
88
import httpx
9+
import sniffio
910
import uvicorn
1011
from fastapi import FastAPI
12+
from hypercorn import Config as HyperConfig
13+
from hypercorn.asyncio.run import (
14+
worker_serve as hyper_aio_serve, # pyright: ignore[reportUnknownVariableType]
15+
)
16+
from hypercorn.trio.run import (
17+
worker_serve as hyper_trio_serve, # pyright: ignore[reportUnknownVariableType]
18+
)
19+
from hypercorn.utils import (
20+
wrap_app as hyper_wrap_app, # pyright: ignore[reportUnknownVariableType]
21+
)
1122
from starlette.requests import Request
1223
from starlette.websockets import WebSocket
1324
from typing_extensions import Self
@@ -65,7 +76,7 @@ async def __aenter__(self) -> Self:
6576
async def __aexit__(self, *_: Any, **__: Any) -> None:
6677
"""Shutdown the server."""
6778
# 在 uvicorn.Server 的实现中,设置 should_exit 可以使得 server 任务结束
68-
assert self.should_exit is False, "The server has already exited."
79+
assert not self.should_exit, "The server has already exited."
6980
self.should_exit = True
7081
await self._exit_stack.__aexit__(*_, **__)
7182

@@ -86,3 +97,129 @@ def contx_socket_url(self) -> httpx.URL:
8697
scheme="https" if config.is_ssl else "http",
8798
path="/",
8899
)
100+
101+
102+
class HypercornServer:
103+
"""An AsyncContext to launch and shutdown Hypercorn server automatically."""
104+
105+
def __init__(self, app: FastAPI, config: HyperConfig): # noqa: D107
106+
self.config = config
107+
self.app = app
108+
self.should_exit = anyio.Event()
109+
110+
async def __aenter__(self) -> Self:
111+
"""Launch the server."""
112+
self._sockets = self.config.create_sockets()
113+
self._exit_stack = AsyncExitStack()
114+
115+
self.current_async_lib = sniffio.current_async_library()
116+
117+
if self.current_async_lib == "asyncio":
118+
serve_func = hyper_aio_serve # pyright: ignore[reportUnknownVariableType]
119+
elif self.current_async_lib == "trio":
120+
serve_func = hyper_trio_serve # pyright: ignore[reportUnknownVariableType]
121+
else:
122+
raise RuntimeError(f"Unsupported async library {self.current_async_lib!r}")
123+
124+
async def serve() -> None:
125+
# Implement ref:
126+
# https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/__init__.py#L12-L46
127+
# https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/__init__.py#L14-L52
128+
await serve_func(
129+
hyper_wrap_app(
130+
self.app, # pyright: ignore[reportArgumentType]
131+
self.config.wsgi_max_body_size,
132+
mode=None,
133+
),
134+
self.config,
135+
shutdown_trigger=self.should_exit.wait,
136+
)
137+
138+
task_group = await self._exit_stack.enter_async_context(
139+
anyio.create_task_group()
140+
)
141+
task_group.start_soon(serve, name=f"Hypercorn Server Task of {self}")
142+
return self
143+
144+
async def __aexit__(self, *_: Any, **__: Any) -> None:
145+
"""Shutdown the server."""
146+
assert not self.should_exit.is_set(), "The server has already exited."
147+
self.should_exit.set()
148+
await self._exit_stack.__aexit__(*_, **__)
149+
150+
@property
151+
def contx_socket_url(self) -> httpx.URL:
152+
"""If server is tcp socket, return the url of server.
153+
154+
Note: The path of url is explicitly set to "/".
155+
"""
156+
config = self.config
157+
158+
bind = config.bind[0]
159+
if bind.startswith(("unix:", "fd://")):
160+
raise RuntimeError("Only support tcp socket.")
161+
162+
# refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
163+
host, port = config.bind[0].split(":")
164+
port = int(port)
165+
166+
return httpx.URL(
167+
host=host,
168+
port=port,
169+
scheme="https" if config.ssl_enabled else "http",
170+
path="/",
171+
)
172+
173+
174+
class TestServer:
175+
"""An AsyncContext to launch and shutdown Hypercorn or Uvicorn server automatically."""
176+
177+
def __init__(
178+
self,
179+
app: FastAPI,
180+
host: str,
181+
port: int,
182+
server_type: Literal["uvicorn", "hypercorn"] = "hypercorn",
183+
):
184+
"""Only support ipv4 address.
185+
186+
If use uvicorn, it only support asyncio backend.
187+
"""
188+
self.app = app
189+
self.host = host
190+
self.port = port
191+
self.server_type = server_type
192+
193+
if self.server_type == "hypercorn":
194+
config = HyperConfig()
195+
config.bind = f"{host}:{port}"
196+
197+
self.config = config
198+
self.server = HypercornServer(app, config)
199+
else:
200+
self.config = uvicorn.Config(app, host=host, port=port)
201+
self.server = UvicornServer(self.config)
202+
203+
async def __aenter__(self) -> Self:
204+
"""Launch the server."""
205+
if (
206+
self.server_type == "uvicorn"
207+
and sniffio.current_async_library() != "asyncio"
208+
):
209+
raise RuntimeError("Uvicorn server does not support trio backend.")
210+
211+
self._exit_stack = AsyncExitStack()
212+
await self._exit_stack.enter_async_context(self.server)
213+
return self
214+
215+
async def __aexit__(self, *_: Any, **__: Any) -> None:
216+
"""Shutdown the server."""
217+
await self._exit_stack.__aexit__(*_, **__)
218+
219+
@property
220+
def contx_socket_url(self) -> httpx.URL:
221+
"""If server is tcp socket, return the url of server.
222+
223+
Note: The path of url is explicitly set to "/".
224+
"""
225+
return self.server.contx_socket_url

0 commit comments

Comments
 (0)