Skip to content

Commit 8a8a4d7

Browse files
committed
protocol: limit use of buffered reader
Fix bug introduced in pires#116 where io.MultiReader only reads from buffered reader. Move buffer reader management to readHeader(). Remove Conn.bufReader, and make Conn.reader nil until readHeader() is called.
1 parent bac82fd commit 8a8a4d7

File tree

1 file changed

+22
-32
lines changed

1 file changed

+22
-32
lines changed

protocol.go

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package proxyproto
22

33
import (
44
"bufio"
5+
"bytes"
56
"errors"
67
"fmt"
78
"io"
@@ -51,7 +52,6 @@ type Conn struct {
5152
once sync.Once
5253
readErr error
5354
conn net.Conn
54-
bufReader *bufio.Reader
5555
reader io.Reader
5656
header *Header
5757
ProxyHeaderPolicy Policy
@@ -151,16 +151,8 @@ func (p *Listener) Addr() net.Addr {
151151
// NewConn is used to wrap a net.Conn that may be speaking
152152
// the proxy protocol into a proxyproto.Conn
153153
func NewConn(conn net.Conn, opts ...func(*Conn)) *Conn {
154-
// For v1 the header length is at most 108 bytes.
155-
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
156-
// We use 256 bytes to be safe.
157-
const bufSize = 256
158-
br := bufio.NewReaderSize(conn, bufSize)
159-
160154
pConn := &Conn{
161-
bufReader: br,
162-
reader: io.MultiReader(br, conn),
163-
conn: conn,
155+
conn: conn,
164156
}
165157

166158
for _, opt := range opts {
@@ -297,7 +289,23 @@ func (p *Conn) readHeader() error {
297289
}
298290
}
299291

300-
header, err := Read(p.bufReader)
292+
// For v1 the header length is at most 108 bytes.
293+
// For v2 the header length is at most 52 bytes plus the length of the TLVs.
294+
// We use 256 bytes to be safe.
295+
const bufSize = 256
296+
br := bufio.NewReaderSize(p.conn, bufSize)
297+
298+
header, err := Read(br)
299+
300+
if br.Buffered() != 0 {
301+
buf := make([]byte, br.Buffered())
302+
if _, err := br.Read(buf); err != nil {
303+
return err // this should never as we read buffered data
304+
}
305+
p.reader = io.MultiReader(bytes.NewReader(buf), p.conn)
306+
} else {
307+
p.reader = p.conn
308+
}
301309

302310
// If the connection's readHeaderTimeout is more than 0, undo the change to the
303311
// deadline that we made above. Because we retain the readDeadline as part of our
@@ -364,26 +372,8 @@ func (p *Conn) WriteTo(w io.Writer) (int64, error) {
364372
return 0, p.readErr
365373
}
366374

367-
b := make([]byte, p.bufReader.Buffered())
368-
if _, err := p.bufReader.Read(b); err != nil {
369-
return 0, err // this should never as we read buffered data
370-
}
371-
372-
var n int64
373-
{
374-
nn, err := w.Write(b)
375-
n += int64(nn)
376-
if err != nil {
377-
return n, err
378-
}
375+
if wt, ok := p.reader.(io.WriterTo); ok {
376+
return wt.WriteTo(w)
379377
}
380-
{
381-
nn, err := io.Copy(w, p.conn)
382-
n += nn
383-
if err != nil {
384-
return n, err
385-
}
386-
}
387-
388-
return n, nil
378+
return io.Copy(w, p.reader)
389379
}

0 commit comments

Comments
 (0)