33//! * Manually implementing the connection loop
44//! * Authenticating peers
55
6+ use std:: time:: Duration ;
7+
68use anyhow:: Result ;
7- use iroh:: { protocol:: Router , Endpoint , Watcher } ;
9+ use iroh:: { protocol:: Router , Endpoint , NodeAddr , SecretKey , Watcher } ;
810
911use self :: storage:: { StorageClient , StorageServer } ;
1012
@@ -17,20 +19,28 @@ async fn main() -> Result<()> {
1719}
1820
1921async fn remote ( ) -> Result < ( ) > {
20- let ( server_router, server_addr) = {
21- let endpoint = Endpoint :: builder ( ) . discovery_n0 ( ) . bind ( ) . await ?;
22+ let server_secret_key = SecretKey :: generate ( & mut rand:: rngs:: OsRng ) ;
23+ let server_addr = NodeAddr :: new ( server_secret_key. public ( ) ) ;
24+ let start_server = async move || {
25+ let endpoint = Endpoint :: builder ( )
26+ . secret_key ( server_secret_key. clone ( ) )
27+ . discovery_n0 ( )
28+ . bind ( )
29+ . await ?;
2230 let server = StorageServer :: new ( "secret" . to_string ( ) ) ;
2331 let router = Router :: builder ( endpoint. clone ( ) )
2432 . accept ( StorageServer :: ALPN , server. clone ( ) )
2533 . spawn ( ) ;
26- let addr = endpoint. node_addr ( ) . initialized ( ) . await ;
27- ( router, addr)
34+ let _ = endpoint. home_relay ( ) . initialized ( ) . await ;
35+ // wait a bit for publishing to complete..
36+ tokio:: time:: sleep ( Duration :: from_millis ( 500 ) ) . await ;
37+ anyhow:: Ok ( router)
2838 } ;
39+ let mut server_router = ( start_server) ( ) . await ?;
2940
3041 // correct authentication
31- let client_endpoint = Endpoint :: builder ( ) . bind ( ) . await ?;
32- let api = StorageClient :: connect ( client_endpoint, server_addr. clone ( ) ) ;
33- api. auth ( "secret" ) . await ?;
42+ let client_endpoint = Endpoint :: builder ( ) . discovery_n0 ( ) . bind ( ) . await ?;
43+ let api = StorageClient :: connect ( client_endpoint, server_addr. clone ( ) , "secret" ) ;
3444 api. set ( "hello" . to_string ( ) , "world" . to_string ( ) ) . await ?;
3545 api. set ( "goodbye" . to_string ( ) , "world" . to_string ( ) ) . await ?;
3646 let value = api. get ( "hello" . to_string ( ) ) . await ?;
@@ -40,15 +50,21 @@ async fn remote() -> Result<()> {
4050 println ! ( "list value = {value:?}" ) ;
4151 }
4252
43- // invalid authentication
44- let client_endpoint = Endpoint :: builder ( ) . bind ( ) . await ?;
45- let api = StorageClient :: connect ( client_endpoint, server_addr. clone ( ) ) ;
46- assert ! ( api. auth( "bad" ) . await . is_err( ) ) ;
47- assert ! ( api. get( "hello" . to_string( ) ) . await . is_err( ) ) ;
53+ // restart server
54+ server_router. shutdown ( ) . await ?;
55+ server_router = ( start_server) ( ) . await ?;
56+
57+ // reconnections work: client will transparently reauthenticate
58+ println ! ( "restarting server" ) ;
59+ let value = api. get ( "hello" . to_string ( ) ) . await ?;
60+ println ! ( "value = {value:?}" ) ;
61+ api. set ( "hello" . to_string ( ) , "world" . to_string ( ) ) . await ?;
62+ let value = api. get ( "hello" . to_string ( ) ) . await ?;
63+ println ! ( "value = {value:?}" ) ;
4864
49- // no authentication
65+ // invalid authentication
5066 let client_endpoint = Endpoint :: builder ( ) . bind ( ) . await ?;
51- let api = StorageClient :: connect ( client_endpoint, server_addr) ;
67+ let api = StorageClient :: connect ( client_endpoint, server_addr. clone ( ) , "bad" ) ;
5268 assert ! ( api. get( "hello" . to_string( ) ) . await . is_err( ) ) ;
5369
5470 drop ( server_router) ;
@@ -65,15 +81,15 @@ mod storage {
6581 sync:: { Arc , Mutex } ,
6682 } ;
6783
68- use anyhow:: Result ;
84+ use anyhow:: { anyhow , Result } ;
6985 use iroh:: {
7086 endpoint:: Connection ,
7187 protocol:: { AcceptError , ProtocolHandler } ,
7288 Endpoint ,
7389 } ;
7490 use irpc:: {
7591 channel:: { mpsc, oneshot} ,
76- Client , WithChannels ,
92+ Client , Request , RequestError , WithChannels ,
7793 } ;
7894 // Import the macro
7995 use irpc_derive:: rpc_requests;
@@ -110,7 +126,8 @@ mod storage {
110126 #[ rpc_requests( message = StorageMessage ) ]
111127 #[ derive( Serialize , Deserialize , Debug ) ]
112128 enum StorageProtocol {
113- #[ rpc( tx=oneshot:: Sender <Result <( ) , String >>) ]
129+ // Connection will be closed if auth fails.
130+ #[ rpc( tx=oneshot:: Sender <( ) >) ]
114131 Auth ( Auth ) ,
115132 #[ rpc( tx=oneshot:: Sender <Option <String >>) ]
116133 Get ( Get ) ,
@@ -130,31 +147,29 @@ mod storage {
130147
131148 impl ProtocolHandler for StorageServer {
132149 async fn accept ( & self , conn : Connection ) -> Result < ( ) , AcceptError > {
133- let mut authed = false ;
134- while let Some ( msg) = read_request :: < StorageProtocol > ( & conn) . await ? {
135- match msg {
136- StorageMessage :: Auth ( msg) => {
137- let WithChannels { inner, tx, .. } = msg;
138- if authed {
139- conn. close ( 1u32 . into ( ) , b"invalid message" ) ;
140- break ;
141- } else if inner. token != self . auth_token {
142- conn. close ( 1u32 . into ( ) , b"permission denied" ) ;
143- break ;
144- } else {
145- authed = true ;
146- tx. send ( Ok ( ( ) ) ) . await . ok ( ) ;
147- }
148- }
149- msg => {
150- if !authed {
151- conn. close ( 1u32 . into ( ) , b"permission denied" ) ;
152- break ;
153- } else {
154- self . handle_authenticated ( msg) . await ;
155- }
156- }
150+ // read first message: must be auth!
151+ let msg = read_request :: < StorageProtocol > ( & conn) . await ?;
152+ let auth_ok = if let Some ( StorageMessage :: Auth ( msg) ) = msg {
153+ let WithChannels { inner, tx, .. } = msg;
154+ if inner. token == self . auth_token {
155+ tx. send ( ( ) ) . await . ok ( ) ;
156+ true
157+ } else {
158+ false
157159 }
160+ } else {
161+ false
162+ } ;
163+
164+ // if not authenticated: close connection immediately.
165+ if !auth_ok {
166+ conn. close ( 1u32 . into ( ) , b"permission denied" ) ;
167+ return Ok ( ( ) ) ;
168+ }
169+
170+ // now the connection is authenticated and we can handle all subsequent requests.
171+ while let Some ( msg) = read_request :: < StorageProtocol > ( & conn) . await ? {
172+ self . handle_request ( msg) . await ;
158173 }
159174 conn. closed ( ) . await ;
160175 Ok ( ( ) )
@@ -171,7 +186,7 @@ mod storage {
171186 }
172187 }
173188
174- async fn handle_authenticated ( & self , msg : StorageMessage ) {
189+ async fn handle_request ( & self , msg : StorageMessage ) {
175190 match msg {
176191 StorageMessage :: Auth ( _) => unreachable ! ( "handled in ProtocolHandler::accept" ) ,
177192 StorageMessage :: Get ( get) => {
@@ -219,39 +234,63 @@ mod storage {
219234 }
220235
221236 pub struct StorageClient {
237+ api_token : String ,
222238 inner : Client < StorageProtocol > ,
223239 }
224240
225241 impl StorageClient {
226242 pub const ALPN : & [ u8 ] = ALPN ;
227243
228- pub fn connect ( endpoint : Endpoint , addr : impl Into < iroh:: NodeAddr > ) -> StorageClient {
244+ pub fn connect (
245+ endpoint : Endpoint ,
246+ addr : impl Into < iroh:: NodeAddr > ,
247+ api_token : & str ,
248+ ) -> StorageClient {
229249 let conn = IrohRemoteConnection :: new ( endpoint, addr. into ( ) , Self :: ALPN . to_vec ( ) ) ;
230250 StorageClient {
251+ api_token : api_token. to_string ( ) ,
231252 inner : Client :: boxed ( conn) ,
232253 }
233254 }
234255
235- pub async fn auth ( & self , token : & str ) -> Result < ( ) , anyhow:: Error > {
236- self . inner
256+ async fn authenticated_request ( & self ) -> Result < Request < StorageProtocol > , irpc:: Error > {
257+ let request = self . inner . request ( ) . await ?;
258+
259+ // if the connection is not new: no need to reauthenticate.
260+ if !request. is_new_connection ( ) {
261+ return Ok ( request) ;
262+ }
263+
264+ // if this is a new connection: use this request to send an auth message.
265+ request
237266 . rpc ( Auth {
238- token : token . to_string ( ) ,
267+ token : self . api_token . clone ( ) ,
239268 } )
240- . await ?
241- . map_err ( |err| anyhow:: anyhow!( err) )
269+ . await ?;
270+ // and create a new request for the actual call.
271+ let request = self . inner . request ( ) . await ?;
272+ // if this *again* created a new connection, we error out.
273+ if request. is_new_connection ( ) {
274+ Err ( RequestError :: Other ( anyhow ! ( "Connection is reconnecting too often" ) ) . into ( ) )
275+ } else {
276+ Ok ( request)
277+ }
242278 }
243279
244280 pub async fn get ( & self , key : String ) -> Result < Option < String > , irpc:: Error > {
245- self . inner . rpc ( Get { key } ) . await
281+ self . authenticated_request ( ) . await ? . rpc ( Get { key } ) . await
246282 }
247283
248284 pub async fn list ( & self ) -> Result < mpsc:: Receiver < String > , irpc:: Error > {
249- self . inner . server_streaming ( List , 10 ) . await
285+ self . authenticated_request ( )
286+ . await ?
287+ . server_streaming ( List , 10 )
288+ . await
250289 }
251290
252291 pub async fn set ( & self , key : String , value : String ) -> Result < ( ) , irpc:: Error > {
253292 let msg = Set { key, value } ;
254- self . inner . rpc ( msg) . await
293+ self . authenticated_request ( ) . await ? . rpc ( msg) . await
255294 }
256295 }
257296}
0 commit comments