11use anyhow:: { Context , anyhow} ;
2- use reqwest:: { IntoUrl , Response , Url } ;
2+ use reqwest:: { IntoUrl , Response , Url , header :: HeaderMap } ;
33use semver:: Version ;
44use slog:: { Logger , error, warn} ;
55use std:: time:: Duration ;
@@ -18,6 +18,7 @@ const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incom
1818pub struct AggregatorClient {
1919 pub ( super ) aggregator_endpoint : Url ,
2020 pub ( super ) api_version_provider : APIVersionProvider ,
21+ pub ( super ) additional_headers : HeaderMap ,
2122 pub ( super ) timeout_duration : Option < Duration > ,
2223 pub ( super ) client : reqwest:: Client ,
2324 pub ( super ) logger : Logger ,
@@ -39,6 +40,7 @@ impl AggregatorClient {
3940 QueryMethod :: Get => self . client . get ( self . join_aggregator_endpoint ( & query. route ( ) ) ?) ,
4041 QueryMethod :: Post => self . client . post ( self . join_aggregator_endpoint ( & query. route ( ) ) ?) ,
4142 }
43+ . headers ( self . additional_headers . clone ( ) )
4244 . header ( MITHRIL_API_VERSION_HEADER , current_api_version. to_string ( ) ) ;
4345
4446 if let Some ( body) = query. body ( ) {
@@ -244,6 +246,29 @@ mod tests {
244246 client. send ( TestGetQuery ) . await . expect ( "should not fail" ) ;
245247 }
246248
249+ #[ tokio:: test]
250+ async fn test_get_query_send_additional_header_and_dont_override_mithril_api_version_header ( )
251+ {
252+ let ( server, mut client) = setup_server_and_client ( ) ;
253+ client. api_version_provider =
254+ APIVersionProvider :: new_with_default_version ( Version :: parse ( "1.2.9" ) . unwrap ( ) ) ;
255+ client. additional_headers = {
256+ let mut headers = HeaderMap :: new ( ) ;
257+ headers. insert ( MITHRIL_API_VERSION_HEADER , "9.4.5" . parse ( ) . unwrap ( ) ) ;
258+ headers. insert ( "foo" , "bar" . parse ( ) . unwrap ( ) ) ;
259+ headers
260+ } ;
261+
262+ server. mock ( |when, then| {
263+ when. method ( httpmock:: Method :: GET )
264+ . header ( MITHRIL_API_VERSION_HEADER , "1.2.9" )
265+ . header ( "foo" , "bar" ) ;
266+ then. status ( 200 ) . body ( r#"{"foo": "a", "bar": 1}"# ) ;
267+ } ) ;
268+
269+ client. send ( TestGetQuery ) . await . expect ( "should not fail" ) ;
270+ }
271+
247272 #[ tokio:: test]
248273 async fn test_get_query_timeout ( ) {
249274 let ( server, mut client) = setup_server_and_client ( ) ;
@@ -276,14 +301,12 @@ mod tests {
276301 then. status ( 201 ) ;
277302 } ) ;
278303
279- let response = client
304+ client
280305 . send ( TestPostQuery {
281306 body : TestBody :: new ( "miaouss" , 5 ) ,
282307 } )
283308 . await
284309 . unwrap ( ) ;
285-
286- assert_eq ! ( response, ( ) )
287310 }
288311
289312 #[ tokio:: test]
@@ -305,6 +328,34 @@ mod tests {
305328 . expect ( "should not fail" ) ;
306329 }
307330
331+ #[ tokio:: test]
332+ async fn test_post_query_send_additional_header_and_dont_override_mithril_api_version_header ( )
333+ {
334+ let ( server, mut client) = setup_server_and_client ( ) ;
335+ client. api_version_provider =
336+ APIVersionProvider :: new_with_default_version ( Version :: parse ( "1.2.9" ) . unwrap ( ) ) ;
337+ client. additional_headers = {
338+ let mut headers = HeaderMap :: new ( ) ;
339+ headers. insert ( MITHRIL_API_VERSION_HEADER , "9.4.5" . parse ( ) . unwrap ( ) ) ;
340+ headers. insert ( "foo" , "bar" . parse ( ) . unwrap ( ) ) ;
341+ headers
342+ } ;
343+
344+ server. mock ( |when, then| {
345+ when. method ( httpmock:: Method :: POST )
346+ . header ( MITHRIL_API_VERSION_HEADER , "1.2.9" )
347+ . header ( "foo" , "bar" ) ;
348+ then. status ( 201 ) . body ( r#"{"foo": "a", "bar": 1}"# ) ;
349+ } ) ;
350+
351+ client
352+ . send ( TestPostQuery {
353+ body : TestBody :: new ( "miaouss" , 3 ) ,
354+ } )
355+ . await
356+ . expect ( "should not fail" ) ;
357+ }
358+
308359 #[ tokio:: test]
309360 async fn test_post_query_timeout ( ) {
310361 let ( server, mut client) = setup_server_and_client ( ) ;
0 commit comments