Skip to content

Commit 5301dd6

Browse files
committed
Better management of disconnecting peers
Rather than send a message when a peer disconnects and then handle that message asynchronously, we now wait for the task itself to exit and then return a `DisconnectedPeer` from `ConnectionManager::step`. We only return the `PeerId` for the `DisconnectedPeer` if it is still the existing `established connection` for the given peer and it hasn't been replaced by a newer connection. This prevents calling `Node::on_disconnect` for the stale connection when it might have already received an `on_connect` call for the new connection.
1 parent d563f8d commit 5301dd6

File tree

4 files changed

+107
-74
lines changed

4 files changed

+107
-74
lines changed

trust-quorum/src/connection_manager.rs

Lines changed: 44 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,19 @@ pub enum MainToConnMsg {
5555
Msg(WireMsg),
5656
}
5757

58+
/// The task for this sprockets connection just exited. If
59+
/// `ConnectionManager::step` returns this value and `peer_id` is `Some` than
60+
/// it means no new connection for the peer has yet been established. It is
61+
/// safe to cleanup state for the given `peer_id`, by, for instance, calling
62+
/// `Node::on_disconnect`.
63+
///
64+
/// By always returning the `task_id`, we allow cleanup of proxy requests for
65+
/// stale nodes that will never complete.
66+
pub struct DisconnectedPeer {
67+
pub task_id: task::Id,
68+
pub peer_id: Option<BaseboardId>,
69+
}
70+
5871
/// All possible messages sent over established connections
5972
///
6073
/// This include trust quorum related `PeerMsg`s, but also ancillary network
@@ -106,7 +119,6 @@ pub enum ConnToMainMsgInner {
106119
Connected { addr: SocketAddrV6, peer_id: BaseboardId },
107120
Received { from: BaseboardId, msg: PeerMsg },
108121
ReceivedNetworkConfig { from: BaseboardId, config: NetworkConfig },
109-
Disconnected { peer_id: BaseboardId },
110122
ProxyRequestReceived { from: BaseboardId, req: proxy::WireRequest },
111123
ProxyResponseReceived { from: BaseboardId, rsp: proxy::WireResponse },
112124
}
@@ -154,7 +166,7 @@ impl BiHashItem for TaskHandle {
154166
}
155167

156168
pub struct EstablishedTaskHandle {
157-
baseboard_id: BaseboardId,
169+
pub baseboard_id: BaseboardId,
158170
task_handle: TaskHandle,
159171
}
160172

