33//! * Manually implementing the connection loop
44//! * Authenticating peers
55
6+ use std:: time:: Duration ;
7+
68use anyhow:: Result ;
7- use iroh:: { protocol:: Router , Endpoint , Watcher } ;
9+ use iroh:: { protocol:: Router , Endpoint , NodeAddr , SecretKey , Watcher } ;
810
911use self :: storage:: { StorageClient , StorageServer } ;
1012
@@ -17,20 +19,28 @@ async fn main() -> Result<()> {
1719}
1820
1921async fn remote ( ) -> Result < ( ) > {
20- let ( server_router, server_addr) = {
21- let endpoint = Endpoint :: builder ( ) . discovery_n0 ( ) . bind ( ) . await ?;
22+ let server_secret_key = SecretKey :: generate ( & mut rand:: rngs:: OsRng ) ;
23+ let server_addr = NodeAddr :: new ( server_secret_key. public ( ) ) ;
24+ let start_server = async move || {
25+ let endpoint = Endpoint :: builder ( )
26+ . secret_key ( server_secret_key. clone ( ) )
27+ . discovery_n0 ( )
28+ . bind ( )
29+ . await ?;
2230 let server = StorageServer :: new ( "secret" . to_string ( ) ) ;
2331 let router = Router :: builder ( endpoint. clone ( ) )
2432 . accept ( StorageServer :: ALPN , server. clone ( ) )
2533 . spawn ( ) ;
26- let addr = endpoint. node_addr ( ) . initialized ( ) . await ;
27- ( router, addr)
34+ let _ = endpoint. home_relay ( ) . initialized ( ) . await ;
35+ // wait a bit for publishing to complete..
36+ tokio:: time:: sleep ( Duration :: from_millis ( 500 ) ) . await ;
37+ anyhow:: Ok ( router)
2838 } ;
39+ let mut server_router = ( start_server) ( ) . await ?;
2940
3041 // correct authentication
31- let client_endpoint = Endpoint :: builder ( ) . bind ( ) . await ?;
32- let api = StorageClient :: connect ( client_endpoint, server_addr. clone ( ) ) ;
33- api. auth ( "secret" ) . await ?;
42+ let client_endpoint = Endpoint :: builder ( ) . discovery_n0 ( ) . bind ( ) . await ?;
43+ let api = StorageClient :: connect ( client_endpoint, server_addr. clone ( ) , "secret" ) ;
3444 api. set ( "hello" . to_string ( ) , "world" . to_string ( ) ) . await ?;
3545 api. set ( "goodbye" . to_string ( ) , "world" . to_string ( ) ) . await ?;
3646 let value = api. get ( "hello" . to_string ( ) ) . await ?;
@@ -40,15 +50,21 @@ async fn remote() -> Result<()> {
4050 println ! ( "list value = {value:?}" ) ;
4151 }
4252
43- // invalid authentication
44- let client_endpoint = Endpoint :: builder ( ) . bind ( ) . await ?;
45- let api = StorageClient :: connect ( client_endpoint, server_addr. clone ( ) ) ;
46- assert ! ( api. auth( "bad" ) . await . is_err( ) ) ;
47- assert ! ( api. get( "hello" . to_string( ) ) . await . is_err( ) ) ;
53+ // restart server
54+ server_router. shutdown ( ) . await ?;
55+ server_router = ( start_server) ( ) . await ?;
56+
57+ // reconnections work: client will transparently reauthenticate
58+ println ! ( "restarting server" ) ;
59+ let value = api. get ( "hello" . to_string ( ) ) . await ?;
60+ println ! ( "value = {value:?}" ) ;
61+ api. set ( "hello" . to_string ( ) , "world" . to_string ( ) ) . await ?;
62+ let value = api. get ( "hello" . to_string ( ) ) . await ?;
63+ println ! ( "value = {value:?}" ) ;
4864
49- // no authentication
65+ // invalid authentication
5066 let client_endpoint = Endpoint :: builder ( ) . bind ( ) . await ?;
51- let api = StorageClient :: connect ( client_endpoint, server_addr) ;
67+ let api = StorageClient :: connect ( client_endpoint, server_addr. clone ( ) , "bad" ) ;
5268 assert ! ( api. get( "hello" . to_string( ) ) . await . is_err( ) ) ;
5369
5470 drop ( server_router) ;
@@ -65,17 +81,16 @@ mod storage {
6581 sync:: { Arc , Mutex } ,
6682 } ;
6783
68- use anyhow:: Result ;
84+ use anyhow:: { anyhow , Result } ;
6985 use iroh:: {
7086 endpoint:: Connection ,
7187 protocol:: { AcceptError , ProtocolHandler } ,
7288 Endpoint ,
7389 } ;
7490 use irpc:: {
7591 channel:: { mpsc, oneshot} ,
76- rpc_requests, Client , WithChannels ,
92+ rpc_requests, Client , Request , RequestError , WithChannels ,
7793 } ;
78- // Import the macro
7994 use irpc_iroh:: { read_request, IrohRemoteConnection } ;
8095 use serde:: { Deserialize , Serialize } ;
8196 use tracing:: info;
@@ -109,7 +124,8 @@ mod storage {
109124 #[ rpc_requests( message = StorageMessage ) ]
110125 #[ derive( Serialize , Deserialize , Debug ) ]
111126 enum StorageProtocol {
112- #[ rpc( tx=oneshot:: Sender <Result <( ) , String >>) ]
127+ // Connection will be closed if auth fails.
128+ #[ rpc( tx=oneshot:: Sender <( ) >) ]
113129 Auth ( Auth ) ,
114130 #[ rpc( tx=oneshot:: Sender <Option <String >>) ]
115131 Get ( Get ) ,
@@ -129,31 +145,29 @@ mod storage {
129145
130146 impl ProtocolHandler for StorageServer {
131147 async fn accept ( & self , conn : Connection ) -> Result < ( ) , AcceptError > {
132- let mut authed = false ;
133- while let Some ( msg) = read_request :: < StorageProtocol > ( & conn) . await ? {
134- match msg {
135- StorageMessage :: Auth ( msg) => {
136- let WithChannels { inner, tx, .. } = msg;
137- if authed {
138- conn. close ( 1u32 . into ( ) , b"invalid message" ) ;
139- break ;
140- } else if inner. token != self . auth_token {
141- conn. close ( 1u32 . into ( ) , b"permission denied" ) ;
142- break ;
143- } else {
144- authed = true ;
145- tx. send ( Ok ( ( ) ) ) . await . ok ( ) ;
146- }
147- }
148- msg => {
149- if !authed {
150- conn. close ( 1u32 . into ( ) , b"permission denied" ) ;
151- break ;
152- } else {
153- self . handle_authenticated ( msg) . await ;
154- }
155- }
148+ // read first message: must be auth!
149+ let msg = read_request :: < StorageProtocol > ( & conn) . await ?;
150+ let auth_ok = if let Some ( StorageMessage :: Auth ( msg) ) = msg {
151+ let WithChannels { inner, tx, .. } = msg;
152+ if inner. token == self . auth_token {
153+ tx. send ( ( ) ) . await . ok ( ) ;
154+ true
155+ } else {
156+ false
156157 }
158+ } else {
159+ false
160+ } ;
161+
162+ // if not authenticated: close connection immediately.
163+ if !auth_ok {
164+ conn. close ( 1u32 . into ( ) , b"permission denied" ) ;
165+ return Ok ( ( ) ) ;
166+ }
167+
168+ // now the connection is authenticated and we can handle all subsequent requests.
169+ while let Some ( msg) = read_request :: < StorageProtocol > ( & conn) . await ? {
170+ self . handle_request ( msg) . await ;
157171 }
158172 conn. closed ( ) . await ;
159173 Ok ( ( ) )
@@ -170,7 +184,7 @@ mod storage {
170184 }
171185 }
172186
173- async fn handle_authenticated ( & self , msg : StorageMessage ) {
187+ async fn handle_request ( & self , msg : StorageMessage ) {
174188 match msg {
175189 StorageMessage :: Auth ( _) => unreachable ! ( "handled in ProtocolHandler::accept" ) ,
176190 StorageMessage :: Get ( get) => {
@@ -218,39 +232,63 @@ mod storage {
218232 }
219233
220234 pub struct StorageClient {
235+ api_token : String ,
221236 inner : Client < StorageProtocol > ,
222237 }
223238
224239 impl StorageClient {
225240 pub const ALPN : & [ u8 ] = ALPN ;
226241
227- pub fn connect ( endpoint : Endpoint , addr : impl Into < iroh:: NodeAddr > ) -> StorageClient {
242+ pub fn connect (
243+ endpoint : Endpoint ,
244+ addr : impl Into < iroh:: NodeAddr > ,
245+ api_token : & str ,
246+ ) -> StorageClient {
228247 let conn = IrohRemoteConnection :: new ( endpoint, addr. into ( ) , Self :: ALPN . to_vec ( ) ) ;
229248 StorageClient {
249+ api_token : api_token. to_string ( ) ,
230250 inner : Client :: boxed ( conn) ,
231251 }
232252 }
233253
234- pub async fn auth ( & self , token : & str ) -> Result < ( ) , anyhow:: Error > {
235- self . inner
254+ async fn authenticated_request ( & self ) -> Result < Request < StorageProtocol > , irpc:: Error > {
255+ let request = self . inner . request ( ) . await ?;
256+
257+ // if the connection is not new: no need to reauthenticate.
258+ if !request. is_new_connection ( ) {
259+ return Ok ( request) ;
260+ }
261+
262+ // if this is a new connection: use this request to send an auth message.
263+ request
236264 . rpc ( Auth {
237- token : token . to_string ( ) ,
265+ token : self . api_token . clone ( ) ,
238266 } )
239- . await ?
240- . map_err ( |err| anyhow:: anyhow!( err) )
267+ . await ?;
268+ // and create a new request for the actual call.
269+ let request = self . inner . request ( ) . await ?;
270+ // if this *again* created a new connection, we error out.
271+ if request. is_new_connection ( ) {
272+ Err ( RequestError :: Other ( anyhow ! ( "Connection is reconnecting too often" ) ) . into ( ) )
273+ } else {
274+ Ok ( request)
275+ }
241276 }
242277
243278 pub async fn get ( & self , key : String ) -> Result < Option < String > , irpc:: Error > {
244- self . inner . rpc ( Get { key } ) . await
279+ self . authenticated_request ( ) . await ? . rpc ( Get { key } ) . await
245280 }
246281
247282 pub async fn list ( & self ) -> Result < mpsc:: Receiver < String > , irpc:: Error > {
248- self . inner . server_streaming ( List , 10 ) . await
283+ self . authenticated_request ( )
284+ . await ?
285+ . server_streaming ( List , 10 )
286+ . await
249287 }
250288
251289 pub async fn set ( & self , key : String , value : String ) -> Result < ( ) , irpc:: Error > {
252290 let msg = Set { key, value } ;
253- self . inner . rpc ( msg) . await
291+ self . authenticated_request ( ) . await ? . rpc ( msg) . await
254292 }
255293 }
256294}
0 commit comments