@@ -2,6 +2,7 @@ use anyhow::{Context, anyhow};
22use reqwest:: { IntoUrl , Response , Url } ;
33use semver:: Version ;
44use slog:: { Logger , error, warn} ;
5+ use std:: time:: Duration ;
56
67use mithril_common:: MITHRIL_API_VERSION_HEADER ;
78use mithril_common:: api_version:: APIVersionProvider ;
@@ -16,6 +17,7 @@ const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incom
1617pub struct AggregatorClient {
1718 pub ( super ) aggregator_endpoint : Url ,
1819 pub ( super ) api_version_provider : APIVersionProvider ,
20+ pub ( super ) timeout_duration : Option < Duration > ,
1921 pub ( super ) client : reqwest:: Client ,
2022 pub ( super ) logger : Logger ,
2123}
@@ -41,6 +43,10 @@ impl AggregatorClient {
4143 request_builder = request_builder. json ( & body) ;
4244 }
4345
46+ if let Some ( timeout) = self . timeout_duration {
47+ request_builder = request_builder. timeout ( timeout) ;
48+ }
49+
4450 match request_builder. send ( ) . await {
4551 Ok ( response) => {
4652 self . warn_if_api_version_mismatch ( & response) ;
@@ -159,6 +165,15 @@ mod tests {
159165 chu : u8 ,
160166 }
161167
168+ impl TestBody {
169+ fn new < P : Into < String > > ( pika : P , chu : u8 ) -> Self {
170+ Self {
171+ pika : pika. into ( ) ,
172+ chu,
173+ }
174+ }
175+ }
176+
162177 struct TestPostQuery {
163178 body : TestBody ,
164179 }
@@ -226,6 +241,23 @@ mod tests {
226241
227242 client. send ( TestGetQuery ) . await . expect ( "should not fail" ) ;
228243 }
244+
245+ #[ tokio:: test]
246+ async fn test_get_query_timeout ( ) {
247+ let ( server, mut client) = setup_server_and_client ( ) ;
248+ client. timeout_duration = Some ( Duration :: from_millis ( 10 ) ) ;
249+ let _server_mock = server. mock ( |when, then| {
250+ when. method ( httpmock:: Method :: GET ) ;
251+ then. delay ( Duration :: from_millis ( 100 ) ) ;
252+ } ) ;
253+
254+ let error = client. send ( TestGetQuery ) . await . expect_err ( "should not fail" ) ;
255+
256+ assert ! (
257+ matches!( error, AggregatorClientError :: RemoteServerUnreachable ( _) ) ,
258+ "unexpected error type: {error:?}"
259+ ) ;
260+ }
229261 }
230262
231263 mod post {
@@ -238,22 +270,13 @@ mod tests {
238270 when. method ( httpmock:: Method :: POST )
239271 . path ( "/dummy-post-route" )
240272 . header ( "content-type" , "application/json" )
241- . body (
242- serde_json:: to_string ( & TestBody {
243- pika : "miaouss" . to_string ( ) ,
244- chu : 5 ,
245- } )
246- . unwrap ( ) ,
247- ) ;
273+ . body ( serde_json:: to_string ( & TestBody :: new ( "miaouss" , 5 ) ) . unwrap ( ) ) ;
248274 then. status ( 201 ) ;
249275 } ) ;
250276
251277 let response = client
252278 . send ( TestPostQuery {
253- body : TestBody {
254- pika : "miaouss" . to_string ( ) ,
255- chu : 5 ,
256- } ,
279+ body : TestBody :: new ( "miaouss" , 5 ) ,
257280 } )
258281 . await
259282 . unwrap ( ) ;
@@ -274,14 +297,33 @@ mod tests {
274297
275298 client
276299 . send ( TestPostQuery {
277- body : TestBody {
278- pika : "a" . to_string ( ) ,
279- chu : 3 ,
280- } ,
300+ body : TestBody :: new ( "miaouss" , 3 ) ,
281301 } )
282302 . await
283303 . expect ( "should not fail" ) ;
284304 }
305+
306+ #[ tokio:: test]
307+ async fn test_post_query_timeout ( ) {
308+ let ( server, mut client) = setup_server_and_client ( ) ;
309+ client. timeout_duration = Some ( Duration :: from_millis ( 10 ) ) ;
310+ let _server_mock = server. mock ( |when, then| {
311+ when. method ( httpmock:: Method :: POST ) ;
312+ then. delay ( Duration :: from_millis ( 100 ) ) ;
313+ } ) ;
314+
315+ let error = client
316+ . send ( TestPostQuery {
317+ body : TestBody :: new ( "miaouss" , 3 ) ,
318+ } )
319+ . await
320+ . expect_err ( "should not fail" ) ;
321+
322+ assert ! (
323+ matches!( error, AggregatorClientError :: RemoteServerUnreachable ( _) ) ,
324+ "unexpected error type: {error:?}"
325+ ) ;
326+ }
285327 }
286328
287329 mod warn_if_api_version_mismatch {
@@ -481,10 +523,7 @@ mod tests {
481523
482524 client
483525 . send ( TestPostQuery {
484- body : TestBody {
485- pika : "miaouss" . to_string ( ) ,
486- chu : 5 ,
487- } ,
526+ body : TestBody :: new ( "miaouss" , 3 ) ,
488527 } )
489528 . await
490529 . unwrap ( ) ;
0 commit comments