@@ -246,7 +258,7 @@ pub struct ConnMgrStatus {
246258

247259
/// The state of a proxy connection
248260
pub enum ProxyConnState {
249-
Connected,
261+
Connected(task::Id),
250262
Disconnected,
251263
}
252264

@@ -447,7 +459,7 @@ impl ConnMgr {
447459
if let Some(h) = self.established.get1(destination) {
448460
info!(self.log, "Sending {req:?}"; "peer_id" => %destination);
449461
h.send(req).await;
450-
ProxyConnState::Connected
462+
ProxyConnState::Connected(h.task_id())
451463
} else {
452464
ProxyConnState::Disconnected
453465
}
@@ -471,32 +483,37 @@ impl ConnMgr {
471483

472484
/// Perform any polling related operations that the connection
473485
/// manager must perform concurrently.
486+
///
487+
/// Return `Ok(Some(DisconnectedPeer))` if an `EstablishedConnectionTask`
488+
/// that was still the exclusive connection task for a specific peer has
489+
/// just exited.
474490
pub async fn step(
475491
&mut self,
476492
corpus: Vec<Utf8PathBuf>,
477-
) -> Result<(), AcceptError> {
478-
tokio::select! {
493+
) -> Result<Option<DisconnectedPeer>, AcceptError> {
494+
let disconnected_peer = tokio::select! {
479495
acceptor = self.server.accept(corpus.clone()) => {
480496
self.accept(acceptor?).await?;
497+
None
481498
}
482499
Some(res) = self.join_set.join_next_with_id() => {
483500
match res {
484501
Ok((task_id, _)) => {
485-
self.on_task_exit(task_id).await;
502+
Some(self.on_task_exit(task_id))
486503
}
487504
Err(err) => {
488505
warn!(self.log, "Connection task panic: {err}");
489-
self.on_task_exit(err.id()).await;
506+
Some(self.on_task_exit(err.id()))
490507
}
491-
492508
}
493509
}
494510
_ = self.reconnect_interval.tick() => {
495511
self.reconnect(corpus.clone()).await;
512+
None
496513
}
497-
}
514+
};
498515

499-
Ok(())
516+
Ok(disconnected_peer)
500517
}
501518

502519
pub async fn accept(
@@ -686,22 +703,6 @@ impl ConnMgr {
686703
}
687704
}
688705

689-
/// The established connection task has asynchronously exited.
690-
pub async fn on_disconnected(
691-
&mut self,
692-
task_id: task::Id,
693-
peer_id: BaseboardId,
694-
) {
695-
if let Some(established_task_handle) = self.established.get1(&peer_id) {
696-
if established_task_handle.task_id() != task_id {
697-
// This was a stale disconnect
698-
return;
699-
}
700-
}
701-
warn!(self.log, "peer disconnected"; "peer_id" => %peer_id);
702-
let _ = self.established.remove1(&peer_id);
703-
}
704-
705706
/// Initiate connections if a corresponding task doesn't already exist. This
706707
/// must be called periodically to handle transient disconnections which
707708
/// cause tasks to exit.
@@ -740,9 +741,9 @@ impl ConnMgr {
740741
&mut self,
741742
addrs: BTreeSet<SocketAddrV6>,
742743
corpus: Vec<Utf8PathBuf>,
743-
) -> BTreeSet<BaseboardId> {
744+
) -> Vec<EstablishedTaskHandle> {
744745
if self.bootstrap_addrs == addrs {
745-
return BTreeSet::new();
746+
return vec![];
746747
}
747748

748749
// We don't try to compare addresses from accepted nodes. If DDMD
@@ -770,10 +771,10 @@ impl ConnMgr {
770771
self.connect_client(corpus.clone(), addr).await;
771772
}
772773

773-
let mut disconnected_peers = BTreeSet::new();
774+
let mut disconnected_peers = Vec::new();
774775
for addr in to_disconnect {
775-
if let Some(peer_id) = self.disconnect_client(addr).await {
776-
disconnected_peers.insert(peer_id);
776+
if let Some(handle) = self.disconnect_client(addr).await {
777+
disconnected_peers.push(handle);
777778
}
778779
}
779780
disconnected_peers
@@ -861,7 +862,7 @@ impl ConnMgr {
861862
async fn disconnect_client(
862863
&mut self,
863864
addr: SocketAddrV6,
864-
) -> Option<BaseboardId> {
865+
) -> Option<EstablishedTaskHandle> {
865866
if let Some(handle) = self.connecting.remove2(&addr) {
866867
// The connection has not yet completed its handshake
867868
info!(
@@ -880,15 +881,17 @@ impl ConnMgr {
880881
"peer_id" => %handle.baseboard_id
881882
);
882883
handle.abort();
883-
Some(handle.baseboard_id)
884+
Some(handle)
884885
} else {
885886
None
886887
}
887888
}
888889
}
889890

890891
/// Remove any references to the given task
891-
async fn on_task_exit(&mut self, task_id: task::Id) {
892+
///
893+
/// Return a `DisconnectedPeer` for the given `task_id`.
894+
fn on_task_exit(&mut self, task_id: task::Id) -> DisconnectedPeer {
892895
// We're most likely to find the task as established so we start with that
893896
if let Some(handle) = self.established.remove2(&task_id) {
894897
info!(
@@ -898,6 +901,10 @@ impl ConnMgr {
898901
"peer_addr" => %handle.addr(),
899902
"peer_id" => %handle.baseboard_id
900903
);
904+
return DisconnectedPeer {
905+
task_id,
906+
peer_id: Some(handle.baseboard_id),
907+
};
901908
} else if let Some(handle) = self.accepting.remove1(&task_id) {
902909
info!(
903910
self.log,
@@ -919,6 +926,8 @@ impl ConnMgr {
919926
"task_id" => ?task_id
920927
);
921928
}
929+
930+
DisconnectedPeer { task_id, peer_id: None }
922931
}
923932
}
924933

trust-quorum/src/established_conn.rs

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,6 @@ impl EstablishedConn {
155155
}
156156

157157
async fn close(&mut self) {
158-
if let Err(_) = self.main_tx.try_send(ConnToMainMsg {
159-
task_id: self.task_id,
160-
msg: ConnToMainMsgInner::Disconnected {
161-
peer_id: self.peer_id.clone(),
162-
},
163-
}) {
164-
warn!(self.log, "Failed to send to main task");
165-
}
166158
let _ = self.writer.shutdown().await;
167159
}
168160

trust-quorum/src/proxy.rs

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use omicron_uuid_kinds::RackUuid;
2323
use serde::{Deserialize, Serialize};
2424
use slog_error_chain::{InlineErrorChain, SlogInlineError};
2525
use tokio::sync::{mpsc, oneshot};
26+
use tokio::task;
2627
use trust_quorum_protocol::{
2728
BaseboardId, CommitError, Configuration, Epoch, PrepareAndCommitError,
2829
};
@@ -246,20 +247,27 @@ pub enum TrackerError {
246247
/// A trackable in-flight proxy request, owned by the `Tracker`
247248
#[derive(Debug)]
248249
pub struct TrackableRequest {
250+
/// Each `TrackableRequest` is bound to a
251+
/// [`crate::established_conn::EstablishedConn`] that is uniquely identified
252+
/// by its `tokio::task::Id`. This is useful because it disambiguates
253+
/// connect and disconnect operations for the same `destination` such that
254+
/// they don't have to be totally ordered. It is enough to know that a
255+
/// disconnect for a given `task_id` only occurs after a connect.
256+
task_id: task::Id,
257+
/// A unique id for a given proxy request
249258
request_id: Uuid,
250-
destination: BaseboardId,
251259
// The option exists so we can take the sender out in `on_disconnect`, when
252260
// the request is borrowed, but about to be discarded.
253261
tx: DebugIgnore<Option<oneshot::Sender<Result<WireValue, TrackerError>>>>,
254262
}
255263

256264
impl TrackableRequest {
257265
pub fn new(
258-
destination: BaseboardId,
266+
task_id: task::Id,
259267
request_id: Uuid,
260268
tx: oneshot::Sender<Result<WireValue, TrackerError>>,
261269
) -> TrackableRequest {
262-
TrackableRequest { request_id, destination, tx: DebugIgnore(Some(tx)) }
270+
TrackableRequest { task_id, request_id, tx: DebugIgnore(Some(tx)) }
263271
}
264272
}
265273

@@ -306,9 +314,9 @@ impl Tracker {
306314
}
307315

308316
/// A remote peer has disconnected
309-
pub fn on_disconnect(&mut self, from: &BaseboardId) {
317+
pub fn on_disconnect(&mut self, task_id: task::Id) {
310318
self.ops.retain(|mut req| {
311-
if &req.destination == from {
319+
if req.task_id == task_id {
312320
let tx = req.tx.take().unwrap();
313321
let _ = tx.send(Err(TrackerError::Disconnected));
314322
false
@@ -335,15 +343,15 @@ mod tests {
335343
async fn recv_and_insert(
336344
rx: &mut mpsc::Receiver<NodeApiRequest>,
337345
tracker: &mut Tracker,
346+
task_id: task::Id,
338347
) {
339-
let Some(NodeApiRequest::Proxy { destination, wire_request, tx }) =
348+
let Some(NodeApiRequest::Proxy { wire_request, tx, .. }) =
340349
rx.recv().await
341350
else {
342351
panic!("Invalid NodeApiRequest")
343352
};
344353

345-
let req =
346-
TrackableRequest::new(destination, wire_request.request_id, tx);
354+
let req = TrackableRequest::new(task_id, wire_request.request_id, tx);
347355
tracker.insert(req);
348356
}
349357

@@ -355,6 +363,11 @@ mod tests {
355363
};
356364
let rack_id = RackUuid::new_v4();
357365

366+
// In real code, the `tokio::task::ID` is the id of the
367+
// `EstablishedConnectionTask`. However, we are simulating those
368+
// connections here, so just use an ID of an arbitrary task.
369+
let task_id = task::spawn(async {}).id();
370+
358371
// Test channel where the sender is usually cloned from the [`crate::NodeTaskHandle`],
359372
// and the receiver is owned by the local [`crate::NodeTask`]
360373
let (tx, mut rx) = mpsc::channel(5);
@@ -378,7 +391,7 @@ mod tests {
378391
assert_eq!(tracker.len(), 0);
379392

380393
// Simulate receiving a request by the [`NodeTask`]
381-
recv_and_insert(&mut rx, &mut tracker).await;
394+
recv_and_insert(&mut rx, &mut tracker, task_id).await;
382395

383396
// We now have a request in the tracker
384397
assert_eq!(tracker.len(), 1);
@@ -401,7 +414,7 @@ mod tests {
401414
});
402415

403416
// Simulate receiving a request by the [`NodeTask`]
404-
recv_and_insert(&mut rx, &mut tracker).await;
417+
recv_and_insert(&mut rx, &mut tracker, task_id).await;
405418
assert_eq!(tracker.len(), 2);
406419

407420
// We still haven't actually completed any operations yet
@@ -456,6 +469,11 @@ mod tests {
456469
let proxy = Proxy::new(tx.clone());
457470
let mut tracker = Tracker::new();
458471

472+
// In real code, the `tokio::task::ID` is the id of the
473+
// `EstablishedConnectionTask`. However, we are simulating those
474+
// connections here, so just use an ID of an arbitrary task.
475+
let task_id = task::spawn(async {}).id();
476+
459477
let requests_completed = Arc::new(AtomicUsize::new(0));
460478

461479
// This is the first "user" task that will issue proxy operations
@@ -473,7 +491,7 @@ mod tests {
473491
assert_eq!(tracker.len(), 0);
474492

475493
// Simulate receiving a request by the [`NodeTask`]
476-
recv_and_insert(&mut rx, &mut tracker).await;
494+
recv_and_insert(&mut rx, &mut tracker, task_id).await;
477495

478496
// We now have a request in the tracker
479497
assert_eq!(tracker.len(), 1);
@@ -517,6 +535,11 @@ mod tests {
517535
};
518536
let rack_id = RackUuid::new_v4();
519537

538+
// In real code, the `tokio::task::ID` is the id of the
539+
// `EstablishedConnectionTask`. However, we are simulating those
540+
// connections here, so just use an ID of an arbitrary task.
541+
let task_id = task::spawn(async {}).id();
542+
520543
// Test channel where the sender is usually cloned from the [`crate::NodeTaskHandle`],
521544
// and the receiver is owned by the local [`crate::NodeTask`]
522545
let (tx, mut rx) = mpsc::channel(5);
@@ -540,7 +563,7 @@ mod tests {
540563
assert_eq!(tracker.len(), 0);
541564

542565
// Simulate receiving a request by the [`NodeTask`]
543-
recv_and_insert(&mut rx, &mut tracker).await;
566+
recv_and_insert(&mut rx, &mut tracker, task_id).await;
544567

545568
// We now have a request in the tracker
546569
assert_eq!(tracker.len(), 1);
@@ -560,14 +583,15 @@ mod tests {
560583
});
561584

562585
// Simulate receiving a request by the [`NodeTask`]
563-
recv_and_insert(&mut rx, &mut tracker).await;
586+
recv_and_insert(&mut rx, &mut tracker, task_id).await;
564587
assert_eq!(tracker.len(), 2);
565588

566589
// We still haven't actually completed any operations yet
567590
assert_eq!(requests_completed.load(Ordering::Relaxed), 0);
568591

569-
// Now simulate a disconnection to the proxy destination
570-
tracker.on_disconnect(&destination);
592+
// Now simulate a disconnection to the proxy destination for this
593+
// specific connection task.
594+
tracker.on_disconnect(task_id);
571595

572596
// Now wait for both responses to be processed
573597
wait_for_condition(

0 commit comments

Comments
 (0)