11# noqa: D100
22
3- import asyncio
4- import socket
3+ from contextlib import AsyncExitStack
54from dataclasses import dataclass
6- from typing import Any , Callable , List , Optional , Type , TypedDict , TypeVar , Union
5+ from typing import Any , TypedDict , Union
76
7+ import anyio
88import httpx
99import uvicorn
1010from fastapi import FastAPI
1111from starlette .requests import Request
1212from starlette .websockets import WebSocket
13- from typing_extensions import Self , override
14-
15- _Decoratable_T = TypeVar ("_Decoratable_T" , bound = Union [Callable [..., Any ], Type [Any ]])
13+ from typing_extensions import Self
1614
1715ServerRecvRequestsTypes = Union [Request , WebSocket ]
1816
@@ -46,180 +44,32 @@ def get_request(self) -> ServerRecvRequestsTypes:
4644 return server_recv_request
4745
4846
49- def _no_override_uvicorn_server (_method : _Decoratable_T ) -> _Decoratable_T :
50- """Check if the method is already in `uvicorn.Server`."""
51- assert not hasattr (
52- uvicorn .Server , _method .__name__
53- ), f"Override method of `uvicorn.Server` cls : { _method .__name__ } "
54- return _method
55-
56-
57- class AeixtTimeoutUndefine :
58- """Didn't set `contx_exit_timeout` in `aexit()`."""
59-
60-
61- aexit_timeout_undefine = AeixtTimeoutUndefine ()
62-
63-
64- # HACK: 不能继承 AbstractAsyncContextManager[Self]
65- # 目前有问题,继承 AbstractAsyncContextManager 的话pyright也推测不出来类型
66- # 只能依靠 __aenter__ 和 __aexit__ 的类型注解
6747class UvicornServer (uvicorn .Server ):
68- """subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically.
69-
70- Attributes:
71- contx_server_task: The task of server.
72- contx_socket: The socket of server.
73-
74- other attributes are same as `uvicorn.Server`:
75- - config: The config arg that be passed in.
76- ...
77- """
48+ """subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically."""
7849
79- _contx_server_task : Union ["asyncio.Task[None]" , None ]
80- assert not hasattr (uvicorn .Server , "_contx_server_task" )
81-
82- _contx_socket : Union [socket .socket , None ]
83- assert not hasattr (uvicorn .Server , "_contx_socket" )
84-
85- _contx_server_started_event : Union [asyncio .Event , None ]
86- assert not hasattr (uvicorn .Server , "_contx_server_started_event" )
87-
88- contx_exit_timeout : Union [int , float , None ]
89- assert not hasattr (uvicorn .Server , "contx_exit_timeout" )
90-
91- @override
92- def __init__ (
93- self , config : uvicorn .Config , contx_exit_timeout : Union [int , float , None ] = None
94- ) -> None :
95- """The same as `uvicorn.Server.__init__`."""
96- super ().__init__ (config = config )
97- self ._contx_server_task = None
98- self ._contx_socket = None
99- self ._contx_server_started_event = None
100- self .contx_exit_timeout = contx_exit_timeout
101-
102- @override
103- async def startup (self , sockets : Optional [List [socket .socket ]] = None ) -> None :
104- """The same as `uvicorn.Server.startup`."""
105- super_return = await super ().startup (sockets = sockets )
106- self .contx_server_started_event .set ()
107- return super_return
108-
109- @_no_override_uvicorn_server
110- async def aenter (self ) -> Self :
50+ async def __aenter__ (self ) -> Self :
11151 """Launch the server."""
112- # 在分配资源之前,先检查是否重入
113- if self .contx_server_started_event .is_set ():
114- raise RuntimeError ("DO not launch server by __aenter__ again!" )
115-
11652 # FIXME: # 这个socket被设计为可被同一进程内的多个server共享,可能会引起潜在问题
117- self ._contx_socket = self .config .bind_socket ()
53+ self ._socket = self .config .bind_socket ()
54+ self ._exit_stack = AsyncExitStack ()
11855
119- self ._contx_server_task = asyncio .create_task (
120- self .serve ([self ._contx_socket ]), name = f"Uvicorn Server Task of { self } "
56+ task_group = await self ._exit_stack .enter_async_context (
57+ anyio .create_task_group ()
58+ )
59+ task_group .start_soon (
60+ self .serve , [self ._socket ], name = f"Uvicorn Server Task of { self } "
12161 )
122- # 在 uvicorn.Server 的实现中,Server.serve() 内部会调用 Server.startup() 完成启动
123- # 被覆盖的 self.startup() 会在完成时调用 self.contx_server_started_event.set()
124- await self .contx_server_started_event .wait () # 等待服务器确实启动后才返回
125- return self
12662
127- @_no_override_uvicorn_server
128- async def __aenter__ (self ) -> Self :
129- """Launch the server.
63+ return self
13064
131- The same as `self.aenter()`.
132- """
133- return await self .aenter ()
134-
135- @_no_override_uvicorn_server
136- async def aexit (
137- self ,
138- contx_exit_timeout : Union [
139- int , float , None , AeixtTimeoutUndefine
140- ] = aexit_timeout_undefine ,
141- ) -> None :
65+ async def __aexit__ (self , * _ : Any , ** __ : Any ) -> None :
14266 """Shutdown the server."""
143- contx_server_task = self .contx_server_task
144- contx_socket = self .contx_socket
145-
146- if isinstance (contx_exit_timeout , AeixtTimeoutUndefine ):
147- contx_exit_timeout = self .contx_exit_timeout
148-
14967 # 在 uvicorn.Server 的实现中,设置 should_exit 可以使得 server 任务结束
150- assert hasattr ( self , "should_exit" )
68+ assert self . should_exit is False , "The server has already exited."
15169 self .should_exit = True
152-
153- try :
154- await asyncio .wait_for (contx_server_task , timeout = contx_exit_timeout )
155- except asyncio .TimeoutError :
156- print (f"{ contx_server_task .get_name ()} timeout!" )
157- finally :
158- # 其实uvicorn.Server会自动关闭socket,这里是为了保险起见
159- contx_socket .close ()
160-
161- @_no_override_uvicorn_server
162- async def __aexit__ (self , * _ : Any , ** __ : Any ) -> None :
163- """Shutdown the server.
164-
165- The same as `self.aexit()`.
166- """
167- return await self .aexit ()
168-
169- @property
170- @_no_override_uvicorn_server
171- def contx_server_started_event (self ) -> asyncio .Event :
172- """The event that indicates the server has started.
173-
174- When first call the property, it will instantiate a `asyncio.Event()`to
175- `self._contx_server_started_event`.
176-
177- Warn: This is a internal implementation detail, do not change the event manually.
178- - please call the property in `self.aenter()` or `self.startup()` **first**.
179- - **Never** call it outside of an async event loop first:
180- https://stackoverflow.com/questions/53724665/using-queues-results-in-asyncio-exception-got-future-future-pending-attached
181- """
182- if self ._contx_server_started_event is None :
183- self ._contx_server_started_event = asyncio .Event ()
184-
185- return self ._contx_server_started_event
186-
187- @property
188- @_no_override_uvicorn_server
189- def contx_socket (self ) -> socket .socket :
190- """The socket of server.
191-
192- Note: must call `self.__aenter__()` first.
193- """
194- if self ._contx_socket is None :
195- raise RuntimeError ("Please call `self.__aenter__()` first." )
196- else :
197- return self ._contx_socket
198-
199- @property
200- @_no_override_uvicorn_server
201- def contx_server_task (self ) -> "asyncio.Task[None]" :
202- """The task of server.
203-
204- Note: must call `self.__aenter__()` first.
205- """
206- if self ._contx_server_task is None :
207- raise RuntimeError ("Please call `self.__aenter__()` first." )
208- else :
209- return self ._contx_server_task
210-
211- @property
212- @_no_override_uvicorn_server
213- def contx_socket_getname (self ) -> Any :
214- """Utils for calling self.contx_socket.getsockname().
215-
216- Return:
217- refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
218- """
219- return self .contx_socket .getsockname ()
70+ await self ._exit_stack .__aexit__ (* _ , ** __ )
22071
22172 @property
222- @_no_override_uvicorn_server
22373 def contx_socket_url (self ) -> httpx .URL :
22474 """If server is tcp socket, return the url of server.
22575
@@ -228,7 +78,8 @@ def contx_socket_url(self) -> httpx.URL:
22878 config = self .config
22979 if config .fd is not None or config .uds is not None :
23080 raise RuntimeError ("Only support tcp socket." )
231- host , port = self .contx_socket_getname [:2 ]
81+ # refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
82+ host , port = self ._socket .getsockname ()[:2 ]
23283 return httpx .URL (
23384 host = host ,
23485 port = port ,
0 commit comments