Skip to content

Commit 4bddf77

Browse files
committed
nicer connection counter
1 parent f992a44 commit 4bddf77

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

examples/limit.rs

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ use crate::common::get_or_generate_secret_key;
2929
#[derive(Debug, Parser)]
3030
#[command(version, about)]
3131
pub enum Args {
32+
/// Limit requests by node id
3233
ByNodeId {
3334
/// Path for files to add
3435
paths: Vec<PathBuf>,
@@ -38,16 +39,19 @@ pub enum Args {
3839
#[clap(long, default_value_t = 1)]
3940
secrets: usize,
4041
},
42+
/// Limit requests by hash, only first hash is allowed
4143
ByHash {
4244
/// Path for files to add
4345
paths: Vec<PathBuf>,
4446
},
47+
/// Throttle requests
4548
Throttle {
4649
/// Path for files to add
4750
paths: Vec<PathBuf>,
4851
#[clap(long, default_value = "100")]
4952
delay_ms: u64,
5053
},
54+
/// Limit maximum number of connections.
5155
MaxConnections {
5256
/// Path for files to add
5357
paths: Vec<PathBuf>,
@@ -140,20 +144,39 @@ fn throttle(delay_ms: u64) -> EventSender {
140144
}
141145

142146
fn limit_max_connections(max_connections: usize) -> EventSender {
147+
#[derive(Default, Debug, Clone)]
148+
struct ConnectionCounter(Arc<(AtomicUsize, usize)>);
149+
150+
impl ConnectionCounter {
151+
fn new(max: usize) -> Self {
152+
Self(Arc::new((Default::default(), max)))
153+
}
154+
155+
fn inc(&self) -> Result<usize, usize> {
156+
let (c, max) = &*self.0;
157+
c.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| {
158+
if n >= *max {
159+
None
160+
} else {
161+
Some(n + 1)
162+
}
163+
})
164+
}
165+
166+
fn dec(&self) {
167+
let (c, _) = &*self.0;
168+
c.fetch_sub(1, Ordering::SeqCst);
169+
}
170+
}
171+
143172
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
144173
n0_future::task::spawn(async move {
145-
let requests = Arc::new(AtomicUsize::new(0));
174+
let requests = ConnectionCounter::new(max_connections);
146175
while let Some(msg) = rx.recv().await {
147176
if let ProviderMessage::GetRequestReceived(mut msg) = msg {
148177
let connection_id = msg.connection_id;
149178
let request_id = msg.request_id;
150-
let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| {
151-
if n >= max_connections {
152-
None
153-
} else {
154-
Some(n + 1)
155-
}
156-
});
179+
let res = requests.inc();
157180
match res {
158181
Ok(n) => {
159182
println!("Accepting request {n}, id ({connection_id},{request_id})");
@@ -170,9 +193,12 @@ fn limit_max_connections(max_connections: usize) -> EventSender {
170193
let requests = requests.clone();
171194
n0_future::task::spawn(async move {
172195
// just drain the per request events
196+
//
197+
// Note that we have requested updates for the request, now we also need to process them
198+
// otherwise the request will be aborted!
173199
while let Ok(Some(_)) = msg.rx.recv().await {}
174200
println!("Stopping request, id ({connection_id},{request_id})");
175-
requests.fetch_sub(1, Ordering::SeqCst);
201+
requests.dec();
176202
});
177203
}
178204
}

0 commit comments

Comments
 (0)