@@ -101,6 +101,10 @@ var proxyHandler = http.HandlerFunc(func(w http.ResponseWriter, req *http.Reques
101101 <- done
102102})
103103
104+ // Permutation 1
105+ //
106+ // Proxy: HTTP
107+ // Backend: HTTP
104108func TestHTTPProxyAndBackend (t * testing.T ) {
105109 // Start the websocket server, which echoes data back to sender.
106110 websocketServer := httptest .NewServer (websocketEchoHandler )
@@ -151,6 +155,11 @@ func TestHTTPProxyAndBackend(t *testing.T) {
151155 }
152156}
153157
158+ // Permutation 2
159+ //
160+ // Proxy: HTTP
161+ // Backend: HTTP
162+ // DialFn: NetDial (dials proxy)
154163func TestHTTPProxyWithNetDial (t * testing.T ) {
155164 // Start the websocket server, which echoes data back to sender.
156165 websocketServer := httptest .NewServer (websocketEchoHandler )
@@ -209,6 +218,11 @@ func TestHTTPProxyWithNetDial(t *testing.T) {
209218 }
210219}
211220
221+ // Permutation 3
222+ //
223+ // Proxy: HTTP
224+ // Backend: HTTP
225+ // DialFn: NetDialContext (dials proxy)
212226func TestHTTPProxyWithNetDialContext (t * testing.T ) {
213227 // Start the websocket server, which echoes data back to sender.
214228 websocketServer := httptest .NewServer (websocketEchoHandler )
@@ -267,6 +281,11 @@ func TestHTTPProxyWithNetDialContext(t *testing.T) {
267281 }
268282}
269283
284+ // Permutation 4
285+ //
286+ // Proxy: HTTPS
287+ // Backend: HTTPS
288+ // TLS Config: set (used for both proxy and backend TLS)
270289func TestHTTPSProxyAndBackend (t * testing.T ) {
271290 // Start the websocket server running TLS.
272291 cert , err := tls .X509KeyPair (localhostCert , localhostKey )
@@ -335,6 +354,12 @@ func TestHTTPSProxyAndBackend(t *testing.T) {
335354 }
336355}
337356
357+ // Permutation 5
358+ //
359+ // Proxy: HTTPS
360+ // Backend: HTTPS
361+ // DialFn: NetDial (used to dial proxy)
362+ // TLS Config: set (used for both proxy and backend TLS)
338363func TestHTTPSProxyUsingNetDial (t * testing.T ) {
339364 // Start the websocket server running TLS.
340365 cert , err := tls .X509KeyPair (localhostCert , localhostKey )
@@ -413,6 +438,12 @@ func TestHTTPSProxyUsingNetDial(t *testing.T) {
413438 }
414439}
415440
441+ // Permutation 6
442+ //
443+ // Proxy: HTTPS
444+ // Backend: HTTPS
445+ // DialFn: NetDialContext (used to dial proxy)
446+ // TLS Config: set (used for both proxy and backend TLS)
416447func TestHTTPSProxyUsingNetDialContext (t * testing.T ) {
417448 // Start the websocket server running TLS.
418449 cert , err := tls .X509KeyPair (localhostCert , localhostKey )
@@ -491,6 +522,168 @@ func TestHTTPSProxyUsingNetDialContext(t *testing.T) {
491522 }
492523}
493524
525+ // Permutation 7
526+ //
527+ // Proxy: HTTPS
528+ // Backend: HTTPS
529+ // DialFn: NetDialTLSContext (used for proxy TLS)
530+ // TLS Config: set (used for backend TLS)
531+ func TestHTTPSProxyUsingNetDialTLSContext (t * testing.T ) {
532+ // Start the websocket server running TLS.
533+ cert , err := tls .X509KeyPair (localhostCert , localhostKey )
534+ if err != nil {
535+ t .Fatalf ("error creating TLS key pair: %v" , err )
536+ }
537+ websocketServer := httptest .NewUnstartedServer (websocketEchoHandler )
538+ websocketServer .TLS = & tls.Config {
539+ Certificates : []tls.Certificate {cert },
540+ }
541+ websocketServer .StartTLS ()
542+ defer websocketServer .Close ()
543+ websocketURL , err := url .Parse (websocketServer .URL )
544+ if err != nil {
545+ t .Fatalf ("error parsing websocket server URL: %v" , err )
546+ }
547+ // Start the proxy server running TLS.
548+ var proxyCalled atomic.Int64
549+ proxyServer := httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
550+ proxyCalled .Add (1 )
551+ proxyHandler .ServeHTTP (w , req )
552+ }))
553+ proxyServer .TLS = & tls.Config {
554+ Certificates : []tls.Certificate {cert },
555+ }
556+ proxyServer .StartTLS ()
557+ defer proxyServer .Close ()
558+ proxyServerURL , err := url .Parse (proxyServer .URL )
559+ if err != nil {
560+ t .Fatalf ("error parsing websocket server URL: %v" , err )
561+ }
562+ // Dial the websocket server to create the websocket connection through
563+ // the proxy. The "NetDialTLSContext" function to dials the proxy and
564+ // performs the TLS handshake. NOTE: Subsequent TLS handshake to backend
565+ // (over proxied connection) uses TLSClientConfig for handshake.
566+ certPool := x509 .NewCertPool ()
567+ certPool .AppendCertsFromPEM (localhostCert )
568+ tlsConfig := & tls.Config {RootCAs : certPool }
569+ var netDialCalled atomic.Int64
570+ dialer := Dialer {
571+ Proxy : http .ProxyURL (proxyServerURL ),
572+ // Dial and TLS handshake function to proxy.
573+ NetDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
574+ netDialCalled .Add (1 )
575+ return tls .Dial (network , addr , tlsConfig )
576+ },
577+ // Used for second TLS handshake to backend server over previously
578+ // established proxied connection.
579+ TLSClientConfig : tlsConfig ,
580+ Subprotocols : []string {subprotocolv1 },
581+ }
582+ websocketURL .Scheme = "wss"
583+ wsClient , _ , err := dialer .Dial (websocketURL .String (), nil )
584+ if err != nil {
585+ t .Fatalf ("websocket dial error: %v" , err )
586+ }
587+ // Generate random data to send/receive over websocket connection.
588+ randomSize := 128 * 1024
589+ randomData := make ([]byte , randomSize )
590+ if _ , err := rand .Read (randomData ); err != nil {
591+ t .Errorf ("unexpected error reading random data: %v" , err )
592+ }
593+ err = wsClient .WriteMessage (BinaryMessage , randomData )
594+ if err != nil {
595+ t .Errorf ("websocket write error: %v" , err )
596+ }
597+ // Read all the data from the websocket connection, then verify
598+ _ , received , err := wsClient .ReadMessage ()
599+ if ! bytes .Equal (randomData , received ) {
600+ t .Errorf ("unexpected data received: %d bytes sent, %d bytes received" ,
601+ len (received ), len (randomData ))
602+ }
603+ if e , a := int64 (1 ), netDialCalled .Load (); e != a {
604+ t .Errorf ("netDial not called" )
605+ }
606+ if e , a := int64 (1 ), proxyCalled .Load (); e != a {
607+ t .Errorf ("proxy not called" )
608+ }
609+ }
610+
611+ // Permutation 8
612+ //
613+ // Proxy: HTTPS
614+ // Backend: HTTP
615+ // DialFn: NetDialTLSContext (used for proxy TLS)
616+ func TestHTTPSProxyUsingNetDialTLSContextWithHTTPBackend (t * testing.T ) {
617+ // Start the websocket server.
618+ websocketServer := httptest .NewUnstartedServer (websocketEchoHandler )
619+ websocketServer .Start ()
620+ defer websocketServer .Close ()
621+ websocketURL , err := url .Parse (websocketServer .URL )
622+ if err != nil {
623+ t .Fatalf ("error parsing websocket server URL: %v" , err )
624+ }
625+ // Start the proxy server running TLS.
626+ cert , err := tls .X509KeyPair (localhostCert , localhostKey )
627+ if err != nil {
628+ t .Fatalf ("error creating TLS key pair: %v" , err )
629+ }
630+ var proxyCalled atomic.Int64
631+ proxyServer := httptest .NewUnstartedServer (http .HandlerFunc (func (w http.ResponseWriter , req * http.Request ) {
632+ proxyCalled .Add (1 )
633+ proxyHandler .ServeHTTP (w , req )
634+ }))
635+ proxyServer .TLS = & tls.Config {
636+ Certificates : []tls.Certificate {cert },
637+ }
638+ proxyServer .StartTLS ()
639+ defer proxyServer .Close ()
640+ proxyServerURL , err := url .Parse (proxyServer .URL )
641+ if err != nil {
642+ t .Fatalf ("error parsing websocket server URL: %v" , err )
643+ }
644+ // Dials websocket backend through HTTPS proxy, using NetDialTLSContext.
645+ certPool := x509 .NewCertPool ()
646+ certPool .AppendCertsFromPEM (localhostCert )
647+ tlsConfig := & tls.Config {RootCAs : certPool }
648+ var netDialCalled atomic.Int64
649+ dialer := Dialer {
650+ Proxy : http .ProxyURL (proxyServerURL ),
651+ // Dial and TLS handshake function to proxy.
652+ NetDialTLSContext : func (ctx context.Context , network , addr string ) (net.Conn , error ) {
653+ netDialCalled .Add (1 )
654+ return tls .Dial (network , addr , tlsConfig )
655+ },
656+ Subprotocols : []string {subprotocolv1 },
657+ }
658+ websocketURL .Scheme = "ws"
659+ wsClient , _ , err := dialer .Dial (websocketURL .String (), nil )
660+ if err != nil {
661+ t .Fatalf ("websocket dial error: %v" , err )
662+ }
663+ // Generate random data to send/receive over websocket connection.
664+ randomSize := 128 * 1024
665+ randomData := make ([]byte , randomSize )
666+ if _ , err := rand .Read (randomData ); err != nil {
667+ t .Errorf ("unexpected error reading random data: %v" , err )
668+ }
669+ err = wsClient .WriteMessage (BinaryMessage , randomData )
670+ if err != nil {
671+ t .Errorf ("websocket write error: %v" , err )
672+ }
673+ // Read all the data from the websocket connection, then verify
674+ _ , received , err := wsClient .ReadMessage ()
675+ if ! bytes .Equal (randomData , received ) {
676+ t .Errorf ("unexpected data received: %d bytes sent, %d bytes received" ,
677+ len (received ), len (randomData ))
678+ }
679+ if e , a := int64 (1 ), netDialCalled .Load (); e != a {
680+ t .Errorf ("netDial not called" )
681+ }
682+ if e , a := int64 (1 ), proxyCalled .Load (); e != a {
683+ t .Errorf ("proxy not called" )
684+ }
685+ }
686+
494687// localhostCert was generated from crypto/tls/generate_cert.go with the following command:
495688//
496689// go run generate_cert.go --rsa-bits 2048 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
0 commit comments