22
33from contextlib import AsyncExitStack
44from dataclasses import dataclass
5- from typing import Any , TypedDict , Union
5+ from typing import Any , Literal , TypedDict , Union
66
77import anyio
88import httpx
9+ import sniffio
910import uvicorn
1011from fastapi import FastAPI
12+ from hypercorn import Config as HyperConfig
13+ from hypercorn .asyncio .run import (
14+ worker_serve as hyper_aio_serve , # pyright: ignore[reportUnknownVariableType]
15+ )
16+ from hypercorn .trio .run import (
17+ worker_serve as hyper_trio_serve , # pyright: ignore[reportUnknownVariableType]
18+ )
19+ from hypercorn .utils import (
20+ wrap_app as hyper_wrap_app , # pyright: ignore[reportUnknownVariableType]
21+ )
1122from starlette .requests import Request
1223from starlette .websockets import WebSocket
1324from typing_extensions import Self
@@ -65,7 +76,7 @@ async def __aenter__(self) -> Self:
6576 async def __aexit__ (self , * _ : Any , ** __ : Any ) -> None :
6677 """Shutdown the server."""
6778 # 在 uvicorn.Server 的实现中,设置 should_exit 可以使得 server 任务结束
68- assert self .should_exit is False , "The server has already exited."
79+ assert not self .should_exit , "The server has already exited."
6980 self .should_exit = True
7081 await self ._exit_stack .__aexit__ (* _ , ** __ )
7182
@@ -86,3 +97,129 @@ def contx_socket_url(self) -> httpx.URL:
8697 scheme = "https" if config .is_ssl else "http" ,
8798 path = "/" ,
8899 )
100+
101+
102+ class HypercornServer :
103+ """An AsyncContext to launch and shutdown Hypercorn server automatically."""
104+
105+ def __init__ (self , app : FastAPI , config : HyperConfig ): # noqa: D107
106+ self .config = config
107+ self .app = app
108+ self .should_exit = anyio .Event ()
109+
110+ async def __aenter__ (self ) -> Self :
111+ """Launch the server."""
112+ self ._sockets = self .config .create_sockets ()
113+ self ._exit_stack = AsyncExitStack ()
114+
115+ self .current_async_lib = sniffio .current_async_library ()
116+
117+ if self .current_async_lib == "asyncio" :
118+ serve_func = hyper_aio_serve # pyright: ignore[reportUnknownVariableType]
119+ elif self .current_async_lib == "trio" :
120+ serve_func = hyper_trio_serve # pyright: ignore[reportUnknownVariableType]
121+ else :
122+ raise RuntimeError (f"Unsupported async library { self .current_async_lib !r} " )
123+
124+ async def serve () -> None :
125+ # Implement ref:
126+ # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/asyncio/__init__.py#L12-L46
127+ # https://github.com/pgjones/hypercorn/blob/3fbd5f245e5dfeaba6ad852d9135d6a32b228d05/src/hypercorn/trio/__init__.py#L14-L52
128+ await serve_func (
129+ hyper_wrap_app (
130+ self .app , # pyright: ignore[reportArgumentType]
131+ self .config .wsgi_max_body_size ,
132+ mode = None ,
133+ ),
134+ self .config ,
135+ shutdown_trigger = self .should_exit .wait ,
136+ )
137+
138+ task_group = await self ._exit_stack .enter_async_context (
139+ anyio .create_task_group ()
140+ )
141+ task_group .start_soon (serve , name = f"Hypercorn Server Task of { self } " )
142+ return self
143+
144+ async def __aexit__ (self , * _ : Any , ** __ : Any ) -> None :
145+ """Shutdown the server."""
146+ assert not self .should_exit .is_set (), "The server has already exited."
147+ self .should_exit .set ()
148+ await self ._exit_stack .__aexit__ (* _ , ** __ )
149+
150+ @property
151+ def contx_socket_url (self ) -> httpx .URL :
152+ """If server is tcp socket, return the url of server.
153+
154+ Note: The path of url is explicitly set to "/".
155+ """
156+ config = self .config
157+
158+ bind = config .bind [0 ]
159+ if bind .startswith (("unix:" , "fd://" )):
160+ raise RuntimeError ("Only support tcp socket." )
161+
162+ # refer to: https://docs.python.org/zh-cn/3/library/socket.html#socket-families
163+ host , port = config .bind [0 ].split (":" )
164+ port = int (port )
165+
166+ return httpx .URL (
167+ host = host ,
168+ port = port ,
169+ scheme = "https" if config .ssl_enabled else "http" ,
170+ path = "/" ,
171+ )
172+
173+
174+ class TestServer :
175+ """An AsyncContext to launch and shutdown Hypercorn or Uvicorn server automatically."""
176+
177+ def __init__ (
178+ self ,
179+ app : FastAPI ,
180+ host : str ,
181+ port : int ,
182+ server_type : Literal ["uvicorn" , "hypercorn" ] = "hypercorn" ,
183+ ):
184+ """Only support ipv4 address.
185+
186+ If use uvicorn, it only support asyncio backend.
187+ """
188+ self .app = app
189+ self .host = host
190+ self .port = port
191+ self .server_type = server_type
192+
193+ if self .server_type == "hypercorn" :
194+ config = HyperConfig ()
195+ config .bind = f"{ host } :{ port } "
196+
197+ self .config = config
198+ self .server = HypercornServer (app , config )
199+ else :
200+ self .config = uvicorn .Config (app , host = host , port = port )
201+ self .server = UvicornServer (self .config )
202+
203+ async def __aenter__ (self ) -> Self :
204+ """Launch the server."""
205+ if (
206+ self .server_type == "uvicorn"
207+ and sniffio .current_async_library () != "asyncio"
208+ ):
209+ raise RuntimeError ("Uvicorn server does not support trio backend." )
210+
211+ self ._exit_stack = AsyncExitStack ()
212+ await self ._exit_stack .enter_async_context (self .server )
213+ return self
214+
215+ async def __aexit__ (self , * _ : Any , ** __ : Any ) -> None :
216+ """Shutdown the server."""
217+ await self ._exit_stack .__aexit__ (* _ , ** __ )
218+
219+ @property
220+ def contx_socket_url (self ) -> httpx .URL :
221+ """If server is tcp socket, return the url of server.
222+
223+ Note: The path of url is explicitly set to "/".
224+ """
225+ return self .server .contx_socket_url
0 commit comments