Skip to content

Commit 4a7f4b9

Browse files
committed
test: add trio tests and use Hypercorn as the server
1 parent b35f2ab commit 4a7f4b9

File tree

4 files changed

+121
-77
lines changed

4 files changed

+121
-77
lines changed

src/fastapi_proxy_lib/core/websocket.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
DEFAULT_MAX_MESSAGE_SIZE_BYTES,
4343
DEFAULT_QUEUE_SIZE,
4444
)
45-
except ImportError:
45+
except ImportError: # pragma: no cover
4646
# ref: https://github.com/frankie567/httpx-ws/blob/b2135792141b71551b022ff0d76542a0263a890c/httpx_ws/_api.py#L31-L34
4747
DEFAULT_KEEPALIVE_PING_TIMEOUT_SECONDS = ( # pyright: ignore[reportConstantRedefinition]
4848
20.0

tests/app/tool.py

Lines changed: 66 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from contextlib import AsyncExitStack
44
from dataclasses import dataclass
5-
from typing import Any, Literal, TypedDict, Union
5+
from typing import Any, Literal, Optional, TypedDict, Union
66

77
import anyio
88
import httpx
@@ -11,17 +11,20 @@
1111
from fastapi import FastAPI
1212
from hypercorn import Config as HyperConfig
1313
from hypercorn.asyncio.run import (
14-
worker_serve as hyper_aio_serve, # pyright: ignore[reportUnknownVariableType]
14+
worker_serve as hyper_aio_worker_serve, # pyright: ignore[reportUnknownVariableType]
1515
)
1616
from hypercorn.trio.run import (
17-
worker_serve as hyper_trio_serve, # pyright: ignore[reportUnknownVariableType]
17+
worker_serve as hyper_trio_worker_serve, # pyright: ignore[reportUnknownVariableType]
18+
)
19+
from hypercorn.utils import (
20+
repr_socket_addr, # pyright: ignore[reportUnknownVariableType]
1821
)
1922
from hypercorn.utils import (
2023
wrap_app as hyper_wrap_app, # pyright: ignore[reportUnknownVariableType]
2124
)
2225
from starlette.requests import Request
2326
from starlette.websockets import WebSocket
24-
from typing_extensions import Self
27+
from typing_extensions import Self, assert_never
2528

2629
ServerRecvRequestsTypes = Union[Request, WebSocket]
2730

@@ -55,7 +58,7 @@ def get_request(self) -> ServerRecvRequestsTypes:
5558
return server_recv_request
5659

5760

58-
class UvicornServer(uvicorn.Server):
61+
class _UvicornServer(uvicorn.Server):
5962
"""subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically."""
6063

6164
async def __aenter__(self) -> Self:
@@ -89,7 +92,9 @@ def contx_socket_url(self) -> httpx.URL:
8992
config = self.config
9093
if config.fd is not None or config.uds is not None:
9194
raise RuntimeError("Only support tcp socket.")
92-
# refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
95+
96+
# Implement ref:
97+
# https://github.com/encode/uvicorn/blob/a2219eb2ed2bbda4143a0fb18c4b0578881b1ae8/uvicorn/server.py#L201-L220
9398
host, port = self._socket.getsockname()[:2]
9499
return httpx.URL(
95100
host=host,
@@ -99,25 +104,42 @@ def contx_socket_url(self) -> httpx.URL:
99104
)
100105

101106

102-
class HypercornServer:
107+
class _HypercornServer:
103108
"""An AsyncContext to launch and shutdown Hypercorn server automatically."""
104109

105-
def __init__(self, app: FastAPI, config: HyperConfig): # noqa: D107
110+
def __init__(self, app: FastAPI, config: HyperConfig):
106111
self.config = config
107112
self.app = app
108113
self.should_exit = anyio.Event()
109114

110115
async def __aenter__(self) -> Self:
111116
"""Launch the server."""
112-
self._sockets = self.config.create_sockets()
113117
self._exit_stack = AsyncExitStack()
114118

115119
self.current_async_lib = sniffio.current_async_library()
116120

117121
if self.current_async_lib == "asyncio":
118-
serve_func = hyper_aio_serve # pyright: ignore[reportUnknownVariableType]
122+
serve_func = ( # pyright: ignore[reportUnknownVariableType]
123+
hyper_aio_worker_serve
124+
)
125+
126+
# Implement ref:
127+
# https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/run.py#L89-L90
128+
self._sockets = self.config.create_sockets()
129+
119130
elif self.current_async_lib == "trio":
120-
serve_func = hyper_trio_serve # pyright: ignore[reportUnknownVariableType]
131+
serve_func = ( # pyright: ignore[reportUnknownVariableType]
132+
hyper_trio_worker_serve
133+
)
134+
135+
# Implement ref:
136+
# https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/run.py#L51-L56
137+
self._sockets = self.config.create_sockets()
138+
for sock in self._sockets.secure_sockets:
139+
sock.listen(self.config.backlog)
140+
for sock in self._sockets.insecure_sockets:
141+
sock.listen(self.config.backlog)
142+
121143
else:
122144
raise RuntimeError(f"Unsupported async library {self.current_async_lib!r}")
123145

@@ -133,6 +155,7 @@ async def serve() -> None:
133155
),
134156
self.config,
135157
shutdown_trigger=self.should_exit.wait,
158+
sockets=self._sockets,
136159
)
137160

