Skip to content

Commit 6d86e4f

Browse files
committed
Add limit example
This shows how to limit serving content in various ways - by node id - by content hash - throttling - limiting max number of connections
1 parent b26aefb commit 6d86e4f

File tree

1 file changed

+341
-0
lines changed

1 file changed

+341
-0
lines changed

examples/limit.rs

Lines changed: 341 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,341 @@
1+
/// Example how to limit blob requests by hash and node id, and to add
2+
/// restrictions on limited content.
3+
mod common;
4+
use std::{
5+
collections::{HashMap, HashSet},
6+
path::PathBuf,
7+
sync::{
8+
atomic::{AtomicUsize, Ordering},
9+
Arc,
10+
},
11+
};
12+
13+
use clap::Parser;
14+
use common::setup_logging;
15+
use iroh::{NodeId, SecretKey, Watcher};
16+
use iroh_blobs::{
17+
provider::events::{
18+
AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode,
19+
ThrottleMode,
20+
},
21+
store::mem::MemStore,
22+
ticket::BlobTicket,
23+
BlobsProtocol, Hash,
24+
};
25+
use rand::thread_rng;
26+
27+
use crate::common::get_or_generate_secret_key;
28+
29+
#[derive(Debug, Parser)]
30+
#[command(version, about)]
31+
pub enum Args {
32+
ByNodeId {
33+
/// Path for files to add
34+
paths: Vec<PathBuf>,
35+
#[clap(long("allow"))]
36+
/// Nodes that are allowed to download content.
37+
allowed_nodes: Vec<NodeId>,
38+
#[clap(long, default_value_t = 1)]
39+
secrets: usize,
40+
},
41+
ByHash {
42+
/// Path for files to add
43+
paths: Vec<PathBuf>,
44+
},
45+
Throttle {
46+
/// Path for files to add
47+
paths: Vec<PathBuf>,
48+
#[clap(long, default_value = "100")]
49+
delay_ms: u64,
50+
},
51+
MaxConnections {
52+
/// Path for files to add
53+
paths: Vec<PathBuf>,
54+
#[clap(long, default_value = "1")]
55+
max_connections: usize,
56+
},
57+
Get {
58+
/// Ticket for the blob to download
59+
ticket: BlobTicket,
60+
},
61+
}
62+
63+
fn limit_by_node_id(allowed_nodes: HashSet<NodeId>) -> EventSender {
64+
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
65+
n0_future::task::spawn(async move {
66+
while let Some(msg) = rx.recv().await {
67+
match msg {
68+
ProviderMessage::ClientConnected(msg) => {
69+
let node_id = msg.node_id;
70+
let res = if allowed_nodes.contains(&node_id) {
71+
println!("Client connected: {node_id}");
72+
Ok(())
73+
} else {
74+
println!("Client rejected: {node_id}");
75+
Err(AbortReason::Permission)
76+
};
77+
msg.tx.send(res).await.ok();
78+
}
79+
_ => {}
80+
}
81+
}
82+
});
83+
EventSender::new(
84+
tx,
85+
EventMask {
86+
connected: ConnectMode::Request,
87+
..EventMask::DEFAULT
88+
},
89+
)
90+
}
91+
92+
fn limit_by_hash(allowed_hashes: HashSet<Hash>) -> EventSender {
93+
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
94+
n0_future::task::spawn(async move {
95+
while let Some(msg) = rx.recv().await {
96+
match msg {
97+
ProviderMessage::GetRequestReceived(msg) => {
98+
let res = if !msg.request.ranges.is_blob() {
99+
println!("HashSeq request not allowed");
100+
Err(AbortReason::Permission)
101+
} else if !allowed_hashes.contains(&msg.request.hash) {
102+
println!("Request for hash {} not allowed", msg.request.hash);
103+
Err(AbortReason::Permission)
104+
} else {
105+
println!("Request for hash {} allowed", msg.request.hash);
106+
Ok(())
107+
};
108+
msg.tx.send(res).await.ok();
109+
}
110+
_ => {}
111+
}
112+
}
113+
});
114+
EventSender::new(
115+
tx,
116+
EventMask {
117+
get: RequestMode::Request,
118+
..EventMask::DEFAULT
119+
},
120+
)
121+
}
122+
123+
fn throttle(delay_ms: u64) -> EventSender {
124+
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
125+
n0_future::task::spawn(async move {
126+
while let Some(msg) = rx.recv().await {
127+
match msg {
128+
ProviderMessage::Throttle(msg) => {
129+
n0_future::task::spawn(async move {
130+
println!(
131+
"Throttling {} {}, {}ms",
132+
msg.connection_id, msg.request_id, delay_ms
133+
);
134+
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
135+
msg.tx.send(Ok(())).await.ok();
136+
});
137+
}
138+
_ => {}
139+
}
140+
}
141+
});
142+
EventSender::new(
143+
tx,
144+
EventMask {
145+
throttle: ThrottleMode::Throttle,
146+
..EventMask::DEFAULT
147+
},
148+
)
149+
}
150+
151+
fn limit_max_connections(max_connections: usize) -> EventSender {
152+
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
153+
n0_future::task::spawn(async move {
154+
let requests = Arc::new(AtomicUsize::new(0));
155+
while let Some(msg) = rx.recv().await {
156+
match msg {
157+
ProviderMessage::GetRequestReceived(mut msg) => {
158+
let connection_id = msg.connection_id;
159+
let request_id = msg.request_id;
160+
let res = requests.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| {
161+
if n >= max_connections {
162+
None
163+
} else {
164+
Some(n + 1)
165+
}
166+
});
167+
match res {
168+
Ok(n) => {
169+
println!("Accepting request {n}, id ({connection_id},{request_id})");
170+
msg.tx.send(Ok(())).await.ok();
171+
}
172+
Err(_) => {
173+
println!(
174+
"Connection limit of {} exceeded, rejecting request",
175+
max_connections
176+
);
177+
msg.tx.send(Err(AbortReason::RateLimited)).await.ok();
178+
continue;
179+
}
180+
}
181+
let requests = requests.clone();
182+
n0_future::task::spawn(async move {
183+
// just drain the per request events
184+
while let Ok(Some(_)) = msg.rx.recv().await {}
185+
println!("Stopping request, id ({connection_id},{request_id})");
186+
requests.fetch_sub(1, Ordering::SeqCst);
187+
});
188+
}
189+
_ => {}
190+
}
191+
}
192+
});
193+
EventSender::new(
194+
tx,
195+
EventMask {
196+
get: RequestMode::RequestLog,
197+
..EventMask::DEFAULT
198+
},
199+
)
200+
}
201+
202+
#[tokio::main]
203+
async fn main() -> anyhow::Result<()> {
204+
setup_logging();
205+
let args = Args::parse();
206+
match args {
207+
Args::Get { ticket } => {
208+
let secret = get_or_generate_secret_key()?;
209+
let endpoint = iroh::Endpoint::builder()
210+
.secret_key(secret)
211+
.discovery_n0()
212+
.bind()
213+
.await?;
214+
let connection = endpoint
215+
.connect(ticket.node_addr().clone(), iroh_blobs::ALPN)
216+
.await?;
217+
let (data, stats) = iroh_blobs::get::request::get_blob(connection, ticket.hash())
218+
.bytes_and_stats()
219+
.await?;
220+
println!("Downloaded {} bytes", data.len());
221+
println!("Stats: {:?}", stats);
222+
}
223+
Args::ByNodeId {
224+
paths,
225+
allowed_nodes,
226+
secrets,
227+
} => {
228+
let mut allowed_nodes = allowed_nodes.into_iter().collect::<HashSet<_>>();
229+
if secrets > 0 {
230+
println!("Generating {secrets} new secret keys for allowed nodes:");
231+
let mut rand = thread_rng();
232+
for _ in 0..secrets {
233+
let secret = SecretKey::generate(&mut rand);
234+
let public = secret.public();
235+
allowed_nodes.insert(public);
236+
println!("IROH_SECRET={}", hex::encode(secret.to_bytes()));
237+
}
238+
}
239+
let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?;
240+
let store = MemStore::new();
241+
let mut hashes = HashMap::new();
242+
for path in paths {
243+
let tag = store.add_path(&path).await?;
244+
hashes.insert(path, tag.hash);
245+
}
246+
let _ = endpoint.home_relay().initialized().await;
247+
let addr = endpoint.node_addr().initialized().await;
248+
let events = limit_by_node_id(allowed_nodes.clone());
249+
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
250+
let router = iroh::protocol::Router::builder(endpoint)
251+
.accept(iroh_blobs::ALPN, blobs)
252+
.spawn();
253+
println!("Node id: {}\n", router.endpoint().node_id());
254+
for id in &allowed_nodes {
255+
println!("Allowed node: {id}");
256+
}
257+
println!();
258+
for (path, hash) in &hashes {
259+
let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw);
260+
println!("{}: {ticket}", path.display());
261+
}
262+
tokio::signal::ctrl_c().await?;
263+
router.shutdown().await?;
264+
}
265+
Args::ByHash { paths } => {
266+
let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?;
267+
let store = MemStore::new();
268+
let mut hashes = HashMap::new();
269+
let mut allowed_hashes = HashSet::new();
270+
for (i, path) in paths.into_iter().enumerate() {
271+
let tag = store.add_path(&path).await?;
272+
hashes.insert(path, tag.hash);
273+
if i == 0 {
274+
allowed_hashes.insert(tag.hash);
275+
}
276+
}
277+
let _ = endpoint.home_relay().initialized().await;
278+
let addr = endpoint.node_addr().initialized().await;
279+
let events = limit_by_hash(allowed_hashes.clone());
280+
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
281+
let router = iroh::protocol::Router::builder(endpoint)
282+
.accept(iroh_blobs::ALPN, blobs)
283+
.spawn();
284+
for (i, (path, hash)) in hashes.iter().enumerate() {
285+
let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw);
286+
let permitted = if i == 0 { "" } else { "limited" };
287+
println!("{}: {ticket} ({permitted})", path.display());
288+
}
289+
tokio::signal::ctrl_c().await?;
290+
router.shutdown().await?;
291+
}
292+
Args::Throttle { paths, delay_ms } => {
293+
let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?;
294+
let store = MemStore::new();
295+
let mut hashes = HashMap::new();
296+
for path in paths {
297+
let tag = store.add_path(&path).await?;
298+
hashes.insert(path, tag.hash);
299+
}
300+
let _ = endpoint.home_relay().initialized().await;
301+
let addr = endpoint.node_addr().initialized().await;
302+
let events = throttle(delay_ms);
303+
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
304+
let router = iroh::protocol::Router::builder(endpoint)
305+
.accept(iroh_blobs::ALPN, blobs)
306+
.spawn();
307+
for (path, hash) in hashes {
308+
let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw);
309+
println!("{}: {ticket}", path.display());
310+
}
311+
tokio::signal::ctrl_c().await?;
312+
router.shutdown().await?;
313+
}
314+
Args::MaxConnections {
315+
paths,
316+
max_connections,
317+
} => {
318+
let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?;
319+
let store = MemStore::new();
320+
let mut hashes = HashMap::new();
321+
for path in paths {
322+
let tag = store.add_path(&path).await?;
323+
hashes.insert(path, tag.hash);
324+
}
325+
let _ = endpoint.home_relay().initialized().await;
326+
let addr = endpoint.node_addr().initialized().await;
327+
let events = limit_max_connections(max_connections);
328+
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
329+
let router = iroh::protocol::Router::builder(endpoint)
330+
.accept(iroh_blobs::ALPN, blobs)
331+
.spawn();
332+
for (path, hash) in hashes {
333+
let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw);
334+
println!("{}: {ticket}", path.display());
335+
}
336+
tokio::signal::ctrl_c().await?;
337+
router.shutdown().await?;
338+
}
339+
}
340+
Ok(())
341+
}

0 commit comments

Comments
 (0)