@@ -57,6 +57,21 @@ impl PallasChainReader {
5757 . with_context ( || "PallasChainReader failed to get a client" )
5858 }
5959
60+ #[ cfg( test) ]
61+ /// Check if the client already exists (test only).
62+ fn has_client ( & self ) -> bool {
63+ self . client . is_some ( )
64+ }
65+
66+ /// Drops the client by aborting the connection and setting it to `None`.
67+ fn drop_client ( & mut self ) {
68+ if let Some ( client) = self . client . take ( ) {
69+ tokio:: spawn ( async move {
70+ let _ = client. abort ( ) . await ;
71+ } ) ;
72+ }
73+ }
74+
6075 /// Intersects the point of the chain with the given point.
6176 async fn find_intersect_point ( & mut self , point : & RawCardanoPoint ) -> StdResult < ( ) > {
6277 let logger = self . logger . clone ( ) ;
@@ -99,30 +114,38 @@ impl PallasChainReader {
99114
100115impl Drop for PallasChainReader {
101116 fn drop ( & mut self ) {
102- if let Some ( client) = self . client . take ( ) {
103- tokio:: spawn ( async move {
104- let _ = client. abort ( ) . await ;
105- } ) ;
106- }
117+ self . drop_client ( ) ;
107118 }
108119}
109120
110121#[ async_trait]
111122impl ChainBlockReader for PallasChainReader {
112123 async fn set_chain_point ( & mut self , point : & RawCardanoPoint ) -> StdResult < ( ) > {
113- self . find_intersect_point ( point) . await
124+ match self . find_intersect_point ( point) . await {
125+ Ok ( ( ) ) => Ok ( ( ) ) ,
126+ Err ( err) => {
127+ self . drop_client ( ) ;
128+
129+ return Err ( err) ;
130+ }
131+ }
114132 }
115133
116134 async fn get_next_chain_block ( & mut self ) -> StdResult < Option < ChainBlockNextAction > > {
117135 let client = self . get_client ( ) . await ?;
118136 let chainsync = client. chainsync ( ) ;
119-
120137 let next = match chainsync. has_agency ( ) {
121- true => chainsync. request_next ( ) . await ? ,
122- false => chainsync. recv_while_must_reply ( ) . await ? ,
138+ true => chainsync. request_next ( ) . await ,
139+ false => chainsync. recv_while_must_reply ( ) . await ,
123140 } ;
141+ match next {
142+ Ok ( next) => self . process_chain_block_next_action ( next) . await ,
143+ Err ( err) => {
144+ self . drop_client ( ) ;
124145
125- self . process_chain_block_next_action ( next) . await
146+ return Err ( err. into ( ) ) ;
147+ }
148+ }
126149 }
127150}
128151
@@ -142,6 +165,7 @@ mod tests {
142165 use super :: * ;
143166
144167 use crate :: test_utils:: TestLogger ;
168+ use crate :: * ;
145169 use crate :: { entities:: BlockNumber , test_utils:: TempDir } ;
146170
147171 /// Enum representing the action to be performed by the server.
@@ -201,7 +225,7 @@ mod tests {
201225 socket_path : PathBuf ,
202226 action : ServerAction ,
203227 has_agency : HasAgency ,
204- ) -> tokio:: task:: JoinHandle < ( ) > {
228+ ) -> tokio:: task:: JoinHandle < NodeServer > {
205229 tokio:: spawn ( {
206230 async move {
207231 if socket_path. exists ( ) {
@@ -249,14 +273,15 @@ mod tests {
249273 . unwrap ( ) ;
250274 }
251275 }
276+
277+ server
252278 }
253279 } )
254280 }
255281
256282 #[ tokio:: test]
257283 async fn get_next_chain_block_rolls_backward ( ) {
258- let socket_path =
259- create_temp_dir ( "get_next_chain_block_rolls_backward" ) . join ( "node.socket" ) ;
284+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
260285 let known_point = get_fake_specific_point ( ) ;
261286 let server = setup_server (
262287 socket_path. clone ( ) ,
@@ -291,7 +316,7 @@ mod tests {
291316
292317 #[ tokio:: test]
293318 async fn get_next_chain_block_rolls_forward ( ) {
294- let socket_path = create_temp_dir ( "get_next_chain_block_rolls_forward" ) . join ( "node.socket" ) ;
319+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
295320 let known_point = get_fake_specific_point ( ) ;
296321 let server = setup_server (
297322 socket_path. clone ( ) ,
@@ -326,7 +351,7 @@ mod tests {
326351
327352 #[ tokio:: test]
328353 async fn get_next_chain_block_has_no_agency ( ) {
329- let socket_path = create_temp_dir ( "get_next_chain_block_has_no_agency" ) . join ( "node.socket" ) ;
354+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
330355 let known_point = get_fake_specific_point ( ) ;
331356 let server = setup_server (
332357 socket_path. clone ( ) ,
@@ -375,4 +400,57 @@ mod tests {
375400 _ => panic ! ( "Unexpected chain block action" ) ,
376401 }
377402 }
403+
404+ #[ tokio:: test]
405+ async fn cached_client_is_dropped_when_returning_error ( ) {
406+ let socket_path = create_temp_dir ( current_function ! ( ) ) . join ( "node.socket" ) ;
407+ let socket_path_clone = socket_path. clone ( ) ;
408+ let known_point = get_fake_specific_point ( ) ;
409+ let server = setup_server (
410+ socket_path. clone ( ) ,
411+ ServerAction :: RollForward ,
412+ HasAgency :: Yes ,
413+ )
414+ . await ;
415+ let client = tokio:: spawn ( async move {
416+ let mut chain_reader = PallasChainReader :: new (
417+ socket_path_clone. as_path ( ) ,
418+ CardanoNetwork :: TestNet ( 10 ) ,
419+ TestLogger :: stdout ( ) ,
420+ ) ;
421+
422+ chain_reader
423+ . set_chain_point ( & RawCardanoPoint :: from ( known_point. clone ( ) ) )
424+ . await
425+ . unwrap ( ) ;
426+
427+ chain_reader. get_next_chain_block ( ) . await . unwrap ( ) . unwrap ( ) ;
428+
429+ chain_reader
430+ } ) ;
431+
432+ let ( server_res, client_res) = tokio:: join!( server, client) ;
433+ let chain_reader = client_res. expect ( "Client failed to get chain reader" ) ;
434+ let server = server_res. expect ( "Server failed to get server" ) ;
435+ server. abort ( ) . await ;
436+
437+ let client = tokio:: spawn ( async move {
438+ let mut chain_reader = chain_reader;
439+
440+ assert ! ( chain_reader. has_client( ) , "Client should exist" ) ;
441+
442+ chain_reader
443+ . get_next_chain_block ( )
444+ . await
445+ . expect_err ( "Chain reader get_next_chain_block should fail" ) ;
446+
447+ assert ! (
448+ !chain_reader. has_client( ) ,
449+ "Client should have been dropped after error"
450+ ) ;
451+
452+ chain_reader
453+ } ) ;
454+ client. await . unwrap ( ) ;
455+ }
378456}
0 commit comments