Skip to content

Commit cfa8717

Browse files
committed
refactor: dependency injection on grpc and BeeMsg handlers
* Accept a trait impl (`App*` / `AppAll`) to instead of conrete `Context` type on handlers * Implement a runtime `App` that satifies the above and is passed to the handlers at runtime * Add a `TestApp` implementation with rudimentary functionality for testing the handlers
1 parent 38f08d2 commit cfa8717

File tree

22 files changed

+998
-592
lines changed

22 files changed

+998
-592
lines changed

mgmtd/src/app.rs

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
//! Interfaces and implementations for in-app interaction between tasks or threads.
2+
3+
mod runtime;
4+
#[cfg(test)]
5+
pub(crate) mod test;
6+
7+
use crate::StaticInfo;
8+
use anyhow::Result;
9+
use protobuf::license::GetCertDataResult;
10+
pub(crate) use runtime::App;
11+
use rusqlite::{Connection, Transaction};
12+
use shared::bee_msg::Msg;
13+
use shared::bee_serde::{Deserializable, Serializable};
14+
use shared::types::{NodeId, NodeType, Uid};
15+
use std::fmt::Debug;
16+
use std::future::Future;
17+
use std::net::SocketAddr;
18+
use std::path::Path;
19+
use std::sync::Arc;
20+
21+
pub(crate) trait AppInfo {
22+
fn static_info(&self) -> &StaticInfo;
23+
}
24+
25+
pub(crate) trait AppDb {
26+
fn read_tx<T: Send + 'static + FnOnce(&Transaction) -> Result<R>, R: Send + 'static>(
27+
&self,
28+
op: T,
29+
) -> impl Future<Output = Result<R>> + Send;
30+
fn write_tx<T: Send + 'static + FnOnce(&Transaction) -> Result<R>, R: Send + 'static>(
31+
&self,
32+
op: T,
33+
) -> impl Future<Output = Result<R>> + Send;
34+
fn write_tx_no_sync<T: Send + 'static + FnOnce(&Transaction) -> Result<R>, R: Send + 'static>(
35+
&self,
36+
op: T,
37+
) -> impl Future<Output = Result<R>> + Send;
38+
fn conn<T: Send + 'static + FnOnce(&mut Connection) -> Result<R>, R: Send + 'static>(
39+
&self,
40+
op: T,
41+
) -> impl Future<Output = Result<R>> + Send;
42+
}
43+
44+
pub(crate) trait AppConn {
45+
/// Sends a [Msg] to a node and receives the response.
46+
fn request<M: Msg + Serializable, R: Msg + Deserializable>(
47+
&self,
48+
node_uid: Uid,
49+
msg: &M,
50+
) -> impl Future<Output = Result<R>> + Send;
51+
fn send_notifications<M: Msg + Serializable>(
52+
&self,
53+
node_types: &'static [NodeType],
54+
msg: &M,
55+
) -> impl Future<Output = ()> + Send;
56+
fn replace_node_addrs(&self, node_uid: Uid, new_addrs: impl Into<Arc<[SocketAddr]>>);
57+
}
58+
59+
pub(crate) trait AppRunState {
60+
fn pre_shutdown(&self) -> bool;
61+
fn notify_client_pulled_state(&self, node_type: NodeType, node_id: NodeId);
62+
}
63+
64+
pub(crate) trait AppLicense {
65+
fn load_and_verify_cert(&self, cert_path: &Path)
66+
-> impl Future<Output = Result<String>> + Send;
67+
fn get_cert_data(&self) -> Result<GetCertDataResult>;
68+
fn get_num_machines(&self) -> Result<u32>;
69+
}
70+
71+
pub(crate) trait AppAll:
72+
AppDb + AppConn + AppRunState + AppLicense + AppInfo + Debug + Clone + Send + Sync + 'static
73+
{
74+
}
75+
impl<T> AppAll for T where
76+
T: AppDb + AppConn + AppRunState + AppLicense + AppInfo + Debug + Clone + Send + Sync + 'static
77+
{
78+
}

