@@ -12,6 +12,7 @@ import (
1212 "fmt"
1313 "io"
1414 "net"
15+ "net/http"
1516 "testing"
1617 "time"
1718)
@@ -82,7 +83,6 @@ func TestRequiredWithReadHeaderTimeout(t *testing.T) {
8283 start := time .Now ()
8384
8485 l , err := net .Listen ("tcp" , "127.0.0.1:0" )
85-
8686 if err != nil {
8787 t .Fatalf ("err: %v" , err )
8888 }
@@ -137,7 +137,6 @@ func TestUseWithReadHeaderTimeout(t *testing.T) {
137137 start := time .Now ()
138138
139139 l , err := net .Listen ("tcp" , "127.0.0.1:0" )
140-
141140 if err != nil {
142141 t .Fatalf ("err: %v" , err )
143142 }
@@ -847,6 +846,7 @@ func TestReadingIsRefusedWhenProxyHeaderPresentButNotAllowed(t *testing.T) {
847846 t .Fatalf ("client error: %v" , err )
848847 }
849848}
849+
850850func TestIgnorePolicyIgnoresIpFromProxyHeader (t * testing.T ) {
851851 l , err := net .Listen ("tcp" , "127.0.0.1:0" )
852852 if err != nil {
@@ -1274,6 +1274,48 @@ func Test_ConnectionErrorsWhenHeaderValidationFails(t *testing.T) {
12741274 }
12751275}
12761276
1277+ func Test_ConnectionHandlesInvalidUpstreamError (t * testing.T ) {
1278+ l , err := net .Listen ("tcp" , "localhost:8080" )
1279+ if err != nil {
1280+ t .Fatalf ("error creating listener: %v" , err )
1281+ }
1282+
1283+ times := 0
1284+
1285+ newLn := & Listener {
1286+ Listener : l ,
1287+ ConnPolicy : func (_ ConnPolicyOptions ) (Policy , error ) {
1288+ // Return the invalid upstream error on the first call, the listener
1289+ // should remain open and accepting.
1290+ if times == 0 {
1291+ times ++
1292+ return REJECT , ErrInvalidUpstream
1293+ }
1294+
1295+ return REJECT , ErrNoProxyProtocol
1296+ },
1297+ }
1298+
1299+ // Kick off the listener and capture any error.
1300+ var listenerErr error
1301+ go func (t * testing.T ) {
1302+ _ , listenerErr = newLn .Accept ()
1303+ }(t )
1304+
1305+ // Make two calls to trigger the listener's accept, the first should experience
1306+ // the ErrInvalidUpstream and keep the listener open, the second should experience
1307+ // a different error which will cause the listener to close.
1308+ _ , _ = http .Get ("http://localhost:8080" )
1309+ if listenerErr != nil {
1310+ t .Fatalf ("invalid upstream shouldn't return an error: %v" , listenerErr )
1311+ }
1312+
1313+ _ , _ = http .Get ("http://localhost:8080" )
1314+ if listenerErr == nil {
1315+ t .Fatalf ("errors other than invalid upstream should error" )
1316+ }
1317+ }
1318+
12771319type TestTLSServer struct {
12781320 Listener net.Listener
12791321
@@ -1482,9 +1524,11 @@ func (c *testConn) ReadFrom(r io.Reader) (int64, error) {
14821524 b , err := io .ReadAll (r )
14831525 return int64 (len (b )), err
14841526}
1527+
14851528func (c * testConn ) Write (p []byte ) (int , error ) {
14861529 return len (p ), nil
14871530}
1531+
14881532func (c * testConn ) Read (p []byte ) (int , error ) {
14891533 if c .reads == 0 {
14901534 return 0 , io .EOF
@@ -1533,7 +1577,7 @@ func TestCopyFromWrappedConnectionToWrappedConnection(t *testing.T) {
15331577}
15341578
15351579func benchmarkTCPProxy (size int , b * testing.B ) {
1536- //create and start the echo backend
1580+ // create and start the echo backend
15371581 backend , err := net .Listen ("tcp" , "127.0.0.1:0" )
15381582 if err != nil {
15391583 b .Fatalf ("err: %v" , err )
@@ -1554,7 +1598,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
15541598 }
15551599 }()
15561600
1557- //start the proxyprotocol enabled tcp proxy
1601+ // start the proxyprotocol enabled tcp proxy
15581602 l , err := net .Listen ("tcp" , "127.0.0.1:0" )
15591603 if err != nil {
15601604 b .Fatalf ("err: %v" , err )
@@ -1603,7 +1647,7 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16031647 },
16041648 }
16051649
1606- //now for the actual benchmark
1650+ // now for the actual benchmark
16071651 b .ResetTimer ()
16081652 for n := 0 ; n < b .N ; n ++ {
16091653 conn , err := net .Dial ("tcp" , pl .Addr ().String ())
@@ -1614,16 +1658,15 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16141658 if _ , err := header .WriteTo (conn ); err != nil {
16151659 b .Fatalf ("err: %v" , err )
16161660 }
1617- //send data
1661+ // send data
16181662 go func () {
16191663 _ , err = conn .Write (data )
16201664 _ = conn .(* net.TCPConn ).CloseWrite ()
16211665 if err != nil {
16221666 panic (fmt .Sprintf ("Failed to write data: %v" , err ))
16231667 }
1624-
16251668 }()
1626- //receive data
1669+ // receive data
16271670 n , err := io .Copy (io .Discard , conn )
16281671 if n != int64 (len (data )) {
16291672 b .Fatalf ("Expected to receive %d bytes, got %d" , len (data ), n )
@@ -1638,24 +1681,31 @@ func benchmarkTCPProxy(size int, b *testing.B) {
16381681func BenchmarkTCPProxy16KB (b * testing.B ) {
16391682 benchmarkTCPProxy (16 * 1024 , b )
16401683}
1684+
16411685func BenchmarkTCPProxy32KB (b * testing.B ) {
16421686 benchmarkTCPProxy (32 * 1024 , b )
16431687}
1688+
16441689func BenchmarkTCPProxy64KB (b * testing.B ) {
16451690 benchmarkTCPProxy (64 * 1024 , b )
16461691}
1692+
16471693func BenchmarkTCPProxy128KB (b * testing.B ) {
16481694 benchmarkTCPProxy (128 * 1024 , b )
16491695}
1696+
16501697func BenchmarkTCPProxy256KB (b * testing.B ) {
16511698 benchmarkTCPProxy (256 * 1024 , b )
16521699}
1700+
16531701func BenchmarkTCPProxy512KB (b * testing.B ) {
16541702 benchmarkTCPProxy (512 * 1024 , b )
16551703}
1704+
16561705func BenchmarkTCPProxy1024KB (b * testing.B ) {
16571706 benchmarkTCPProxy (1024 * 1024 , b )
16581707}
1708+
16591709func BenchmarkTCPProxy2048KB (b * testing.B ) {
16601710 benchmarkTCPProxy (2048 * 1024 , b )
16611711}
0 commit comments