@@ -29,6 +29,7 @@ use crate::common::get_or_generate_secret_key;
2929#[ derive( Debug , Parser ) ]
3030#[ command( version, about) ]
3131pub enum Args {
32+ /// Limit requests by node id
3233 ByNodeId {
3334 /// Path for files to add
3435 paths : Vec < PathBuf > ,
@@ -38,16 +39,19 @@ pub enum Args {
3839 #[ clap( long, default_value_t = 1 ) ]
3940 secrets : usize ,
4041 } ,
42+ /// Limit requests by hash, only first hash is allowed
4143 ByHash {
4244 /// Path for files to add
4345 paths : Vec < PathBuf > ,
4446 } ,
47+ /// Throttle requests
4548 Throttle {
4649 /// Path for files to add
4750 paths : Vec < PathBuf > ,
4851 #[ clap( long, default_value = "100" ) ]
4952 delay_ms : u64 ,
5053 } ,
54+ /// Limit maximum number of connections.
5155 MaxConnections {
5256 /// Path for files to add
5357 paths : Vec < PathBuf > ,
@@ -140,20 +144,39 @@ fn throttle(delay_ms: u64) -> EventSender {
140144}
141145
142146fn limit_max_connections ( max_connections : usize ) -> EventSender {
147+ #[ derive( Default , Debug , Clone ) ]
148+ struct ConnectionCounter ( Arc < ( AtomicUsize , usize ) > ) ;
149+
150+ impl ConnectionCounter {
151+ fn new ( max : usize ) -> Self {
152+ Self ( Arc :: new ( ( Default :: default ( ) , max) ) )
153+ }
154+
155+ fn inc ( & self ) -> Result < usize , usize > {
156+ let ( c, max) = & * self . 0 ;
157+ c. fetch_update ( Ordering :: SeqCst , Ordering :: SeqCst , |n| {
158+ if n >= * max {
159+ None
160+ } else {
161+ Some ( n + 1 )
162+ }
163+ } )
164+ }
165+
166+ fn dec ( & self ) {
167+ let ( c, _) = & * self . 0 ;
168+ c. fetch_sub ( 1 , Ordering :: SeqCst ) ;
169+ }
170+ }
171+
143172 let ( tx, mut rx) = tokio:: sync:: mpsc:: channel ( 32 ) ;
144173 n0_future:: task:: spawn ( async move {
145- let requests = Arc :: new ( AtomicUsize :: new ( 0 ) ) ;
174+ let requests = ConnectionCounter :: new ( max_connections ) ;
146175 while let Some ( msg) = rx. recv ( ) . await {
147176 if let ProviderMessage :: GetRequestReceived ( mut msg) = msg {
148177 let connection_id = msg. connection_id ;
149178 let request_id = msg. request_id ;
150- let res = requests. fetch_update ( Ordering :: SeqCst , Ordering :: SeqCst , |n| {
151- if n >= max_connections {
152- None
153- } else {
154- Some ( n + 1 )
155- }
156- } ) ;
179+ let res = requests. inc ( ) ;
157180 match res {
158181 Ok ( n) => {
159182 println ! ( "Accepting request {n}, id ({connection_id},{request_id})" ) ;
@@ -170,9 +193,12 @@ fn limit_max_connections(max_connections: usize) -> EventSender {
170193 let requests = requests. clone ( ) ;
171194 n0_future:: task:: spawn ( async move {
172195 // just drain the per request events
196+ //
197+ // Note that we have requested updates for the request, now we also need to process them
198+ // otherwise the request will be aborted!
173199 while let Ok ( Some ( _) ) = msg. rx . recv ( ) . await { }
174200 println ! ( "Stopping request, id ({connection_id},{request_id})" ) ;
175- requests. fetch_sub ( 1 , Ordering :: SeqCst ) ;
201+ requests. dec ( ) ;
176202 } ) ;
177203 }
178204 }
0 commit comments