@@ -41,21 +41,24 @@ func NewClient(config ClientConfig) *Client {
4141}
4242
4343func waitLim (ctx context.Context , rl ratelimit.Limiter ) error {
44+ // Quick context check before any blocking operation
4445 select {
4546 case <- ctx .Done ():
4647 return ctx .Err ()
4748 default :
48- done := make (chan struct {})
49- go func () {
50- rl .Take ()
51- close (done )
52- }()
53- select {
54- case <- done :
55- return nil
56- case <- ctx .Done ():
57- return ctx .Err ()
58- }
49+ }
50+
51+ done := make (chan struct {})
52+ go func () {
53+ defer close (done )
54+ rl .Take ()
55+ }()
56+
57+ select {
58+ case <- done :
59+ return nil
60+ case <- ctx .Done ():
61+ return ctx .Err ()
5962 }
6063}
6164
@@ -135,12 +138,17 @@ func (c *Client) connectPersistent(ctx context.Context, addrport string) error {
135138 eg , ctx := errgroup .WithContext (ctx )
136139 for i := 0 ; i < int (c .config .Connections ); i ++ {
137140 eg .Go (func () error {
138- conn , err := dialer .Dial ( "tcp" , addrport )
141+ conn , err := dialer .DialContext ( ctx , "tcp" , addrport )
139142 if err != nil {
140143 return fmt .Errorf ("dialing %q: %w" , addrport , err )
141144 }
142145 defer conn .Close ()
143146
147+ // Set deadlines based on context to make Read/Write operations interruptible
148+ if deadline , ok := ctx .Deadline (); ok {
149+ conn .SetDeadline (deadline )
150+ }
151+
144152 msgsTotal := int64 (c .config .Rate ) * int64 (c .config .Duration .Seconds ())
145153 limiter := ratelimit .New (int (c .config .Rate ))
146154
@@ -197,17 +205,25 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
197205 limiter := ratelimit .New (int (c .config .Rate ))
198206
199207 eg , ctx := errgroup .WithContext (ctx )
208+ ephemeralLoop:
200209 for i := int64 (0 ); i < connTotal ; i ++ {
210+ // Check for context cancellation at the start of each iteration
211+ select {
212+ case <- ctx .Done ():
213+ break ephemeralLoop
214+ default :
215+ }
216+
201217 if err := waitLim (ctx , limiter ); err != nil {
202218 if errors .Is (err , context .Canceled ) || errors .Is (err , context .DeadlineExceeded ) {
203- break
219+ break ephemeralLoop
204220 }
205221 continue
206222 }
207223
208224 eg .Go (func () error {
209225 return measureTime (addrport , c .config .MergeResultsEachHost , func () error {
210- conn , err := dialer .Dial ( "tcp" , addrport )
226+ conn , err := dialer .DialContext ( ctx , "tcp" , addrport )
211227 if err != nil {
212228 if errors .Is (err , syscall .ETIMEDOUT ) {
213229 slog .Warn ("connection timeout" , "addr" , addrport )
@@ -217,6 +233,11 @@ func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
217233 }
218234 defer conn .Close ()
219235
236+ // Set deadlines based on context to make Read/Write operations interruptible
237+ if deadline , ok := ctx .Deadline (); ok {
238+ conn .SetDeadline (deadline )
239+ }
240+
220241 if err := SetQuickAck (conn ); err != nil {
221242 return fmt .Errorf ("setting quick ack: %w" , err )
222243 }
@@ -267,22 +288,36 @@ func (c *Client) connectUDP(ctx context.Context, addrport string) error {
267288 }
268289
269290 eg , ctx := errgroup .WithContext (ctx )
291+ udpLoop:
270292 for i := int64 (0 ); i < connTotal ; i ++ {
293+ // Check for context cancellation at the start of each iteration
294+ select {
295+ case <- ctx .Done ():
296+ break udpLoop
297+ default :
298+ }
299+
271300 if err := waitLim (ctx , limiter ); err != nil {
272301 if errors .Is (err , context .Canceled ) || errors .Is (err , context .DeadlineExceeded ) {
273- break
302+ break udpLoop
274303 }
275304 continue
276305 }
277306
278307 eg .Go (func () error {
279308 return measureTime (addrport , c .config .MergeResultsEachHost , func () error {
280- conn , err := net .Dial ("udp4" , addrport )
309+ var dialer net.Dialer
310+ conn , err := dialer .DialContext (ctx , "udp4" , addrport )
281311 if err != nil {
282312 return fmt .Errorf ("dialing UDP %q: %w" , addrport , err )
283313 }
284314 defer conn .Close ()
285315
316+ // Set deadlines based on context to make Read/Write operations interruptible
317+ if deadline , ok := ctx .Deadline (); ok {
318+ conn .SetDeadline (deadline )
319+ }
320+
286321 msgPtr := bufUDPPool .Get ().(* []byte )
287322 msg := * msgPtr
288323 defer bufUDPPool .Put (msgPtr )
0 commit comments