138161
task_group = await self._exit_stack.enter_async_context(
@@ -154,13 +177,32 @@ def contx_socket_url(self) -> httpx.URL:
154177
Note: The path of url is explicitly set to "/".
155178
"""
156179
config = self.config
180+
sockets = self._sockets
181+
182+
# Implement ref:
183+
# https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/run.py#L112-L149
184+
# https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/run.py#L61-L82
185+
186+
# We only run on one socket each time,
187+
# so we raise `RuntimeError` to avoid other unknown errors during testing.
188+
if sockets.insecure_sockets:
189+
if len(sockets.insecure_sockets) > 1:
190+
raise RuntimeError("Hypercorn test: Multiple insecure_sockets found.")
191+
socket = sockets.insecure_sockets[0]
192+
elif sockets.secure_sockets:
193+
if len(sockets.secure_sockets) > 1:
194+
raise RuntimeError("Hypercorn test: secure_sockets sockets found.")
195+
socket = sockets.secure_sockets[0]
196+
else:
197+
raise RuntimeError("Hypercorn test: No socket found.")
157198

158-
bind = config.bind[0]
199+
bind = repr_socket_addr(socket.family, socket.getsockname())
159200
if bind.startswith(("unix:", "fd://")):
160201
raise RuntimeError("Only support tcp socket.")
161202

162-
# refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
163-
host, port = config.bind[0].split(":")
203+
# Implement ref:
204+
# https://docs.python.org/zh-cn/3/library/socket.html#socket-families
205+
host, port = bind.split(":")
164206
port = int(port)
165207

166208
return httpx.URL(
@@ -179,12 +221,16 @@ def __init__(
179221
app: FastAPI,
180222
host: str,
181223
port: int,
182-
server_type: Literal["uvicorn", "hypercorn"] = "hypercorn",
224+
server_type: Optional[Literal["uvicorn", "hypercorn"]] = None,
183225
):
184226
"""Only support ipv4 address.
185227
186228
If use uvicorn, it only support asyncio backend.
229+
230+
If `host` == 0, then use random port.
187231
"""
232+
server_type = server_type if server_type is not None else "hypercorn"
233+
188234
self.app = app
189235
self.host = host
190236
self.port = port
@@ -195,10 +241,12 @@ def __init__(
195241
config.bind = f"{host}:{port}"
196242

197243
self.config = config
198-
self.server = HypercornServer(app, config)
199-
else:
244+
self.server = _HypercornServer(app, config)
245+
elif self.server_type == "uvicorn":
200246
self.config = uvicorn.Config(app, host=host, port=port)
201-
self.server = UvicornServer(self.config)
247+
self.server = _UvicornServer(self.config)
248+
else:
249+
assert_never(self.server_type)
202250

203251
async def __aenter__(self) -> Self:
204252
"""Launch the server."""

tests/conftest.py

Lines changed: 32 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
Callable,
1616
Coroutine,
1717
Literal,
18+
Optional,
1819
Protocol,
1920
)
2021

2122
import pytest
22-
import uvicorn
2323
from asgi_lifespan import LifespanManager
24+
from fastapi import FastAPI
2425
from fastapi_proxy_lib.fastapi.app import (
2526
forward_http_app,
2627
reverse_http_app,
@@ -30,7 +31,7 @@
3031

3132
from .app.echo_http_app import get_app as get_http_test_app
3233
from .app.echo_ws_app import get_app as get_ws_test_app
33-
from .app.tool import AppDataclass4Test, UvicornServer
34+
from .app.tool import AppDataclass4Test, TestServer
3435

3536
# ASGI types.
3637
# Copied from: https://github.com/florimondmanca/asgi-lifespan/blob/fbb0f440337314be97acaae1a3c0c7a2ec8298dd/src/asgi_lifespan/_types.py
@@ -61,17 +62,28 @@ class LifeAppDataclass4Test(AppDataclass4Test):
6162
"""The lifespan of app will be managed automatically by pytest."""
6263

6364

64-
class UvicornServerFixture(Protocol): # noqa: D101
65+
class TestServerFixture(Protocol): # noqa: D101
6566
def __call__( # noqa: D102
66-
self, config: uvicorn.Config
67-
) -> Coroutine[None, None, UvicornServer]: ...
67+
self,
68+
app: FastAPI,
69+
host: str,
70+
port: int,
71+
server_type: Optional[Literal["uvicorn", "hypercorn"]] = None,
72+
) -> Coroutine[None, None, TestServer]: ...
6873

6974

7075
# https://anyio.readthedocs.io/en/stable/testing.html#specifying-the-backends-to-run-on
71-
@pytest.fixture()
72-
def anyio_backend() -> Literal["asyncio"]:
76+
@pytest.fixture(
77+
params=[
78+
pytest.param(("asyncio", {"use_uvloop": False}), id="asyncio"),
79+
pytest.param(
80+
("trio", {"restrict_keyboard_interrupt_to_checkpoints": True}), id="trio"
81+
),
82+
],
83+
)
84+
def anyio_backend(request: pytest.FixtureRequest):
7385
"""Specify the async backend for `pytest.mark.anyio`."""
74-
return "asyncio"
86+
return request.param
7587

7688

7789
@pytest.fixture()
@@ -191,17 +203,22 @@ def reverse_ws_app_fct(
191203

192204

193205
@pytest.fixture()
194-
async def uvicorn_server_fixture() -> AsyncIterator[UvicornServerFixture]:
195-
"""Fixture for UvicornServer.
206+
async def test_server_fixture() -> AsyncIterator[TestServerFixture]:
207+
"""Fixture for TestServer.
196208
197209
Will launch and shutdown automatically.
198210
"""
199211
async with AsyncExitStack() as exit_stack:
200212

201-
async def uvicorn_server_fct(config: uvicorn.Config) -> UvicornServer:
202-
uvicorn_server = await exit_stack.enter_async_context(
203-
UvicornServer(config=config)
213+
async def test_server_fct(
214+
app: FastAPI,
215+
host: str,
216+
port: int,
217+
server_type: Optional[Literal["uvicorn", "hypercorn"]] = None,
218+
) -> TestServer:
219+
test_server = await exit_stack.enter_async_context(
220+
TestServer(app=app, host=host, port=port, server_type=server_type)
204221
)
205-
return uvicorn_server
222+
return test_server
206223

207-
yield uvicorn_server_fct
224+
yield test_server_fct

0 commit comments

Comments
 (0)