Skip to content

Commit ef8da55

Browse files
committed
test: using anyio instead of asyncio in tests
And simplify the `UvicornServer` class in test
1 parent 430acd6 commit ef8da55

File tree

5 files changed

+42
-197
lines changed

5 files changed

+42
-197
lines changed

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,11 @@ dependencies = [
9898
"pytest == 7.*",
9999
"pytest-cov == 4.*",
100100
"uvicorn[standard] < 1.0.0", # TODO: Once it releases version 1.0.0, we will remove this restriction.
101+
"hypercorn[trio] == 0.16.*",
101102
"httpx[http2]", # we don't set version here, instead set it in `[project].dependencies`.
102-
"asgi-lifespan==2.*",
103-
"pytest-timeout==2.*",
103+
"asgi-lifespan == 2.*",
104+
"pytest-timeout == 2.*",
105+
"sniffio == 1.3.*",
104106
]
105107

106108
[tool.hatch.envs.default.scripts]

tests/app/echo_ws_app.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# ruff: noqa: D100
22
# pyright: reportUnusedFunction=false
33

4-
import asyncio
54

5+
import anyio
66
from fastapi import FastAPI, WebSocket
77
from starlette.websockets import WebSocketDisconnect
88

@@ -76,7 +76,7 @@ async def just_close_with_1001(websocket: WebSocket):
7676
test_app_dataclass.request_dict["request"] = websocket
7777

7878
await websocket.accept()
79-
await asyncio.sleep(0.3)
79+
await anyio.sleep(0.3)
8080
await websocket.close(1001)
8181

8282
@app.websocket("/reject_handshake")

tests/app/tool.py

Lines changed: 19 additions & 168 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
# noqa: D100
22

3-
import asyncio
4-
import socket
3+
from contextlib import AsyncExitStack
54
from dataclasses import dataclass
6-
from typing import Any, Callable, List, Optional, Type, TypedDict, TypeVar, Union
5+
from typing import Any, TypedDict, Union
76

7+
import anyio
88
import httpx
99
import uvicorn
1010
from fastapi import FastAPI
1111
from starlette.requests import Request
1212
from starlette.websockets import WebSocket
13-
from typing_extensions import Self, override
14-
15-
_Decoratable_T = TypeVar("_Decoratable_T", bound=Union[Callable[..., Any], Type[Any]])
13+
from typing_extensions import Self
1614

1715
ServerRecvRequestsTypes = Union[Request, WebSocket]
1816

@@ -46,180 +44,32 @@ def get_request(self) -> ServerRecvRequestsTypes:
4644
return server_recv_request
4745

4846

49-
def _no_override_uvicorn_server(_method: _Decoratable_T) -> _Decoratable_T:
50-
"""Check if the method is already in `uvicorn.Server`."""
51-
assert not hasattr(
52-
uvicorn.Server, _method.__name__
53-
), f"Override method of `uvicorn.Server` cls : {_method.__name__}"
54-
return _method
55-
56-
57-
class AeixtTimeoutUndefine:
58-
"""Didn't set `contx_exit_timeout` in `aexit()`."""
59-
60-
61-
aexit_timeout_undefine = AeixtTimeoutUndefine()
62-
63-
64-
# HACK: 不能继承 AbstractAsyncContextManager[Self]
65-
# 目前有问题,继承 AbstractAsyncContextManager 的话pyright也推测不出来类型
66-
# 只能依靠 __aenter__ 和 __aexit__ 的类型注解
6747
class UvicornServer(uvicorn.Server):
68-
"""subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically.
69-
70-
Attributes:
71-
contx_server_task: The task of server.
72-
contx_socket: The socket of server.
73-
74-
other attributes are same as `uvicorn.Server`:
75-
- config: The config arg that be passed in.
76-
...
77-
"""
48+
"""subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically."""
7849

