Skip to content

Commit 31d0ee3

Browse files
committed
feat: add high-level methods to request to make the request based api simple
1 parent f955176 commit 31d0ee3

File tree

5 files changed

+347
-187
lines changed

5 files changed

+347
-187
lines changed

Cargo.lock

Lines changed: 11 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

irpc-iroh/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ n0-future = { workspace = true }
2727
tracing-subscriber = { workspace = true, features = ["fmt"] }
2828
irpc-derive = { version = "0.5.0", path = "../irpc-derive" }
2929
clap = { version = "4.5.41", features = ["derive"] }
30+
rand = "0.8"

irpc-iroh/examples/auth.rs

Lines changed: 91 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
//! * Manually implementing the connection loop
44
//! * Authenticating peers
55
6+
use std::time::Duration;
7+
68
use anyhow::Result;
7-
use iroh::{protocol::Router, Endpoint, Watcher};
9+
use iroh::{protocol::Router, Endpoint, NodeAddr, SecretKey, Watcher};
810

911
use self::storage::{StorageClient, StorageServer};
1012

@@ -17,20 +19,28 @@ async fn main() -> Result<()> {
1719
}
1820

1921
async fn remote() -> Result<()> {
20-
let (server_router, server_addr) = {
21-
let endpoint = Endpoint::builder().discovery_n0().bind().await?;
22+
let server_secret_key = SecretKey::generate(&mut rand::rngs::OsRng);
23+
let server_addr = NodeAddr::new(server_secret_key.public());
24+
let start_server = async move || {
25+
let endpoint = Endpoint::builder()
26+
.secret_key(server_secret_key.clone())
27+
.discovery_n0()
28+
.bind()
29+
.await?;
2230
let server = StorageServer::new("secret".to_string());
2331
let router = Router::builder(endpoint.clone())
2432
.accept(StorageServer::ALPN, server.clone())
2533
.spawn();
26-
let addr = endpoint.node_addr().initialized().await;
27-
(router, addr)
34+
let _ = endpoint.home_relay().initialized().await;
35+
// wait a bit for publishing to complete..
36+
tokio::time::sleep(Duration::from_millis(500)).await;
37+
anyhow::Ok(router)
2838
};
39+
let mut server_router = (start_server)().await?;
2940

3041
// correct authentication
31-
let client_endpoint = Endpoint::builder().bind().await?;
32-
let api = StorageClient::connect(client_endpoint, server_addr.clone());
33-
api.auth("secret").await?;
42+
let client_endpoint = Endpoint::builder().discovery_n0().bind().await?;
43+
let api = StorageClient::connect(client_endpoint, server_addr.clone(), "secret");
3444
api.set("hello".to_string(), "world".to_string()).await?;
3545
api.set("goodbye".to_string(), "world".to_string()).await?;
3646
let value = api.get("hello".to_string()).await?;
@@ -40,15 +50,21 @@ async fn remote() -> Result<()> {
4050
println!("list value = {value:?}");
4151
}
4252

43-
// invalid authentication
44-
let client_endpoint = Endpoint::builder().bind().await?;
45-
let api = StorageClient::connect(client_endpoint, server_addr.clone());
46-
assert!(api.auth("bad").await.is_err());
47-
assert!(api.get("hello".to_string()).await.is_err());
53+
// restart server
54+
server_router.shutdown().await?;
55+
server_router = (start_server)().await?;
56+
57+
// reconnections work: client will transparently reauthenticate
58+
println!("restarting server");
59+
let value = api.get("hello".to_string()).await?;
60+
println!("value = {value:?}");
61+
api.set("hello".to_string(), "world".to_string()).await?;
62+
let value = api.get("hello".to_string()).await?;
63+
println!("value = {value:?}");
4864

49-
// no authentication
65+
// invalid authentication
5066
let client_endpoint = Endpoint::builder().bind().await?;
51-
let api = StorageClient::connect(client_endpoint, server_addr);
67+
let api = StorageClient::connect(client_endpoint, server_addr.clone(), "bad");
5268
assert!(api.get("hello".to_string()).await.is_err());
5369

5470
drop(server_router);
@@ -65,17 +81,16 @@ mod storage {
6581
sync::{Arc, Mutex},
6682
};
6783

68-
use anyhow::Result;
84+
use anyhow::{anyhow, Result};
6985
use iroh::{
7086
endpoint::Connection,
7187
protocol::{AcceptError, ProtocolHandler},
7288
Endpoint,
7389
};
7490
use irpc::{
7591
channel::{mpsc, oneshot},
76-
rpc_requests, Client, WithChannels,
92+
rpc_requests, Client, Request, RequestError, WithChannels,
7793
};
78-
// Import the macro
7994
use irpc_iroh::{read_request, IrohRemoteConnection};
8095
use serde::{Deserialize, Serialize};
8196
use tracing::info;
@@ -109,7 +124,8 @@ mod storage {
109124
#[rpc_requests(message = StorageMessage)]
110125
#[derive(Serialize, Deserialize, Debug)]
111126
enum StorageProtocol {
112-
#[rpc(tx=oneshot::Sender<Result<(), String>>)]
127+
// Connection will be closed if auth fails.
128+
#[rpc(tx=oneshot::Sender<()>)]
113129
Auth(Auth),
114130
#[rpc(tx=oneshot::Sender<Option<String>>)]
115131
Get(Get),
@@ -129,31 +145,29 @@ mod storage {
129145

130146
impl ProtocolHandler for StorageServer {
131147
async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
132-
let mut authed = false;
133-
while let Some(msg) = read_request::<StorageProtocol>(&conn).await? {
134-
match msg {
135-
StorageMessage::Auth(msg) => {
136-
let WithChannels { inner, tx, .. } = msg;
137-
if authed {
138-
conn.close(1u32.into(), b"invalid message");
139-
break;
140-
} else if inner.token != self.auth_token {
141-
conn.close(1u32.into(), b"permission denied");
142-
break;
143-
} else {
144-
authed = true;
145-
tx.send(Ok(())).await.ok();
146-
}
147-
}
148-
msg => {
149-
if !authed {
150-
conn.close(1u32.into(), b"permission denied");
151-
break;
152-
} else {
153-
self.handle_authenticated(msg).await;
154-
}
155-
}
148+
// read first message: must be auth!
149+
let msg = read_request::<StorageProtocol>(&conn).await?;
150+
let auth_ok = if let Some(StorageMessage::Auth(msg)) = msg {
151+
let WithChannels { inner, tx, .. } = msg;
152+
if inner.token == self.auth_token {
153+
tx.send(()).await.ok();
154+
true
155+
} else {
156+
false
156157
}
158+
} else {
159+
false
160+
};
161+
162+
// if not authenticated: close connection immediately.
163+
if !auth_ok {
164+
conn.close(1u32.into(), b"permission denied");
165+
return Ok(());
166+
}
167+
168+
// now the connection is authenticated and we can handle all subsequent requests.
169+
while let Some(msg) = read_request::<StorageProtocol>(&conn).await? {
170+
self.handle_request(msg).await;
157171
}
158172
conn.closed().await;
159173
Ok(())
@@ -170,7 +184,7 @@ mod storage {
170184
}
171185
}
172186

173-
async fn handle_authenticated(&self, msg: StorageMessage) {
187+
async fn handle_request(&self, msg: StorageMessage) {
174188
match msg {
175189
StorageMessage::Auth(_) => unreachable!("handled in ProtocolHandler::accept"),
176190
StorageMessage::Get(get) => {
@@ -218,39 +232,63 @@ mod storage {
218232
}
219233

220234
pub struct StorageClient {
235+
api_token: String,
221236
inner: Client<StorageProtocol>,
222237
}
223238

224239
impl StorageClient {
225240
pub const ALPN: &[u8] = ALPN;
226241

227-
pub fn connect(endpoint: Endpoint, addr: impl Into<iroh::NodeAddr>) -> StorageClient {
242+
pub fn connect(
243+
endpoint: Endpoint,
244+
addr: impl Into<iroh::NodeAddr>,
245+
api_token: &str,
246+
) -> StorageClient {
228247
let conn = IrohRemoteConnection::new(endpoint, addr.into(), Self::ALPN.to_vec());
229248
StorageClient {
249+
api_token: api_token.to_string(),
230250
inner: Client::boxed(conn),
231251
}
232252
}
233253

234-
pub async fn auth(&self, token: &str) -> Result<(), anyhow::Error> {
235-
self.inner
254+
async fn authenticated_request(&self) -> Result<Request<StorageProtocol>, irpc::Error> {
255+
let request = self.inner.request().await?;
256+
257+
// if the connection is not new: no need to reauthenticate.
258+
if !request.is_new_connection() {
259+
return Ok(request);
260+
}
261+
262+
// if this is a new connection: use this request to send an auth message.
263+
request
236264
.rpc(Auth {
237-
token: token.to_string(),
265+
token: self.api_token.clone(),
238266
})
239-
.await?
240-
.map_err(|err| anyhow::anyhow!(err))
267+
.await?;
268+
// and create a new request for the actual call.
269+
let request = self.inner.request().await?;
270+
// if this *again* created a new connection, we error out.
271+
if request.is_new_connection() {
272+
Err(RequestError::Other(anyhow!("Connection is reconnecting too often")).into())
273+
} else {
274+
Ok(request)
275+
}
241276
}
242277

243278
pub async fn get(&self, key: String) -> Result<Option<String>, irpc::Error> {
244-
self.inner.rpc(Get { key }).await
279+
self.authenticated_request().await?.rpc(Get { key }).await
245280
}
246281

247282
pub async fn list(&self) -> Result<mpsc::Receiver<String>, irpc::Error> {
248-
self.inner.server_streaming(List, 10).await
283+
self.authenticated_request()
284+
.await?
285+
.server_streaming(List, 10)
286+
.await
249287
}
250288

251289
pub async fn set(&self, key: String, value: String) -> Result<(), irpc::Error> {
252290
let msg = Set { key, value };
253-
self.inner.rpc(msg).await
291+
self.authenticated_request().await?.rpc(msg).await
254292
}
255293
}
256294
}

irpc-iroh/src/lib.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use iroh::{
1010
use irpc::{
1111
channel::RecvError,
1212
rpc::{
13-
Handler, RemoteConnection, RemoteService, ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED,
14-
MAX_MESSAGE_SIZE,
13+
Handler, RemoteConnection, RemoteService, RemoteStreams,
14+
ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED, MAX_MESSAGE_SIZE,
1515
},
1616
util::AsyncReadVarintExt,
1717
LocalSender, RequestError,
@@ -60,27 +60,32 @@ impl RemoteConnection for IrohRemoteConnection {
6060
Box::new(self.clone())
6161
}
6262

63-
fn open_bi(&self) -> BoxFuture<std::result::Result<(SendStream, RecvStream), RequestError>> {
63+
fn open_bi(&self) -> BoxFuture<std::result::Result<RemoteStreams, RequestError>> {
6464
let this = self.0.clone();
6565
Box::pin(async move {
6666
let mut guard = this.connection.lock().await;
6767
let pair = match guard.as_mut() {
6868
Some(conn) => {
6969
// try to reuse the connection
7070
match conn.open_bi().await {
71-
Ok(pair) => pair,
71+
Ok(pair) => RemoteStreams::with_reused(pair),
7272
Err(_) => {
7373
// try with a new connection, just once
7474
*guard = None;
75-
connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
76-
.await
77-
.map_err(RequestError::Other)?
75+
let pair =
76+
connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
77+
.await
78+
.map_err(RequestError::Other)?;
79+
RemoteStreams::with_new(pair)
7880
}
7981
}
8082
}
81-
None => connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
82-
.await
83-
.map_err(RequestError::Other)?,
83+
None => {
84+
let pair = connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
85+
.await
86+
.map_err(RequestError::Other)?;
87+
RemoteStreams::with_new(pair)
88+
}
8489
};
8590
Ok(pair)
8691
})

0 commit comments

Comments
 (0)