mgmtd/src/app/runtime.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
use super::*;
2+
use crate::ClientPulledStateNotification;
3+
use crate::bee_msg::dispatch_request;
4+
use crate::license::LicenseVerifier;
5+
use anyhow::Result;
6+
use protobuf::license::GetCertDataResult;
7+
use rusqlite::{Connection, Transaction};
8+
use shared::conn::msg_dispatch::{DispatchRequest, Request};
9+
use shared::conn::outgoing::Pool;
10+
use shared::run_state::WeakRunStateHandle;
11+
use sqlite::Connections;
12+
use std::fmt::Debug;
13+
use std::ops::Deref;
14+
use tokio::sync::mpsc;
15+
16+
/// A collection of Handles used for interacting and accessing the different components of the app.
17+
///
18+
/// This is the actual runtime object that can be shared between tasks. Interfaces should, however,
19+
/// accept any implementation of the AppContext trait instead.
20+
#[derive(Clone, Debug)]
21+
pub(crate) struct App(Arc<InnerAppHandles>);
22+
23+
/// Stores the actual handles.
24+
#[derive(Debug)]
25+
pub(crate) struct InnerAppHandles {
26+
pub conn: Pool,
27+
pub db: Connections,
28+
pub license: LicenseVerifier,
29+
pub info: &'static StaticInfo,
30+
pub run_state: WeakRunStateHandle,
31+
shutdown_client_id: mpsc::Sender<ClientPulledStateNotification>,
32+
}
33+
34+
impl App {
35+
/// Creates a new AppHandles object.
36+
///
37+
/// Takes all the stored handles.
38+
pub(crate) fn new(
39+
conn: Pool,
40+
db: Connections,
41+
license: LicenseVerifier,
42+
info: &'static StaticInfo,
43+
run_state: WeakRunStateHandle,
44+
shutdown_client_id: mpsc::Sender<ClientPulledStateNotification>,
45+
) -> Self {
46+
Self(Arc::new(InnerAppHandles {
47+
conn,
48+
db,
49+
license,
50+
info,
51+
run_state,
52+
shutdown_client_id,
53+
}))
54+
}
55+
}
56+
57+
/// Derefs to InnerAppHandle which stores all the handles.
58+
///
59+
/// Allows transparent access.
60+
impl Deref for App {
61+
type Target = InnerAppHandles;
62+
63+
fn deref(&self) -> &Self::Target {
64+
&self.0
65+
}
66+
}
67+
68+
/// Adds BeeMsg dispatching functionality to AppHandles
69+
impl DispatchRequest for App {
70+
async fn dispatch_request(&self, req: impl Request) -> Result<()> {
71+
dispatch_request(self, req).await
72+
}
73+
}
74+
75+
impl AppInfo for App {
76+
fn static_info(&self) -> &StaticInfo {
77+
self.info
78+
}
79+
}
80+
81+
impl AppDb for App {
82+
async fn read_tx<T: Send + 'static + FnOnce(&Transaction) -> Result<R>, R: Send + 'static>(
83+
&self,
84+
op: T,
85+
) -> Result<R> {
86+
Connections::read_tx(&self.db, op).await
87+
}
88+
89+
async fn write_tx<T: Send + 'static + FnOnce(&Transaction) -> Result<R>, R: Send + 'static>(
90+
&self,
91+
op: T,
92+
) -> Result<R> {
93+
Connections::write_tx(&self.db, op).await
94+
}
95+
96+
async fn write_tx_no_sync<
97+
T: Send + 'static + FnOnce(&Transaction) -> Result<R>,
98+
R: Send + 'static,
99+
>(
100+
&self,
101+
op: T,
102+
) -> Result<R> {
103+
Connections::write_tx_no_sync(&self.db, op).await
104+
}
105+
106+
async fn conn<T: Send + 'static + FnOnce(&mut Connection) -> Result<R>, R: Send + 'static>(
107+
&self,
108+
op: T,
109+
) -> Result<R> {
110+
Connections::conn(&self.db, op).await
111+
}
112+
}
113+
114+
impl AppConn for App {
115+
async fn request<M: Msg + Serializable, R: Msg + Deserializable>(
116+
&self,
117+
node_uid: Uid,
118+
msg: &M,
119+
) -> Result<R> {
120+
Pool::request(&self.conn, node_uid, msg).await
121+
}
122+
async fn send_notifications<M: Msg + Serializable>(
123+
&self,
124+
node_types: &'static [NodeType],
125+
msg: &M,
126+
) {
127+
log::trace!("NOTIFICATION to {node_types:?}: {msg:?}");
128+
129+
for t in node_types {
130+
if let Err(err) = async {
131+
let nodes = self
132+
.read_tx(move |tx| crate::db::node::get_with_type(tx, *t))
133+
.await?;
134+
135+
self.conn
136+
.broadcast_datagram(nodes.into_iter().map(|e| e.uid), msg)
137+
.await?;
138+
139+
Ok(()) as Result<_>
140+
}
141+
.await
142+
{
143+
log::error!("Notification could not be sent to all {t} nodes: {err:#}");
144+
}
145+
}
146+
}
147+
148+
fn replace_node_addrs(&self, node_uid: Uid, new_addrs: impl Into<Arc<[SocketAddr]>>) {
149+
Pool::replace_node_addrs(&self.conn, node_uid, new_addrs)
150+
}
151+
}
152+
153+
impl AppRunState for App {
154+
fn pre_shutdown(&self) -> bool {
155+
WeakRunStateHandle::pre_shutdown(&self.run_state)
156+
}
157+
158+
fn notify_client_pulled_state(&self, node_type: NodeType, node_id: NodeId) {
159+
if self.run_state.pre_shutdown() {
160+
let tx = self.shutdown_client_id.clone();
161+
162+
// We don't want to block the task calling this and are not interested by the results
163+
tokio::spawn(async move {
164+
let _ = tx.send((node_type, node_id)).await;
165+
});
166+
}
167+
}
168+
}
169+
170+
impl AppLicense for App {
171+
async fn load_and_verify_cert(&self, cert_path: &Path) -> Result<String> {
172+
LicenseVerifier::load_and_verify_cert(&self.license, cert_path).await
173+
}
174+
175+
fn get_cert_data(&self) -> Result<GetCertDataResult> {
176+
LicenseVerifier::get_cert_data(&self.license)
177+
}
178+
179+
fn get_num_machines(&self) -> Result<u32> {
180+
LicenseVerifier::get_num_machines(&self.license)
181+
}
182+
}

0 commit comments

Comments
 (0)