Skip to content

Commit ce012aa

Browse files
committed
Fix panic in multiListener
If Close() is called concurrently, both goroutines could pass the select’s default case before either of them executes close(), resulting in the channel being closed twice
1 parent 4c0f3b2 commit ce012aa

File tree

2 files changed

+6
-16
lines changed

2 files changed

+6
-16
lines changed

net/multi_listen.go

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import (
2121
"fmt"
2222
"net"
2323
"sync"
24+
"sync/atomic"
2425
)
2526

2627
// connErrPair pairs conn and error which is returned by accept on sub-listeners.
@@ -38,6 +39,7 @@ type multiListener struct {
3839
connCh chan connErrPair
3940
// stopCh communicates from parent to child listeners.
4041
stopCh chan struct{}
42+
closed atomic.Bool
4143
}
4244

4345
// compile time check to ensure *multiListener implements net.Listener
@@ -59,9 +61,8 @@ func MultiListen(ctx context.Context, network string, addrs ...string) (net.List
5961
ctx,
6062
network,
6163
addrs,
62-
func(ctx context.Context, network, address string) (net.Listener, error) {
63-
return lc.Listen(ctx, network, address)
64-
})
64+
lc.Listen,
65+
)
6566
}
6667

6768
// multiListen implements MultiListen by consuming stdlib functions as dependency allowing
@@ -150,10 +151,8 @@ func (ml *multiListener) Accept() (net.Conn, error) {
150151
// the go-routines to exit.
151152
func (ml *multiListener) Close() error {
152153
// Make sure this can be called repeatedly without explosions.
153-
select {
154-
case <-ml.stopCh:
154+
if !ml.closed.CompareAndSwap(false, true) {
155155
return fmt.Errorf("use of closed network connection")
156-
default:
157156
}
158157

159158
// Tell all sub-listeners to stop.
@@ -169,12 +168,6 @@ func (ml *multiListener) Close() error {
169168
ml.wg.Wait()
170169
close(ml.connCh)
171170

172-
// Drain any already-queued connections.
173-
for connErr := range ml.connCh {
174-
if connErr.conn != nil {
175-
_ = connErr.conn.Close()
176-
}
177-
}
178171
return nil
179172
}
180173

@@ -187,7 +180,7 @@ func (ml *multiListener) Addr() net.Addr {
187180

188181
// Addrs is like Addr, but returns the address for all registered listeners.
189182
func (ml *multiListener) Addrs() []net.Addr {
190-
var ret []net.Addr
183+
ret := make([]net.Addr, 0, len(ml.listeners))
191184
for _, l := range ml.listeners {
192185
ret = append(ret, l.Addr())
193186
}

net/multi_listen_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,9 +116,6 @@ func listenFuncFactory(listeners []*fakeListener) func(_ context.Context, networ
116116
IP: ParseIPSloppy(host),
117117
Port: port,
118118
}
119-
if err != nil {
120-
return nil, err
121-
}
122119
listener.addr = addr
123120
index++
124121

0 commit comments

Comments
 (0)