22
33from contextlib import AsyncExitStack
44from dataclasses import dataclass
5- from typing import Any , Literal , TypedDict , Union
5+ from typing import Any , Literal , Optional , TypedDict , Union
66
77import anyio
88import httpx
1111from fastapi import FastAPI
1212from hypercorn import Config as HyperConfig
1313from hypercorn .asyncio .run import (
14- worker_serve as hyper_aio_serve , # pyright: ignore[reportUnknownVariableType]
14+ worker_serve as hyper_aio_worker_serve , # pyright: ignore[reportUnknownVariableType]
1515)
1616from hypercorn .trio .run import (
17- worker_serve as hyper_trio_serve , # pyright: ignore[reportUnknownVariableType]
17+ worker_serve as hyper_trio_worker_serve , # pyright: ignore[reportUnknownVariableType]
18+ )
19+ from hypercorn .utils import (
20+ repr_socket_addr , # pyright: ignore[reportUnknownVariableType]
1821)
1922from hypercorn .utils import (
2023 wrap_app as hyper_wrap_app , # pyright: ignore[reportUnknownVariableType]
2124)
2225from starlette .requests import Request
2326from starlette .websockets import WebSocket
24- from typing_extensions import Self
27+ from typing_extensions import Self , assert_never
2528
2629ServerRecvRequestsTypes = Union [Request , WebSocket ]
2730
@@ -55,7 +58,7 @@ def get_request(self) -> ServerRecvRequestsTypes:
5558 return server_recv_request
5659
5760
58- class UvicornServer (uvicorn .Server ):
61+ class _UvicornServer (uvicorn .Server ):
5962 """subclass of `uvicorn.Server` which can use AsyncContext to launch and shutdown automatically."""
6063
6164 async def __aenter__ (self ) -> Self :
@@ -89,7 +92,9 @@ def contx_socket_url(self) -> httpx.URL:
8992 config = self .config
9093 if config .fd is not None or config .uds is not None :
9194 raise RuntimeError ("Only support tcp socket." )
92- # refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
95+
96+ # Implement ref:
97+ # https://github.com/encode/uvicorn/blob/a2219eb2ed2bbda4143a0fb18c4b0578881b1ae8/uvicorn/server.py#L201-L220
9398 host , port = self ._socket .getsockname ()[:2 ]
9499 return httpx .URL (
95100 host = host ,
@@ -99,25 +104,42 @@ def contx_socket_url(self) -> httpx.URL:
99104 )
100105
101106
102- class HypercornServer :
107+ class _HypercornServer :
103108 """An AsyncContext to launch and shutdown Hypercorn server automatically."""
104109
105- def __init__ (self , app : FastAPI , config : HyperConfig ): # noqa: D107
110+ def __init__ (self , app : FastAPI , config : HyperConfig ):
106111 self .config = config
107112 self .app = app
108113 self .should_exit = anyio .Event ()
109114
110115 async def __aenter__ (self ) -> Self :
111116 """Launch the server."""
112- self ._sockets = self .config .create_sockets ()
113117 self ._exit_stack = AsyncExitStack ()
114118
115119 self .current_async_lib = sniffio .current_async_library ()
116120
117121 if self .current_async_lib == "asyncio" :
118- serve_func = hyper_aio_serve # pyright: ignore[reportUnknownVariableType]
122+ serve_func = ( # pyright: ignore[reportUnknownVariableType]
123+ hyper_aio_worker_serve
124+ )
125+
126+ # Implement ref:
127+ # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/run.py#L89-L90
128+ self ._sockets = self .config .create_sockets ()
129+
119130 elif self .current_async_lib == "trio" :
120- serve_func = hyper_trio_serve # pyright: ignore[reportUnknownVariableType]
131+ serve_func = ( # pyright: ignore[reportUnknownVariableType]
132+ hyper_trio_worker_serve
133+ )
134+
135+ # Implement ref:
136+ # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/run.py#L51-L56
137+ self ._sockets = self .config .create_sockets ()
138+ for sock in self ._sockets .secure_sockets :
139+ sock .listen (self .config .backlog )
140+ for sock in self ._sockets .insecure_sockets :
141+ sock .listen (self .config .backlog )
142+
121143 else :
122144 raise RuntimeError (f"Unsupported async library { self .current_async_lib !r} " )
123145
@@ -133,6 +155,7 @@ async def serve() -> None:
133155 ),
134156 self .config ,
135157 shutdown_trigger = self .should_exit .wait ,
158+ sockets = self ._sockets ,
136159 )
137160
138161 task_group = await self ._exit_stack .enter_async_context (
@@ -154,13 +177,32 @@ def contx_socket_url(self) -> httpx.URL:
154177 Note: The path of url is explicitly set to "/".
155178 """
156179 config = self .config
180+ sockets = self ._sockets
181+
182+ # Implement ref:
183+ # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/run.py#L112-L149
184+ # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/run.py#L61-L82
185+
186+ # We only run on one socket each time,
187+ # so we raise `RuntimeError` to avoid other unknown errors during testing.
188+ if sockets .insecure_sockets :
189+ if len (sockets .insecure_sockets ) > 1 :
190+ raise RuntimeError ("Hypercorn test: Multiple insecure_sockets found." )
191+ socket = sockets .insecure_sockets [0 ]
192+ elif sockets .secure_sockets :
193+ if len (sockets .secure_sockets ) > 1 :
194+ raise RuntimeError ("Hypercorn test: secure_sockets sockets found." )
195+ socket = sockets .secure_sockets [0 ]
196+ else :
197+ raise RuntimeError ("Hypercorn test: No socket found." )
157198
158- bind = config . bind [ 0 ]
199+ bind = repr_socket_addr ( socket . family , socket . getsockname ())
159200 if bind .startswith (("unix:" , "fd://" )):
160201 raise RuntimeError ("Only support tcp socket." )
161202
162- # refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
163- host , port = config .bind [0 ].split (":" )
203+ # Implement ref:
204+ # https://docs.python.org/zh-cn/3/library/socket.html#socket-families
205+ host , port = bind .split (":" )
164206 port = int (port )
165207
166208 return httpx .URL (
@@ -179,12 +221,16 @@ def __init__(
179221 app : FastAPI ,
180222 host : str ,
181223 port : int ,
182- server_type : Literal ["uvicorn" , "hypercorn" ] = "hypercorn" ,
224+ server_type : Optional [ Literal ["uvicorn" , "hypercorn" ]] = None ,
183225 ):
184226 """Only support ipv4 address.
185227
186228 If use uvicorn, it only support asyncio backend.
229+
230+ If `host` == 0, then use random port.
187231 """
232+ server_type = server_type if server_type is not None else "hypercorn"
233+
188234 self .app = app
189235 self .host = host
190236 self .port = port
@@ -195,10 +241,12 @@ def __init__(
195241 config .bind = f"{ host } :{ port } "
196242
197243 self .config = config
198- self .server = HypercornServer (app , config )
199- else :
244+ self .server = _HypercornServer (app , config )
245+ elif self . server_type == "uvicorn" :
200246 self .config = uvicorn .Config (app , host = host , port = port )
201- self .server = UvicornServer (self .config )
247+ self .server = _UvicornServer (self .config )
248+ else :
249+ assert_never (self .server_type )
202250
203251 async def __aenter__ (self ) -> Self :
204252 """Launch the server."""
0 commit comments