11from httpcore ._sync .http_proxy import HTTPProxy , TunnelHTTPConnection , merge_headers , logger
2- from httpcore ._sync .connection_pool import ConnectionPool
32from httpcore ._sync .http11 import HTTP11Connection
3+ from httpcore ._async .http_proxy import AsyncHTTPProxy , AsyncTunnelHTTPConnection
4+ from httpcore ._async .http11 import AsyncHTTP11Connection
45from httpcore ._models import URL , Request
56from httpcore ._exceptions import ProxyError
67from httpcore ._ssl import default_ssl_context
78from httpcore ._trace import Trace
8- from httpx import HTTPTransport
9+ from httpx import AsyncHTTPTransport , HTTPTransport
910from httpx ._config import DEFAULT_LIMITS , Proxy , create_ssl_context
1011
1112class ProxyTunnelHTTPConnection (TunnelHTTPConnection ):
@@ -94,6 +95,90 @@ def handle_request(self, request):
9495 response .headers = merge_headers (response .headers , connect_response .headers )
9596 return response
9697
98+ class AsyncProxyTunnelHTTPConnection (AsyncTunnelHTTPConnection ):
99+ async def handle_async_request (self , request ):
100+ timeouts = request .extensions .get ("timeout" , {})
101+ timeout = timeouts .get ("connect" , None )
102+
103+ async with self ._connect_lock :
104+ if not self ._connected :
105+ target = b"%b:%d" % (self ._remote_origin .host , self ._remote_origin .port )
106+
107+ connect_url = URL (
108+ scheme = self ._proxy_origin .scheme ,
109+ host = self ._proxy_origin .host ,
110+ port = self ._proxy_origin .port ,
111+ target = target ,
112+ )
113+ connect_headers = merge_headers (
114+ [(b"Host" , target ), (b"Accept" , b"*/*" )], self ._proxy_headers
115+ )
116+ connect_request = Request (
117+ method = b"CONNECT" ,
118+ url = connect_url ,
119+ headers = connect_headers ,
120+ extensions = request .extensions ,
121+ )
122+ connect_response = await self ._connection .handle_async_request (
123+ connect_request
124+ )
125+
126+ if connect_response .status < 200 or connect_response .status > 299 :
127+ reason_bytes = connect_response .extensions .get ("reason_phrase" , b"" )
128+ reason_str = reason_bytes .decode ("ascii" , errors = "ignore" )
129+ msg = "%d %s" % (connect_response .status , reason_str )
130+ await self ._connection .aclose ()
131+ raise ProxyError (msg )
132+
133+ stream = connect_response .extensions ["network_stream" ]
134+
135+ # Upgrade the stream to SSL
136+ ssl_context = (
137+ default_ssl_context ()
138+ if self ._ssl_context is None
139+ else self ._ssl_context
140+ )
141+ alpn_protocols = ["http/1.1" , "h2" ] if self ._http2 else ["http/1.1" ]
142+ ssl_context .set_alpn_protocols (alpn_protocols )
143+
144+ kwargs = {
145+ "ssl_context" : ssl_context ,
146+ "server_hostname" : self ._remote_origin .host .decode ("ascii" ),
147+ "timeout" : timeout ,
148+ }
149+ async with Trace ("start_tls" , logger , request , kwargs ) as trace :
150+ stream = await stream .start_tls (** kwargs )
151+ trace .return_value = stream
152+
153+ # Determine if we should be using HTTP/1.1 or HTTP/2
154+ ssl_object = stream .get_extra_info ("ssl_object" )
155+ http2_negotiated = (
156+ ssl_object is not None
157+ and ssl_object .selected_alpn_protocol () == "h2"
158+ )
159+
160+ # Create the HTTP/1.1 or HTTP/2 connection
161+ if http2_negotiated or (self ._http2 and not self ._http1 ):
162+ from httpcore ._async .http2 import AsyncHTTP2Connection
163+
164+ self ._connection = AsyncHTTP2Connection (
165+ origin = self ._remote_origin ,
166+ stream = stream ,
167+ keepalive_expiry = self ._keepalive_expiry ,
168+ )
169+ else :
170+ self ._connection = AsyncHTTP11Connection (
171+ origin = self ._remote_origin ,
172+ stream = stream ,
173+ keepalive_expiry = self ._keepalive_expiry ,
174+ )
175+
176+ self ._connected = True
177+ # this is the only modification
178+ response = await self ._connection .handle_async_request (request )
179+ response .headers = merge_headers (response .headers , connect_response .headers )
180+ return response
181+
97182class HTTPProxyHeaders (HTTPProxy ):
98183 def create_connection (self , origin ):
99184 if origin .scheme == b"http" :
@@ -110,6 +195,22 @@ def create_connection(self, origin):
110195 network_backend = self ._network_backend ,
111196 )
112197
198+ class AsyncHTTPProxyHeaders (AsyncHTTPProxy ):
199+ def create_connection (self , origin ):
200+ if origin .scheme == b"http" :
201+ return super ().create_connection (origin )
202+ return AsyncProxyTunnelHTTPConnection (
203+ proxy_origin = self ._proxy_url .origin ,
204+ proxy_headers = self ._proxy_headers ,
205+ remote_origin = origin ,
206+ ssl_context = self ._ssl_context ,
207+ proxy_ssl_context = self ._proxy_ssl_context ,
208+ keepalive_expiry = self ._keepalive_expiry ,
209+ http1 = self ._http1 ,
210+ http2 = self ._http2 ,
211+ network_backend = self ._network_backend ,
212+ )
213+
113214# class ProxyConnectionPool(ConnectionPool):
114215# def create_connection(self, origin):
115216# if self._proxy is not None:
@@ -169,5 +270,45 @@ def __init__(
169270 http2 = http2 ,
170271 socket_options = socket_options ,
171272 )
273+ else :
274+ super ().__init__ (verify , cert , trust_env , http1 , http2 , limits , proxy , uds , local_address , retries , socket_options )
275+
276+ class AsyncHTTPProxyTransport (AsyncHTTPTransport ):
277+ def __init__ (
278+ self ,
279+ verify = True ,
280+ cert = None ,
281+ trust_env : bool = True ,
282+ http1 : bool = True ,
283+ http2 : bool = False ,
284+ limits = DEFAULT_LIMITS ,
285+ proxy = None ,
286+ uds : str | None = None ,
287+ local_address : str | None = None ,
288+ retries : int = 0 ,
289+ socket_options = None ,
290+ ) -> None :
291+ proxy = Proxy (url = proxy ) if isinstance (proxy , (str , URL )) else proxy
292+ ssl_context = create_ssl_context (verify = verify , cert = cert , trust_env = trust_env )
293+
294+ if proxy and proxy .url .scheme in ("http" , "https" ):
295+ self ._pool = AsyncHTTPProxyHeaders (
296+ proxy_url = URL (
297+ scheme = proxy .url .raw_scheme ,
298+ host = proxy .url .raw_host ,
299+ port = proxy .url .port ,
300+ target = proxy .url .raw_path ,
301+ ),
302+ proxy_auth = proxy .raw_auth ,
303+ proxy_headers = proxy .headers .raw ,
304+ proxy_ssl_context = proxy .ssl_context ,
305+ ssl_context = ssl_context ,
306+ max_connections = limits .max_connections ,
307+ max_keepalive_connections = limits .max_keepalive_connections ,
308+ keepalive_expiry = limits .keepalive_expiry ,
309+ http1 = http1 ,
310+ http2 = http2 ,
311+ socket_options = socket_options ,
312+ )
172313 else :
173314 super ().__init__ (verify , cert , trust_env , http1 , http2 , limits , proxy , uds , local_address , retries , socket_options )
0 commit comments