@@ -2,6 +2,7 @@ package proxyproto
22
33import (
44 "bufio"
5+ "bytes"
56 "errors"
67 "fmt"
78 "io"
@@ -51,7 +52,6 @@ type Conn struct {
5152 once sync.Once
5253 readErr error
5354 conn net.Conn
54- bufReader * bufio.Reader
5555 reader io.Reader
5656 header * Header
5757 ProxyHeaderPolicy Policy
@@ -151,16 +151,8 @@ func (p *Listener) Addr() net.Addr {
151151// NewConn is used to wrap a net.Conn that may be speaking
152152// the proxy protocol into a proxyproto.Conn
153153func NewConn (conn net.Conn , opts ... func (* Conn )) * Conn {
154- // For v1 the header length is at most 108 bytes.
155- // For v2 the header length is at most 52 bytes plus the length of the TLVs.
156- // We use 256 bytes to be safe.
157- const bufSize = 256
158- br := bufio .NewReaderSize (conn , bufSize )
159-
160154 pConn := & Conn {
161- bufReader : br ,
162- reader : io .MultiReader (br , conn ),
163- conn : conn ,
155+ conn : conn ,
164156 }
165157
166158 for _ , opt := range opts {
@@ -297,7 +289,23 @@ func (p *Conn) readHeader() error {
297289 }
298290 }
299291
300- header , err := Read (p .bufReader )
292+ // For v1 the header length is at most 108 bytes.
293+ // For v2 the header length is at most 52 bytes plus the length of the TLVs.
294+ // We use 256 bytes to be safe.
295+ const bufSize = 256
296+ br := bufio .NewReaderSize (p .conn , bufSize )
297+
298+ header , err := Read (br )
299+
300+ if br .Buffered () != 0 {
301+ buf := make ([]byte , br .Buffered ())
302+ if _ , err := br .Read (buf ); err != nil {
303+ return err // this should never as we read buffered data
304+ }
305+ p .reader = io .MultiReader (bytes .NewReader (buf ), p .conn )
306+ } else {
307+ p .reader = p .conn
308+ }
301309
302310 // If the connection's readHeaderTimeout is more than 0, undo the change to the
303311 // deadline that we made above. Because we retain the readDeadline as part of our
@@ -364,26 +372,8 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
364372 return 0 , p .readErr
365373 }
366374
367- b := make ([]byte , p .bufReader .Buffered ())
368- if _ , err := p .bufReader .Read (b ); err != nil {
369- return 0 , err // this should never as we read buffered data
370- }
371-
372- var n int64
373- {
374- nn , err := w .Write (b )
375- n += int64 (nn )
376- if err != nil {
377- return n , err
378- }
375+ if wt , ok := p .reader .(io.WriterTo ); ok {
376+ return wt .WriteTo (w )
379377 }
380- {
381- nn , err := io .Copy (w , p .conn )
382- n += nn
383- if err != nil {
384- return n , err
385- }
386- }
387-
388- return n , nil
378+ return io .Copy (w , p .reader )
389379}
0 commit comments