79-
_contx_server_task: Union["asyncio.Task[None]", None]
80-
assert not hasattr(uvicorn.Server, "_contx_server_task")
81-
82-
_contx_socket: Union[socket.socket, None]
83-
assert not hasattr(uvicorn.Server, "_contx_socket")
84-
85-
_contx_server_started_event: Union[asyncio.Event, None]
86-
assert not hasattr(uvicorn.Server, "_contx_server_started_event")
87-
88-
contx_exit_timeout: Union[int, float, None]
89-
assert not hasattr(uvicorn.Server, "contx_exit_timeout")
90-
91-
@override
92-
def __init__(
93-
self, config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None
94-
) -> None:
95-
"""The same as `uvicorn.Server.__init__`."""
96-
super().__init__(config=config)
97-
self._contx_server_task = None
98-
self._contx_socket = None
99-
self._contx_server_started_event = None
100-
self.contx_exit_timeout = contx_exit_timeout
101-
102-
@override
103-
async def startup(self, sockets: Optional[List[socket.socket]] = None) -> None:
104-
"""The same as `uvicorn.Server.startup`."""
105-
super_return = await super().startup(sockets=sockets)
106-
self.contx_server_started_event.set()
107-
return super_return
108-
109-
@_no_override_uvicorn_server
110-
async def aenter(self) -> Self:
50+
async def __aenter__(self) -> Self:
11151
"""Launch the server."""
112-
# 在分配资源之前,先检查是否重入
113-
if self.contx_server_started_event.is_set():
114-
raise RuntimeError("DO not launch server by __aenter__ again!")
115-
11652
# FIXME: # 这个socket被设计为可被同一进程内的多个server共享,可能会引起潜在问题
117-
self._contx_socket = self.config.bind_socket()
53+
self._socket = self.config.bind_socket()
54+
self._exit_stack = AsyncExitStack()
11855

