@@ -85,6 +85,51 @@ func newTLSServer(t *testing.T) *cstServer {
8585 return & s
8686}
8787
88+ type cstProxyServer struct {}
89+
90+ func (s * cstProxyServer ) ServeHTTP (w http.ResponseWriter , req * http.Request ) {
91+ if req .Method != http .MethodConnect {
92+ http .Error (w , "method not allowed" , http .StatusMethodNotAllowed )
93+ return
94+ }
95+
96+ conn , _ , err := w .(http.Hijacker ).Hijack ()
97+ if err != nil {
98+ http .Error (w , err .Error (), http .StatusInternalServerError )
99+ return
100+ }
101+ defer conn .Close ()
102+
103+ upstream , err := (& net.Dialer {}).DialContext (req .Context (), "tcp" , req .URL .Host )
104+ if err != nil {
105+ _ , _ = fmt .Fprintf (conn , "HTTP/1.1 502 Bad Gateway\r \n \r \n " )
106+ return
107+ }
108+ defer upstream .Close ()
109+
110+ _ , _ = fmt .Fprintf (conn , "HTTP/1.1 200 Connection established\r \n \r \n " )
111+
112+ wg := sync.WaitGroup {}
113+ wg .Add (2 )
114+ go func () {
115+ defer wg .Done ()
116+ _ , _ = io .Copy (upstream , conn )
117+ }()
118+ go func () {
119+ defer wg .Done ()
120+ _ , _ = io .Copy (conn , upstream )
121+ }()
122+ wg .Wait ()
123+ }
124+
125+ func newProxyServer () * httptest.Server {
126+ return httptest .NewServer (& cstProxyServer {})
127+ }
128+
129+ func newTLSProxyServer () * httptest.Server {
130+ return httptest .NewTLSServer (& cstProxyServer {})
131+ }
132+
88133func (t cstHandler ) ServeHTTP (w http.ResponseWriter , r * http.Request ) {
89134 // Because tests wait for a response from a server, we are guaranteed that
90135 // the wait group count is incremented before the test waits on the group
@@ -165,7 +210,6 @@ func sendRecv(t *testing.T, ws *Conn) {
165210}
166211
167212func TestProxyDial (t * testing.T ) {
168-
169213 s := newServer (t )
170214 defer s .Close ()
171215
@@ -202,6 +246,106 @@ func TestProxyDial(t *testing.T) {
202246 sendRecv (t , ws )
203247}
204248
249+ func TestProxyDialer (t * testing.T ) {
250+ testcases := []struct {
251+ name string
252+ isTLS bool
253+ tlsServerName string // optional host for tls ServerName
254+ insecureSkipVerify bool
255+ netDialTLSContext func (ctx context.Context , network , addr string ) (net.Conn , error )
256+ }{{
257+ name : "http" ,
258+ isTLS : false ,
259+ }, {
260+ name : "https" ,
261+ isTLS : true ,
262+ }, {
263+ name : "https with ServerName" ,
264+ isTLS : true ,
265+ tlsServerName : "example.com" ,
266+ }, {
267+ name : "https with insecureSkipVerify" ,
268+ isTLS : true ,
269+ insecureSkipVerify : true ,
270+ }, {
271+ name : "https with netDialTLSContext" ,
272+ isTLS : true ,
273+ netDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
274+ dialer := & tls.Dialer {
275+ Config : & tls.Config {
276+ InsecureSkipVerify : true ,
277+ },
278+ }
279+ return dialer .DialContext (ctx , network , addr )
280+ },
281+ }}
282+
283+ for _ , tc := range testcases {
284+ t .Run (tc .name , func (tt * testing.T ) {
285+ s := newServer (tt )
286+ defer s .Close ()
287+
288+ var ps * httptest.Server
289+ if tc .isTLS {
290+ ps = newTLSProxyServer ()
291+ } else {
292+ ps = newProxyServer ()
293+ }
294+
295+ psurl , _ := url .Parse (ps .URL )
296+
297+ netDialCalled := false
298+
299+ cstDialer := cstDialer // make local copy for modification on next line.
300+ cstDialer .Proxy = http .ProxyURL (psurl )
301+ if tc .isTLS {
302+ cstDialer .TLSClientConfig = & tls.Config {
303+ RootCAs : rootCAs (tt , ps ),
304+ ServerName : tc .tlsServerName ,
305+ InsecureSkipVerify : tc .insecureSkipVerify ,
306+ }
307+ if tc .netDialTLSContext != nil {
308+ cstDialer .NetDialTLSContext = func (ctx context.Context , network , addr string ) (net.Conn , error ) {
309+ netDialCalled = true
310+ return tc .netDialTLSContext (ctx , network , addr )
311+ }
312+ } else {
313+ netDialCalled = true
314+ }
315+ } else {
316+ netDialCalled = true
317+ }
318+
319+ connect := false
320+ origHandler := ps .Config .Handler
321+
322+ // Capture the request Host header.
323+ ps .Config .Handler = http .HandlerFunc (
324+ func (w http.ResponseWriter , r * http.Request ) {
325+ if r .Method == http .MethodConnect {
326+ connect = true
327+ }
328+
329+ origHandler .ServeHTTP (w , r )
330+ })
331+
332+ ws , _ , err := cstDialer .Dial (s .URL , nil )
333+ if err != nil {
334+ tt .Fatalf ("Dial: %v" , err )
335+ }
336+ defer ws .Close ()
337+ sendRecv (tt , ws )
338+
339+ if ! connect {
340+ tt .Error ("connect not received" )
341+ }
342+ if ! netDialCalled {
343+ tt .Error ("netDialTLSContext not called" )
344+ }
345+ })
346+ }
347+ }
348+
205349func TestProxyAuthorizationDial (t * testing.T ) {
206350 s := newServer (t )
207351 defer s .Close ()
@@ -652,7 +796,7 @@ func TestHost(t *testing.T) {
652796 server * httptest.Server // server to use
653797 url string // host for request URI
654798 header string // optional request host header
655- tls string // optional host for tls ServerName
799+ tls string // optional host for tlsServerName ServerName
656800 wantAddr string // expected host for dial
657801 wantHeader string // expected request header on server
658802 insecureSkipVerify bool
@@ -759,7 +903,7 @@ func TestHost(t *testing.T) {
759903 }
760904
761905 check := func (protos map [* httptest.Server ]string ) {
762- name := fmt .Sprintf ("%d: %s%s/ header[Host]=%q, tls .ServerName=%q" , i + 1 , protos [tt .server ], tt .url , tt .header , tt .tls )
906+ name := fmt .Sprintf ("%d: %s%s/ header[Host]=%q, tlsServerName .ServerName=%q" , i + 1 , protos [tt .server ], tt .url , tt .header , tt .tls )
763907 if gotAddr != tt .wantAddr {
764908 t .Errorf ("%s: got addr %s, want %s" , name , gotAddr , tt .wantAddr )
765909 }
0 commit comments