Skip to content

Commit 27d7f77

Browse files
committed
feat: add high-level methods to request to make the request based api simple
1 parent 9ee8f35 commit 27d7f77

File tree

5 files changed

+347
-186
lines changed

5 files changed

+347
-186
lines changed

Cargo.lock

Lines changed: 11 additions & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

irpc-iroh/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@ n0-future = { workspace = true }
2727
tracing-subscriber = { workspace = true, features = ["fmt"] }
2828
irpc-derive = { version = "0.5.0", path = "../irpc-derive" }
2929
clap = { version = "4.5.41", features = ["derive"] }
30+
rand = "0.8"

irpc-iroh/examples/auth.rs

Lines changed: 91 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
//! * Manually implementing the connection loop
44
//! * Authenticating peers
55
6+
use std::time::Duration;
7+
68
use anyhow::Result;
7-
use iroh::{protocol::Router, Endpoint, Watcher};
9+
use iroh::{protocol::Router, Endpoint, NodeAddr, SecretKey, Watcher};
810

911
use self::storage::{StorageClient, StorageServer};
1012

@@ -17,20 +19,28 @@ async fn main() -> Result<()> {
1719
}
1820

1921
async fn remote() -> Result<()> {
20-
let (server_router, server_addr) = {
21-
let endpoint = Endpoint::builder().discovery_n0().bind().await?;
22+
let server_secret_key = SecretKey::generate(&mut rand::rngs::OsRng);
23+
let server_addr = NodeAddr::new(server_secret_key.public());
24+
let start_server = async move || {
25+
let endpoint = Endpoint::builder()
26+
.secret_key(server_secret_key.clone())
27+
.discovery_n0()
28+
.bind()
29+
.await?;
2230
let server = StorageServer::new("secret".to_string());
2331
let router = Router::builder(endpoint.clone())
2432
.accept(StorageServer::ALPN, server.clone())
2533
.spawn();
26-
let addr = endpoint.node_addr().initialized().await;
27-
(router, addr)
34+
let _ = endpoint.home_relay().initialized().await;
35+
// wait a bit for publishing to complete..
36+
tokio::time::sleep(Duration::from_millis(500)).await;
37+
anyhow::Ok(router)
2838
};
39+
let mut server_router = (start_server)().await?;
2940

3041
// correct authentication
31-
let client_endpoint = Endpoint::builder().bind().await?;
32-
let api = StorageClient::connect(client_endpoint, server_addr.clone());
33-
api.auth("secret").await?;
42+
let client_endpoint = Endpoint::builder().discovery_n0().bind().await?;
43+
let api = StorageClient::connect(client_endpoint, server_addr.clone(), "secret");
3444
api.set("hello".to_string(), "world".to_string()).await?;
3545
api.set("goodbye".to_string(), "world".to_string()).await?;
3646
let value = api.get("hello".to_string()).await?;
@@ -40,15 +50,21 @@ async fn remote() -> Result<()> {
4050
println!("list value = {value:?}");
4151
}
4252

43-
// invalid authentication
44-
let client_endpoint = Endpoint::builder().bind().await?;
45-
let api = StorageClient::connect(client_endpoint, server_addr.clone());
46-
assert!(api.auth("bad").await.is_err());
47-
assert!(api.get("hello".to_string()).await.is_err());
53+
// restart server
54+
server_router.shutdown().await?;
55+
server_router = (start_server)().await?;
56+
57+
// reconnections work: client will transparently reauthenticate
58+
println!("restarting server");
59+
let value = api.get("hello".to_string()).await?;
60+
println!("value = {value:?}");
61+
api.set("hello".to_string(), "world".to_string()).await?;
62+
let value = api.get("hello".to_string()).await?;
63+
println!("value = {value:?}");
4864

49-
// no authentication
65+
// invalid authentication
5066
let client_endpoint = Endpoint::builder().bind().await?;
51-
let api = StorageClient::connect(client_endpoint, server_addr);
67+
let api = StorageClient::connect(client_endpoint, server_addr.clone(), "bad");
5268
assert!(api.get("hello".to_string()).await.is_err());
5369

5470
drop(server_router);
@@ -65,15 +81,15 @@ mod storage {
6581
sync::{Arc, Mutex},
6682
};
6783

68-
use anyhow::Result;
84+
use anyhow::{anyhow, Result};
6985
use iroh::{
7086
endpoint::Connection,
7187
protocol::{AcceptError, ProtocolHandler},
7288
Endpoint,
7389
};
7490
use irpc::{
7591
channel::{mpsc, oneshot},
76-
Client, WithChannels,
92+
Client, Request, RequestError, WithChannels,
7793
};
7894
// Import the macro
7995
use irpc_derive::rpc_requests;
@@ -110,7 +126,8 @@ mod storage {
110126
#[rpc_requests(message = StorageMessage)]
111127
#[derive(Serialize, Deserialize, Debug)]
112128
enum StorageProtocol {
113-
#[rpc(tx=oneshot::Sender<Result<(), String>>)]
129+
// Connection will be closed if auth fails.
130+
#[rpc(tx=oneshot::Sender<()>)]
114131
Auth(Auth),
115132
#[rpc(tx=oneshot::Sender<Option<String>>)]
116133
Get(Get),
@@ -130,31 +147,29 @@ mod storage {
130147

131148
impl ProtocolHandler for StorageServer {
132149
async fn accept(&self, conn: Connection) -> Result<(), AcceptError> {
133-
let mut authed = false;
134-
while let Some(msg) = read_request::<StorageProtocol>(&conn).await? {
135-
match msg {
136-
StorageMessage::Auth(msg) => {
137-
let WithChannels { inner, tx, .. } = msg;
138-
if authed {
139-
conn.close(1u32.into(), b"invalid message");
140-
break;
141-
} else if inner.token != self.auth_token {
142-
conn.close(1u32.into(), b"permission denied");
143-
break;
144-
} else {
145-
authed = true;
146-
tx.send(Ok(())).await.ok();
147-
}
148-
}
149-
msg => {
150-
if !authed {
151-
conn.close(1u32.into(), b"permission denied");
152-
break;
153-
} else {
154-
self.handle_authenticated(msg).await;
155-
}
156-
}
150+
// read first message: must be auth!
151+
let msg = read_request::<StorageProtocol>(&conn).await?;
152+
let auth_ok = if let Some(StorageMessage::Auth(msg)) = msg {
153+
let WithChannels { inner, tx, .. } = msg;
154+
if inner.token == self.auth_token {
155+
tx.send(()).await.ok();
156+
true
157+
} else {
158+
false
157159
}
160+
} else {
161+
false
162+
};
163+
164+
// if not authenticated: close connection immediately.
165+
if !auth_ok {
166+
conn.close(1u32.into(), b"permission denied");
167+
return Ok(());
168+
}
169+
170+
// now the connection is authenticated and we can handle all subsequent requests.
171+
while let Some(msg) = read_request::<StorageProtocol>(&conn).await? {
172+
self.handle_request(msg).await;
158173
}
159174
conn.closed().await;
160175
Ok(())
@@ -171,7 +186,7 @@ mod storage {
171186
}
172187
}
173188

174-
async fn handle_authenticated(&self, msg: StorageMessage) {
189+
async fn handle_request(&self, msg: StorageMessage) {
175190
match msg {
176191
StorageMessage::Auth(_) => unreachable!("handled in ProtocolHandler::accept"),
177192
StorageMessage::Get(get) => {
@@ -219,39 +234,63 @@ mod storage {
219234
}
220235

221236
pub struct StorageClient {
237+
api_token: String,
222238
inner: Client<StorageProtocol>,
223239
}
224240

225241
impl StorageClient {
226242
pub const ALPN: &[u8] = ALPN;
227243

228-
pub fn connect(endpoint: Endpoint, addr: impl Into<iroh::NodeAddr>) -> StorageClient {
244+
pub fn connect(
245+
endpoint: Endpoint,
246+
addr: impl Into<iroh::NodeAddr>,
247+
api_token: &str,
248+
) -> StorageClient {
229249
let conn = IrohRemoteConnection::new(endpoint, addr.into(), Self::ALPN.to_vec());
230250
StorageClient {
251+
api_token: api_token.to_string(),
231252
inner: Client::boxed(conn),
232253
}
233254
}
234255

235-
pub async fn auth(&self, token: &str) -> Result<(), anyhow::Error> {
236-
self.inner
256+
async fn authenticated_request(&self) -> Result<Request<StorageProtocol>, irpc::Error> {
257+
let request = self.inner.request().await?;
258+
259+
// if the connection is not new: no need to reauthenticate.
260+
if !request.is_new_connection() {
261+
return Ok(request);
262+
}
263+
264+
// if this is a new connection: use this request to send an auth message.
265+
request
237266
.rpc(Auth {
238-
token: token.to_string(),
267+
token: self.api_token.clone(),
239268
})
240-
.await?
241-
.map_err(|err| anyhow::anyhow!(err))
269+
.await?;
270+
// and create a new request for the actual call.
271+
let request = self.inner.request().await?;
272+
// if this *again* created a new connection, we error out.
273+
if request.is_new_connection() {
274+
Err(RequestError::Other(anyhow!("Connection is reconnecting too often")).into())
275+
} else {
276+
Ok(request)
277+
}
242278
}
243279

244280
pub async fn get(&self, key: String) -> Result<Option<String>, irpc::Error> {
245-
self.inner.rpc(Get { key }).await
281+
self.authenticated_request().await?.rpc(Get { key }).await
246282
}
247283

248284
pub async fn list(&self) -> Result<mpsc::Receiver<String>, irpc::Error> {
249-
self.inner.server_streaming(List, 10).await
285+
self.authenticated_request()
286+
.await?
287+
.server_streaming(List, 10)
288+
.await
250289
}
251290

252291
pub async fn set(&self, key: String, value: String) -> Result<(), irpc::Error> {
253292
let msg = Set { key, value };
254-
self.inner.rpc(msg).await
293+
self.authenticated_request().await?.rpc(msg).await
255294
}
256295
}
257296
}

irpc-iroh/src/lib.rs

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ use iroh::{
1010
use irpc::{
1111
channel::RecvError,
1212
rpc::{
13-
Handler, RemoteConnection, RemoteService, ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED,
14-
MAX_MESSAGE_SIZE,
13+
Handler, RemoteConnection, RemoteService, RemoteStreams,
14+
ERROR_CODE_MAX_MESSAGE_SIZE_EXCEEDED, MAX_MESSAGE_SIZE,
1515
},
1616
util::AsyncReadVarintExt,
1717
LocalSender, RequestError,
@@ -59,27 +59,32 @@ impl RemoteConnection for IrohRemoteConnection {
5959
Box::new(self.clone())
6060
}
6161

62-
fn open_bi(&self) -> BoxFuture<std::result::Result<(SendStream, RecvStream), RequestError>> {
62+
fn open_bi(&self) -> BoxFuture<std::result::Result<RemoteStreams, RequestError>> {
6363
let this = self.0.clone();
6464
Box::pin(async move {
6565
let mut guard = this.connection.lock().await;
6666
let pair = match guard.as_mut() {
6767
Some(conn) => {
6868
// try to reuse the connection
6969
match conn.open_bi().await {
70-
Ok(pair) => pair,
70+
Ok(pair) => RemoteStreams::with_reused(pair),
7171
Err(_) => {
7272
// try with a new connection, just once
7373
*guard = None;
74-
connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
75-
.await
76-
.map_err(RequestError::Other)?
74+
let pair =
75+
connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
76+
.await
77+
.map_err(RequestError::Other)?;
78+
RemoteStreams::with_new(pair)
7779
}
7880
}
7981
}
80-
None => connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
81-
.await
82-
.map_err(RequestError::Other)?,
82+
None => {
83+
let pair = connect_and_open_bi(&this.endpoint, &this.addr, &this.alpn, guard)
84+
.await
85+
.map_err(RequestError::Other)?;
86+
RemoteStreams::with_new(pair)
87+
}
8388
};
8489
Ok(pair)
8590
})

0 commit comments

Comments
 (0)