1+ from httpcore ._sync .http_proxy import HTTPProxy , TunnelHTTPConnection , merge_headers , logger
2+ from httpcore ._sync .connection_pool import ConnectionPool
3+ from httpcore ._sync .http11 import HTTP11Connection
4+ from httpcore ._models import URL , Request
5+ from httpcore ._exceptions import ProxyError
6+ from httpcore ._ssl import default_ssl_context
7+ from httpcore ._trace import Trace
8+ from httpx import HTTPTransport
9+ from httpx ._config import DEFAULT_LIMITS , Proxy , create_ssl_context
10+
11+ class ProxyTunnelHTTPConnection (TunnelHTTPConnection ):
12+ # Unfortunately the only way to get connect_response.headers into the Response
13+ # is to override this whole method
14+ def handle_request (self , request ):
15+ timeouts = request .extensions .get ("timeout" , {})
16+ timeout = timeouts .get ("connect" , None )
17+
18+ with self ._connect_lock :
19+ if not self ._connected :
20+ target = b"%b:%d" % (self ._remote_origin .host , self ._remote_origin .port )
21+
22+ connect_url = URL (
23+ scheme = self ._proxy_origin .scheme ,
24+ host = self ._proxy_origin .host ,
25+ port = self ._proxy_origin .port ,
26+ target = target ,
27+ )
28+ connect_headers = merge_headers (
29+ [(b"Host" , target ), (b"Accept" , b"*/*" )], self ._proxy_headers
30+ )
31+ connect_request = Request (
32+ method = b"CONNECT" ,
33+ url = connect_url ,
34+ headers = connect_headers ,
35+ extensions = request .extensions ,
36+ )
37+ connect_response = self ._connection .handle_request (
38+ connect_request
39+ )
40+
41+ if connect_response .status < 200 or connect_response .status > 299 :
42+ reason_bytes = connect_response .extensions .get ("reason_phrase" , b"" )
43+ reason_str = reason_bytes .decode ("ascii" , errors = "ignore" )
44+ msg = "%d %s" % (connect_response .status , reason_str )
45+ self ._connection .close ()
46+ raise ProxyError (msg )
47+
48+ stream = connect_response .extensions ["network_stream" ]
49+
50+ # Upgrade the stream to SSL
51+ ssl_context = (
52+ default_ssl_context ()
53+ if self ._ssl_context is None
54+ else self ._ssl_context
55+ )
56+ alpn_protocols = ["http/1.1" , "h2" ] if self ._http2 else ["http/1.1" ]
57+ ssl_context .set_alpn_protocols (alpn_protocols )
58+
59+ kwargs = {
60+ "ssl_context" : ssl_context ,
61+ "server_hostname" : self ._remote_origin .host .decode ("ascii" ),
62+ "timeout" : timeout ,
63+ }
64+ with Trace ("start_tls" , logger , request , kwargs ) as trace :
65+ stream = stream .start_tls (** kwargs )
66+ trace .return_value = stream
67+
68+ # Determine if we should be using HTTP/1.1 or HTTP/2
69+ ssl_object = stream .get_extra_info ("ssl_object" )
70+ http2_negotiated = (
71+ ssl_object is not None
72+ and ssl_object .selected_alpn_protocol () == "h2"
73+ )
74+
75+ # Create the HTTP/1.1 or HTTP/2 connection
76+ if http2_negotiated or (self ._http2 and not self ._http1 ):
77+ from httpcore ._sync .http2 import HTTP2Connection
78+
79+ self ._connection = HTTP2Connection (
80+ origin = self ._remote_origin ,
81+ stream = stream ,
82+ keepalive_expiry = self ._keepalive_expiry ,
83+ )
84+ else :
85+ self ._connection = HTTP11Connection (
86+ origin = self ._remote_origin ,
87+ stream = stream ,
88+ keepalive_expiry = self ._keepalive_expiry ,
89+ )
90+
91+ self ._connected = True
92+ # this is the only modification
93+ response = self ._connection .handle_request (request )
94+ response .headers = merge_headers (response .headers , connect_response .headers )
95+ return response
96+
97+ class HTTPProxyHeaders (HTTPProxy ):
98+ def create_connection (self , origin ):
99+ if origin .scheme == b"http" :
100+ return super ().create_connection (origin )
101+ return ProxyTunnelHTTPConnection (
102+ proxy_origin = self ._proxy_url .origin ,
103+ proxy_headers = self ._proxy_headers ,
104+ remote_origin = origin ,
105+ ssl_context = self ._ssl_context ,
106+ proxy_ssl_context = self ._proxy_ssl_context ,
107+ keepalive_expiry = self ._keepalive_expiry ,
108+ http1 = self ._http1 ,
109+ http2 = self ._http2 ,
110+ network_backend = self ._network_backend ,
111+ )
112+
113+ # class ProxyConnectionPool(ConnectionPool):
114+ # def create_connection(self, origin):
115+ # if self._proxy is not None:
116+ # if self._proxy.url.scheme in (b"socks5", b"socks5h"):
117+ # return super().create_connection(origin)
118+ # elif origin.scheme == b"http":
119+ # return super().create_connection(origin)
120+
121+ # return ProxyTunnelHTTPConnection(
122+ # proxy_origin=self._proxy.url.origin,
123+ # proxy_headers=self._proxy.headers,
124+ # proxy_ssl_context=self._proxy.ssl_context,
125+ # remote_origin=origin,
126+ # ssl_context=self._ssl_context,
127+ # keepalive_expiry=self._keepalive_expiry,
128+ # http1=self._http1,
129+ # http2=self._http2,
130+ # network_backend=self._network_backend,
131+ # )
132+
133+ # return super().create_connection(origin)
134+
135+ class HTTPProxyTransport (HTTPTransport ):
136+ def __init__ (
137+ self ,
138+ verify = True ,
139+ cert = None ,
140+ trust_env : bool = True ,
141+ http1 : bool = True ,
142+ http2 : bool = False ,
143+ limits = DEFAULT_LIMITS ,
144+ proxy = None ,
145+ uds : str | None = None ,
146+ local_address : str | None = None ,
147+ retries : int = 0 ,
148+ socket_options = None ,
149+ ) -> None :
150+ proxy = Proxy (url = proxy ) if isinstance (proxy , (str , URL )) else proxy
151+ ssl_context = create_ssl_context (verify = verify , cert = cert , trust_env = trust_env )
152+
153+ if proxy and proxy .url .scheme in ("http" , "https" ):
154+ self ._pool = HTTPProxyHeaders (
155+ proxy_url = URL (
156+ scheme = proxy .url .raw_scheme ,
157+ host = proxy .url .raw_host ,
158+ port = proxy .url .port ,
159+ target = proxy .url .raw_path ,
160+ ),
161+ proxy_auth = proxy .raw_auth ,
162+ proxy_headers = proxy .headers .raw ,
163+ ssl_context = ssl_context ,
164+ proxy_ssl_context = proxy .ssl_context ,
165+ max_connections = limits .max_connections ,
166+ max_keepalive_connections = limits .max_keepalive_connections ,
167+ keepalive_expiry = limits .keepalive_expiry ,
168+ http1 = http1 ,
169+ http2 = http2 ,
170+ socket_options = socket_options ,
171+ )
172+ else :
173+ super ().__init__ (verify , cert , trust_env , http1 , http2 , limits , proxy , uds , local_address , retries , socket_options )
0 commit comments