119-
self._contx_server_task = asyncio.create_task(
120-
self.serve([self._contx_socket]), name=f"Uvicorn Server Task of {self}"
56+
task_group = await self._exit_stack.enter_async_context(
57+
anyio.create_task_group()
58+
)
59+
task_group.start_soon(
60+
self.serve, [self._socket], name=f"Uvicorn Server Task of {self}"
12161
)
122-
# 在 uvicorn.Server 的实现中,Server.serve() 内部会调用 Server.startup() 完成启动
123-
# 被覆盖的 self.startup() 会在完成时调用 self.contx_server_started_event.set()
124-
await self.contx_server_started_event.wait() # 等待服务器确实启动后才返回
125-
return self
12662

127-
@_no_override_uvicorn_server
128-
async def __aenter__(self) -> Self:
129-
"""Launch the server.
63+
return self
13064

131-
The same as `self.aenter()`.
132-
"""
133-
return await self.aenter()
134-
135-
@_no_override_uvicorn_server
136-
async def aexit(
137-
self,
138-
contx_exit_timeout: Union[
139-
int, float, None, AeixtTimeoutUndefine
140-
] = aexit_timeout_undefine,
141-
) -> None:
65+
async def __aexit__(self, *_: Any, **__: Any) -> None:
14266
"""Shutdown the server."""
143-
contx_server_task = self.contx_server_task
144-
contx_socket = self.contx_socket
145-
146-
if isinstance(contx_exit_timeout, AeixtTimeoutUndefine):
147-
contx_exit_timeout = self.contx_exit_timeout
148-
14967
# 在 uvicorn.Server 的实现中,设置 should_exit 可以使得 server 任务结束
150-
assert hasattr(self, "should_exit")
68+
assert self.should_exit is False, "The server has already exited."
15169
self.should_exit = True
152-
153-
try:
154-
await asyncio.wait_for(contx_server_task, timeout=contx_exit_timeout)
155-
except asyncio.TimeoutError:
156-
print(f"{contx_server_task.get_name()} timeout!")
157-
finally:
158-
# 其实uvicorn.Server会自动关闭socket,这里是为了保险起见
159-
contx_socket.close()
160-
161-
@_no_override_uvicorn_server
162-
async def __aexit__(self, *_: Any, **__: Any) -> None:
163-
"""Shutdown the server.
164-
165-
The same as `self.aexit()`.
166-
"""
167-
return await self.aexit()
168-
169-
@property
170-
@_no_override_uvicorn_server
171-
def contx_server_started_event(self) -> asyncio.Event:
172-
"""The event that indicates the server has started.
173-
174-
When first call the property, it will instantiate a `asyncio.Event()`to
175-
`self._contx_server_started_event`.
176-
177-
Warn: This is a internal implementation detail, do not change the event manually.
178-
- please call the property in `self.aenter()` or `self.startup()` **first**.
179-
- **Never** call it outside of an async event loop first:
180-
https://stackoverflow.com/questions/53724665/using-queues-results-in-asyncio-exception-got-future-future-pending-attached
181-
"""
182-
if self._contx_server_started_event is None:
183-
self._contx_server_started_event = asyncio.Event()
184-
185-
return self._contx_server_started_event
186-
187-
@property
188-
@_no_override_uvicorn_server
189-
def contx_socket(self) -> socket.socket:
190-
"""The socket of server.
191-
192-
Note: must call `self.__aenter__()` first.
193-
"""
194-
if self._contx_socket is None:
195-
raise RuntimeError("Please call `self.__aenter__()` first.")
196-
else:
197-
return self._contx_socket
198-
199-
@property
200-
@_no_override_uvicorn_server
201-
def contx_server_task(self) -> "asyncio.Task[None]":
202-
"""The task of server.
203-
204-
Note: must call `self.__aenter__()` first.
205-
"""
206-
if self._contx_server_task is None:
207-
raise RuntimeError("Please call `self.__aenter__()` first.")
208-
else:
209-
return self._contx_server_task
210-
211-
@property
212-
@_no_override_uvicorn_server
213-
def contx_socket_getname(self) -> Any:
214-
"""Utils for calling self.contx_socket.getsockname().
215-
216-
Return:
217-
refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
218-
"""
219-
return self.contx_socket.getsockname()
70+
await self._exit_stack.__aexit__(*_, **__)
22071

22172
@property
222-
@_no_override_uvicorn_server
22373
def contx_socket_url(self) -> httpx.URL:
22474
"""If server is tcp socket, return the url of server.
22575
@@ -228,7 +78,8 @@ def contx_socket_url(self) -> httpx.URL:
22878
config = self.config
22979
if config.fd is not None or config.uds is not None:
23080
raise RuntimeError("Only support tcp socket.")
231-
host, port = self.contx_socket_getname[:2]
81+
# refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
82+
host, port = self._socket.getsockname()[:2]
23283
return httpx.URL(
23384
host=host,
23485
port=port,

tests/conftest.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
Coroutine,
1717
Literal,
1818
Protocol,
19-
Union,
2019
)
2120

2221
import pytest
@@ -64,7 +63,7 @@ class LifeAppDataclass4Test(AppDataclass4Test):
6463

6564
class UvicornServerFixture(Protocol): # noqa: D101
6665
def __call__( # noqa: D102
67-
self, config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None
66+
self, config: uvicorn.Config
6867
) -> Coroutine[None, None, UvicornServer]: ...
6968

7069

@@ -199,11 +198,9 @@ async def uvicorn_server_fixture() -> AsyncIterator[UvicornServerFixture]:
199198
"""
200199
async with AsyncExitStack() as exit_stack:
201200

202-
async def uvicorn_server_fct(
203-
config: uvicorn.Config, contx_exit_timeout: Union[int, float, None] = None
204-
) -> UvicornServer:
201+
async def uvicorn_server_fct(config: uvicorn.Config) -> UvicornServer:
205202
uvicorn_server = await exit_stack.enter_async_context(
206-
UvicornServer(config=config, contx_exit_timeout=contx_exit_timeout)
203+
UvicornServer(config=config)
207204
)
208205
return uvicorn_server
209206

tests/test_ws.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
# noqa: D100
22

33

4-
import asyncio
54
from contextlib import AsyncExitStack
65
from multiprocessing import Process, Queue
76
from typing import Any, Dict, Literal, Optional
87

8+
import anyio
99
import httpx
1010
import httpx_ws
1111
import pytest
@@ -25,7 +25,6 @@
2525

2626
DEFAULT_HOST = "127.0.0.1"
2727
DEFAULT_PORT = 0
28-
DEFAULT_CONTX_EXIT_TIMEOUT = 5
2928

3029
# WS_BACKENDS_NEED_BE_TESTED = ("websockets", "wsproto")
3130
# # FIXME: wsproto 有问题,暂时不测试
@@ -56,14 +55,14 @@ def _subprocess_run_echo_ws_uvicorn_server(queue: "Queue[str]", **kwargs: Any):
5655
)
5756

