Skip to content

Commit d9cf7a1

Browse files
committed
refactor(aggregator-client): minor adjustement following first usage in follower aggregator
- use an arc for the `ApiVersionProvider` so it fit with our current DI systems
1 parent 7d33e89 commit d9cf7a1

File tree

2 files changed

+27
-20
lines changed

2 files changed

+27
-20
lines changed

internal/mithril-aggregator-client/src/builder.rs

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use anyhow::Context;
22
use reqwest::{Client, IntoUrl, Proxy, Url};
33
use slog::{Logger, o};
44
use std::collections::HashMap;
5+
use std::sync::Arc;
56
use std::time::Duration;
67

78
use mithril_common::StdResult;
@@ -12,7 +13,7 @@ use crate::client::AggregatorClient;
1213
/// A builder of [AggregatorClient]
1314
pub struct AggregatorClientBuilder {
1415
aggregator_url_result: reqwest::Result<Url>,
15-
api_version_provider: Option<APIVersionProvider>,
16+
api_version_provider: Option<Arc<APIVersionProvider>>,
1617
additional_headers: Option<HashMap<String, String>>,
1718
timeout_duration: Option<Duration>,
1819
relay_endpoint: Option<String>,
@@ -41,7 +42,10 @@ impl AggregatorClientBuilder {
4142
}
4243

4344
/// Set the [APIVersionProvider] to use.
44-
pub fn with_api_version_provider(mut self, api_version_provider: APIVersionProvider) -> Self {
45+
pub fn with_api_version_provider(
46+
mut self,
47+
api_version_provider: Arc<APIVersionProvider>,
48+
) -> Self {
4549
self.api_version_provider = Some(api_version_provider);
4650
self
4751
}
@@ -59,8 +63,8 @@ impl AggregatorClientBuilder {
5963
}
6064

6165
/// Set the address of the relay
62-
pub fn with_relay_endpoint(mut self, relay_endpoint: String) -> Self {
63-
self.relay_endpoint = Some(relay_endpoint);
66+
pub fn with_relay_endpoint(mut self, relay_endpoint: Option<String>) -> Self {
67+
self.relay_endpoint = relay_endpoint;
6468
self
6569
}
6670

internal/mithril-aggregator-client/src/client.rs

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use anyhow::{Context, anyhow};
22
use reqwest::{IntoUrl, Response, Url, header::HeaderMap};
33
use semver::Version;
44
use slog::{Logger, error, warn};
5+
use std::sync::Arc;
56
use std::time::Duration;
67

78
use mithril_common::MITHRIL_API_VERSION_HEADER;
@@ -17,7 +18,7 @@ const API_VERSION_MISMATCH_WARNING_MESSAGE: &str = "OpenAPI version may be incom
1718
/// A client to send HTTP requests to a Mithril Aggregator
1819
pub struct AggregatorClient {
1920
pub(super) aggregator_endpoint: Url,
20-
pub(super) api_version_provider: APIVersionProvider,
21+
pub(super) api_version_provider: Arc<APIVersionProvider>,
2122
pub(super) additional_headers: HeaderMap,
2223
pub(super) timeout_duration: Option<Duration>,
2324
pub(super) client: reqwest::Client,
@@ -240,8 +241,9 @@ mod tests {
240241
#[tokio::test]
241242
async fn test_query_send_mithril_api_version_header() {
242243
let (server, mut client) = setup_server_and_client();
243-
client.api_version_provider =
244-
APIVersionProvider::new_with_default_version(Version::parse("1.2.9").unwrap());
244+
client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
245+
Version::parse("1.2.9").unwrap(),
246+
));
245247
server.mock(|when, then| {
246248
when.method(httpmock::Method::GET)
247249
.header(MITHRIL_API_VERSION_HEADER, "1.2.9");
@@ -254,8 +256,9 @@ mod tests {
254256
#[tokio::test]
255257
async fn test_query_send_additional_header_and_dont_override_mithril_api_version_header() {
256258
let (server, mut client) = setup_server_and_client();
257-
client.api_version_provider =
258-
APIVersionProvider::new_with_default_version(Version::parse("1.2.9").unwrap());
259+
client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
260+
Version::parse("1.2.9").unwrap(),
261+
));
259262
client.additional_headers = {
260263
let mut headers = HeaderMap::new();
261264
headers.insert(MITHRIL_API_VERSION_HEADER, "9.4.5".parse().unwrap());
@@ -336,9 +339,9 @@ mod tests {
336339
let (logger, log_inspector) = TestLogger::memory();
337340
let client = AggregatorClient::builder("http://whatever")
338341
.with_logger(logger)
339-
.with_api_version_provider(APIVersionProvider::new_with_default_version(
342+
.with_api_version_provider(Arc::new(APIVersionProvider::new_with_default_version(
340343
Version::parse(client_version).unwrap(),
341-
))
344+
)))
342345
.build()
343346
.unwrap();
344347
let response =
@@ -360,9 +363,9 @@ mod tests {
360363
let (logger, log_inspector) = TestLogger::memory();
361364
let client = AggregatorClient::builder("http://whatever")
362365
.with_logger(logger)
363-
.with_api_version_provider(APIVersionProvider::new_with_default_version(
366+
.with_api_version_provider(Arc::new(APIVersionProvider::new_with_default_version(
364367
Version::parse(version).unwrap(),
365-
))
368+
)))
366369
.build()
367370
.unwrap();
368371
let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, version);
@@ -379,9 +382,9 @@ mod tests {
379382
let (logger, log_inspector) = TestLogger::memory();
380383
let client = AggregatorClient::builder("http://whatever")
381384
.with_logger(logger)
382-
.with_api_version_provider(APIVersionProvider::new_with_default_version(
385+
.with_api_version_provider(Arc::new(APIVersionProvider::new_with_default_version(
383386
Version::parse(client_version).unwrap(),
384-
))
387+
)))
385388
.build()
386389
.unwrap();
387390
let response =
@@ -402,7 +405,7 @@ mod tests {
402405
let (logger, log_inspector) = TestLogger::memory();
403406
let client = AggregatorClient::builder("http://whatever")
404407
.with_logger(logger)
405-
.with_api_version_provider(APIVersionProvider::default())
408+
.with_api_version_provider(Arc::new(APIVersionProvider::default()))
406409
.build()
407410
.unwrap();
408411
let response =
@@ -418,7 +421,7 @@ mod tests {
418421
let (logger, log_inspector) = TestLogger::memory();
419422
let client = AggregatorClient::builder("http://whatever")
420423
.with_logger(logger)
421-
.with_api_version_provider(APIVersionProvider::default())
424+
.with_api_version_provider(Arc::new(APIVersionProvider::default()))
422425
.build()
423426
.unwrap();
424427
let response =
@@ -434,7 +437,7 @@ mod tests {
434437
let (logger, log_inspector) = TestLogger::memory();
435438
let client = AggregatorClient::builder("http://whatever")
436439
.with_logger(logger)
437-
.with_api_version_provider(APIVersionProvider::new_failing())
440+
.with_api_version_provider(Arc::new(APIVersionProvider::new_failing()))
438441
.build()
439442
.unwrap();
440443
let response = build_fake_response_with_header(MITHRIL_API_VERSION_HEADER, "1.0.0");
@@ -450,9 +453,9 @@ mod tests {
450453
let client_version = "1.0.0";
451454
let (server, mut client) = setup_server_and_client();
452455
let (logger, log_inspector) = TestLogger::memory();
453-
client.api_version_provider = APIVersionProvider::new_with_default_version(
456+
client.api_version_provider = Arc::new(APIVersionProvider::new_with_default_version(
454457
Version::parse(client_version).unwrap(),
455-
);
458+
));
456459
client.logger = logger;
457460
server.mock(|_, then| {
458461
then.status(StatusCode::CREATED.as_u16())

0 commit comments

Comments
 (0)