From 3b89f41d086c963eb6e3fc49a423dc9af93dd2ba Mon Sep 17 00:00:00 2001 From: Alexander Yastrebov Date: Thu, 10 Oct 2024 00:05:53 +0200 Subject: [PATCH 1/4] Revert "protocol: avoid double buffering" This reverts commit 2df67b4704aa55bdc60e50b75337678625b64c39. --- protocol.go | 31 +++---------------------------- 1 file changed, 3 insertions(+), 28 deletions(-) diff --git a/protocol.go b/protocol.go index 270b90d..7eda3d6 100644 --- a/protocol.go +++ b/protocol.go @@ -52,7 +52,6 @@ type Conn struct { readErr error conn net.Conn bufReader *bufio.Reader - reader io.Reader header *Header ProxyHeaderPolicy Policy Validate Validator @@ -155,11 +154,9 @@ func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { // For v2 the header length is at most 52 bytes plus the length of the TLVs. // We use 256 bytes to be safe. const bufSize = 256 - br := bufio.NewReaderSize(conn, bufSize) pConn := &Conn{ - bufReader: br, - reader: io.MultiReader(br, conn), + bufReader: bufio.NewReaderSize(conn, bufSize), conn: conn, } @@ -181,7 +178,7 @@ func (p *Conn) Read(b []byte) (int, error) { return 0, p.readErr } - return p.reader.Read(b) + return p.bufReader.Read(b) } // Write wraps original conn.Write @@ -363,27 +360,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) { if p.readErr != nil { return 0, p.readErr } - - b := make([]byte, p.bufReader.Buffered()) - if _, err := p.bufReader.Read(b); err != nil { - return 0, err // this should never as we read buffered data - } - - var n int64 - { - nn, err := w.Write(b) - n += int64(nn) - if err != nil { - return n, err - } - } - { - nn, err := io.Copy(w, p.conn) - n += nn - if err != nil { - return n, err - } - } - - return n, nil + return p.bufReader.WriteTo(w) } From c73338ac5dfd7e0902e05d6b80462493479d2d2b Mon Sep 17 00:00:00 2001 From: Alexander Yastrebov Date: Thu, 10 Oct 2024 02:02:09 +0200 Subject: [PATCH 2/4] Buffer only proxy header data Reverts and re-implements incorrect #116 which passed all connection data through the buffered reader. --- header.go | 1 + protocol.go | 36 +++++++++++++++++++++++++----------- v1.go | 1 + v2.go | 1 + 4 files changed, 28 insertions(+), 11 deletions(-) diff --git a/header.go b/header.go index 209c2cc..791e94e 100644 --- a/header.go +++ b/header.go @@ -42,6 +42,7 @@ type Header struct { SourceAddr net.Addr DestinationAddr net.Addr rawTLVs []byte + length int } // HeaderProxyFromAddrs creates a new PROXY header from a source and a diff --git a/protocol.go b/protocol.go index 7eda3d6..b08732c 100644 --- a/protocol.go +++ b/protocol.go @@ -2,6 +2,7 @@ package proxyproto import ( "bufio" + "bytes" "errors" "fmt" "io" @@ -51,7 +52,7 @@ type Conn struct { once sync.Once readErr error conn net.Conn - bufReader *bufio.Reader + reader io.Reader header *Header ProxyHeaderPolicy Policy Validate Validator @@ -150,14 +151,8 @@ func (p *Listener) Addr() net.Addr { // NewConn is used to wrap a net.Conn that may be speaking // the proxy protocol into a proxyproto.Conn func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn { - // For v1 the header length is at most 108 bytes. - // For v2 the header length is at most 52 bytes plus the length of the TLVs. - // We use 256 bytes to be safe. - const bufSize = 256 - pConn := &Conn{ - bufReader: bufio.NewReaderSize(conn, bufSize), - conn: conn, + conn: conn, } for _, opt := range opts { @@ -178,7 +173,7 @@ func (p *Conn) Read(b []byte) (int, error) { return 0, p.readErr } - return p.bufReader.Read(b) + return p.reader.Read(b) } // Write wraps original conn.Write @@ -294,7 +289,26 @@ func (p *Conn) readHeader() error { } } - header, err := Read(p.bufReader) + // For v1 the header length is at most 108 bytes. + // For v2 the header length is at most 52 bytes plus the length of the TLVs. + // We use 256 bytes to be safe. + const bufSize = 256 + + bb := bytes.NewBuffer(make([]byte, 0, bufSize)) + tr := io.TeeReader(p.conn, bb) + br := bufio.NewReaderSize(tr, bufSize) + + header, err := Read(br) + + if err == nil { + _, err = io.CopyN(io.Discard, bb, int64(header.length)) + } + + if bb.Len() == 0 { + p.reader = p.conn + } else { + p.reader = io.MultiReader(bb, p.conn) + } // If the connection's readHeaderTimeout is more than 0, undo the change to the // deadline that we made above. Because we retain the readDeadline as part of our @@ -360,5 +374,5 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) { if p.readErr != nil { return 0, p.readErr } - return p.bufReader.WriteTo(w) + return io.Copy(w, p.reader) } diff --git a/v1.go b/v1.go index 0d34ba5..3c948eb 100644 --- a/v1.go +++ b/v1.go @@ -125,6 +125,7 @@ func parseVersion1(reader *bufio.Reader) (*Header, error) { // Command doesn't exist in v1 but set it for other parts of this library // to rely on it for determining connection details. header := initVersion1() + header.length = len(buf) // Transport protocol has been processed already. header.TransportProtocol = transportProtocol diff --git a/v2.go b/v2.go index 74bf3f0..2dde097 100644 --- a/v2.go +++ b/v2.go @@ -100,6 +100,7 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) { if !header.validateLength(length) { return nil, ErrInvalidLength } + header.length = 16 + int(length) // Return early if the length is zero, which means that // there's no address information and TLVs present for UNSPEC. From 8d833b88f9dca074545cd9409d2add245c2ee0ae Mon Sep 17 00:00:00 2001 From: Alexander Yastrebov Date: Tue, 15 Oct 2024 23:15:03 +0200 Subject: [PATCH 3/4] Use buffered number to calculate header length --- header.go | 1 - protocol.go | 5 ++--- v1.go | 1 - v2.go | 1 - 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/header.go b/header.go index 791e94e..209c2cc 100644 --- a/header.go +++ b/header.go @@ -42,7 +42,6 @@ type Header struct { SourceAddr net.Addr DestinationAddr net.Addr rawTLVs []byte - length int } // HeaderProxyFromAddrs creates a new PROXY header from a source and a diff --git a/protocol.go b/protocol.go index b08732c..5d0c69a 100644 --- a/protocol.go +++ b/protocol.go @@ -295,13 +295,12 @@ func (p *Conn) readHeader() error { const bufSize = 256 bb := bytes.NewBuffer(make([]byte, 0, bufSize)) - tr := io.TeeReader(p.conn, bb) - br := bufio.NewReaderSize(tr, bufSize) + br := bufio.NewReaderSize(io.TeeReader(p.conn, bb), bufSize) header, err := Read(br) if err == nil { - _, err = io.CopyN(io.Discard, bb, int64(header.length)) + _, err = io.CopyN(io.Discard, bb, int64(bb.Len()-br.Buffered())) } if bb.Len() == 0 { diff --git a/v1.go b/v1.go index 3c948eb..0d34ba5 100644 --- a/v1.go +++ b/v1.go @@ -125,7 +125,6 @@ func parseVersion1(reader *bufio.Reader) (*Header, error) { // Command doesn't exist in v1 but set it for other parts of this library // to rely on it for determining connection details. header := initVersion1() - header.length = len(buf) // Transport protocol has been processed already. header.TransportProtocol = transportProtocol diff --git a/v2.go b/v2.go index 2dde097..74bf3f0 100644 --- a/v2.go +++ b/v2.go @@ -100,7 +100,6 @@ func parseVersion2(reader *bufio.Reader) (header *Header, err error) { if !header.validateLength(length) { return nil, ErrInvalidLength } - header.length = 16 + int(length) // Return early if the length is zero, which means that // there's no address information and TLVs present for UNSPEC. From 0067a883084a2b0bf527dfeec72f2a4532474db4 Mon Sep 17 00:00:00 2001 From: Alexander Yastrebov Date: Tue, 15 Oct 2024 23:21:53 +0200 Subject: [PATCH 4/4] Use Next instead of io.Discard to skip header --- protocol.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/protocol.go b/protocol.go index 5d0c69a..57ea10f 100644 --- a/protocol.go +++ b/protocol.go @@ -300,7 +300,7 @@ func (p *Conn) readHeader() error { header, err := Read(br) if err == nil { - _, err = io.CopyN(io.Discard, bb, int64(bb.Len()-br.Buffered())) + _ = bb.Next(bb.Len() - br.Buffered()) // skip header } if bb.Len() == 0 {