5857
async def run():
59-
await target_ws_server.aenter()
60-
url = str(target_ws_server.contx_socket_url)
61-
queue.put(url)
62-
queue.close()
63-
while True: # run forever
64-
await asyncio.sleep(0.1)
58+
async with target_ws_server:
59+
url = str(target_ws_server.contx_socket_url)
60+
queue.put(url)
61+
queue.close()
62+
while True: # run forever
63+
await anyio.sleep(0.1)
6564

66-
asyncio.run(run())
65+
anyio.run(run)
6766

6867

6968
def _subprocess_run_httpx_ws(
@@ -96,9 +95,9 @@ async def run():
9695
queue.put("done")
9796
queue.close()
9897
while True: # run forever
99-
await asyncio.sleep(0.1)
98+
await anyio.sleep(0.1)
10099

101-
asyncio.run(run())
100+
anyio.run(run)
102101

103102

104103
class TestReverseWsProxy(AbstractTestProxy):
@@ -120,7 +119,6 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri
120119
uvicorn.Config(
121120
echo_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param
122121
),
123-
contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT,
124122
)
125123

126124
target_server_base_url = str(target_ws_server.contx_socket_url)
@@ -135,7 +133,6 @@ async def tool_4_test_fixture( # pyright: ignore[reportIncompatibleMethodOverri
135133
uvicorn.Config(
136134
reverse_ws_app, port=DEFAULT_PORT, host=DEFAULT_HOST, ws=request.param
137135
),
138-
contx_exit_timeout=DEFAULT_CONTX_EXIT_TIMEOUT,
139136
)
140137

141138
proxy_server_base_url = str(proxy_ws_server.contx_socket_url)
@@ -226,7 +223,7 @@ async def test_ws_proxy(self, tool_4_test_fixture: Tool4TestFixture) -> None:
226223

227224
# 避免从队列中get导致的异步阻塞
228225
while aconnect_ws_subprocess_queue.empty():
229-
await asyncio.sleep(0.1)
226+
await anyio.sleep(0.1)
230227
_ = aconnect_ws_subprocess_queue.get() # 获取到了即代表连接建立成功
231228

232229
# force shutdown client
@@ -267,7 +264,7 @@ async def test_target_server_shutdown_abnormally(
267264

268265
# 避免从队列中get导致的异步阻塞
269266
while subprocess_queue.empty():
270-
await asyncio.sleep(0.1)
267+
await anyio.sleep(0.1)
271268
target_server_base_url = subprocess_queue.get()
272269

273270
client_for_conn_to_target_server = httpx.AsyncClient(proxies=NO_PROXIES)
@@ -300,13 +297,11 @@ async def test_target_server_shutdown_abnormally(
300297
await ws0.receive()
301298
assert exce.value.code == 1011
302299

303-
loop = asyncio.get_running_loop()
304-
305-
seconde_ws_recv_start = loop.time()
300+
seconde_ws_recv_start = anyio.current_time()
306301
with pytest.raises(httpx_ws.WebSocketDisconnect) as exce:
307302
await ws1.receive()
308303
assert exce.value.code == 1011
309-
seconde_ws_recv_end = loop.time()
304+
seconde_ws_recv_end = anyio.current_time()
310305

311306
# HACK: 由于收到关闭代码需要40s,目前无法确定是什么原因,
312307
# 所以目前会同时测试两个客户端的连接,

0 commit comments

Comments
 (0)