@@ -21,6 +21,7 @@ import (
2121 "fmt"
2222 "net"
2323 "sync"
24+ "sync/atomic"
2425)
2526
2627// connErrPair pairs conn and error which is returned by accept on sub-listeners.
@@ -38,6 +39,7 @@ type multiListener struct {
3839 connCh chan connErrPair
3940 // stopCh communicates from parent to child listeners.
4041 stopCh chan struct {}
42+ closed atomic.Bool
4143}
4244
4345// compile time check to ensure *multiListener implements net.Listener
@@ -59,9 +61,8 @@ func MultiListen(ctx context.Context, network string, addrs ...string) (net.List
5961 ctx ,
6062 network ,
6163 addrs ,
62- func (ctx context.Context , network , address string ) (net.Listener , error ) {
63- return lc .Listen (ctx , network , address )
64- })
64+ lc .Listen ,
65+ )
6566}
6667
6768// multiListen implements MultiListen by consuming stdlib functions as dependency allowing
@@ -150,10 +151,8 @@ func (ml *multiListener) Accept() (net.Conn, error) {
150151// the go-routines to exit.
151152func (ml * multiListener ) Close () error {
152153 // Make sure this can be called repeatedly without explosions.
153- select {
154- case <- ml .stopCh :
154+ if ! ml .closed .CompareAndSwap (false , true ) {
155155 return fmt .Errorf ("use of closed network connection" )
156- default :
157156 }
158157
159158 // Tell all sub-listeners to stop.
@@ -169,12 +168,6 @@ func (ml *multiListener) Close() error {
169168 ml .wg .Wait ()
170169 close (ml .connCh )
171170
172- // Drain any already-queued connections.
173- for connErr := range ml .connCh {
174- if connErr .conn != nil {
175- _ = connErr .conn .Close ()
176- }
177- }
178171 return nil
179172}
180173
@@ -187,7 +180,7 @@ func (ml *multiListener) Addr() net.Addr {
187180
188181// Addrs is like Addr, but returns the address for all registered listeners.
189182func (ml * multiListener ) Addrs () []net.Addr {
190- var ret []net.Addr
183+ ret := make ( []net.Addr , 0 , len ( ml . listeners ))
191184 for _ , l := range ml .listeners {
192185 ret = append (ret , l .Addr ())
193186 }
0 commit comments