diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index d463868c6..f769014a3 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -119,11 +119,13 @@ pub trait Tx: std::fmt::Debug { /// message is either delivered, or we eventually discover that /// the channel has failed and it will be sent back on `return_channel`. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SendError`. + #[hyperactor::instrument_infallible] fn try_post(&self, message: M, return_channel: oneshot::Sender>) { self.do_post(message, Some(return_channel)); } /// Enqueue a message to be sent on the channel. + #[hyperactor::instrument_infallible] fn post(&self, message: M) { self.do_post(message, None); } @@ -803,6 +805,7 @@ enum ChannelRxKind { #[async_trait] impl Rx for ChannelRx { + #[hyperactor::instrument] async fn recv(&mut self) -> Result { match &mut self.inner { ChannelRxKind::Local(rx) => rx.recv().await, diff --git a/hyperactor/src/channel/net/framed.rs b/hyperactor/src/channel/net/framed.rs index e6cd0c3db..31cb64239 100644 --- a/hyperactor/src/channel/net/framed.rs +++ b/hyperactor/src/channel/net/framed.rs @@ -68,6 +68,7 @@ impl FrameReader { /// `max_frame_length`. **This error is fatal:** once returned, /// the `FrameReader` must be dropped; the underlying connection /// is no longer valid. + #[tracing::instrument(skip_all)] pub async fn next(&mut self) -> io::Result> { loop { match &mut self.state { @@ -229,6 +230,7 @@ impl FrameWrite { /// returned futures at any time. Upon completion, the frame is guaranteed to be /// written, unless an error was encountered, in which case the underlying stream /// is in an undefined state. + #[tracing::instrument(skip_all)] pub async fn send(&mut self) -> io::Result<()> { loop { if self.len_buf.has_remaining() { diff --git a/hyperactor/src/channel/net/server.rs b/hyperactor/src/channel/net/server.rs index 6a642365f..94fb49a47 100644 --- a/hyperactor/src/channel/net/server.rs +++ b/hyperactor/src/channel/net/server.rs @@ -27,8 +27,12 @@ use tokio::sync::mpsc; use tokio::task::JoinError; use tokio::task::JoinHandle; use tokio::task::JoinSet; +use tokio::time::Duration; use tokio_util::net::Listener; use tokio_util::sync::CancellationToken; +use tracing::Instrument; +use tracing::Level; +use tracing::Span; use super::serialize_response; use crate::RemoteMessage; @@ -46,6 +50,36 @@ use crate::clock::RealClock; use crate::config; use crate::metrics; use crate::sync::mvar::MVar; + +fn process_state_span( + source: &ChannelAddr, + dest: &ChannelAddr, + session_id: u64, + next: &Next, + rcv_raw_frame_count: u64, + last_ack_time: tokio::time::Instant, +) -> Span { + let since_last_ack_str = humantime::format_duration(last_ack_time.elapsed()).to_string(); + + let pending_ack_count = if next.seq > next.ack { + next.seq - next.ack - 1 + } else { + 0 + }; + + tracing::span!( + Level::ERROR, + "net i/o loop", + session_id = format!("{}.{}", dest, session_id), + source = source.to_string(), + next_seq = next.seq, + last_ack = next.ack, + pending_ack_count = pending_ack_count, + rcv_raw_frame_count = rcv_raw_frame_count, + since_last_ack = since_last_ack_str.as_str(), + ) +} + pub(super) struct ServerConn { reader: FrameReader>, write_state: WriteState, Bytes, u64>, @@ -78,184 +112,267 @@ impl ServerConn { Ok(session_id) } - /// Handles a server side stream created during the `listen` loop. - async fn process( + async fn process_step( &mut self, - session_id: u64, - tx: mpsc::Sender, - cancel_token: CancellationToken, - mut next: Next, - ) -> (Next, Result<(), anyhow::Error>) { - let log_id = format!("session {}.{}<-{}", self.dest, session_id, self.source); - let initial_next: Next = next.clone(); - let mut rcv_raw_frame_count = 0u64; - let mut last_ack_time = RealClock.now(); - - let ack_time_interval = config::global::get(config::MESSAGE_ACK_TIME_INTERVAL); - let ack_msg_interval = config::global::get(config::MESSAGE_ACK_EVERY_N_MESSAGES); + tx: &mpsc::Sender, + cancel_token: &CancellationToken, + next: &Next, + last_ack_time: &mut tokio::time::Instant, + rcv_raw_frame_count: &mut u64, + ack_time_interval: Duration, + ack_msg_interval: u64, + log_id: &str, + ) -> (Next, Option<(Result<(), anyhow::Error>, bool)>) { + let mut next = next.clone(); - let (mut final_next, final_result, reject_conn) = loop { - if self.write_state.is_idle() - && (next.ack + ack_msg_interval <= next.seq - || (next.ack < next.seq && last_ack_time.elapsed() > ack_time_interval)) - { - let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() - else { - panic!("illegal state"); - }; - let ack = match serialize_response(NetRxResponse::Ack(next.seq - 1)) { - Ok(ack) => ack, - Err(err) => { - break ( - next, + if self.write_state.is_idle() + && (next.ack + ack_msg_interval <= next.seq + || (next.ack < next.seq && last_ack_time.elapsed() > ack_time_interval)) + { + let Ok(writer) = replace(&mut self.write_state, WriteState::Broken).into_idle() else { + panic!("illegal state"); + }; + let ack = match serialize_response(NetRxResponse::Ack(next.seq - 1)) { + Ok(ack) => ack, + Err(err) => { + return ( + next, + Some(( Err::<(), anyhow::Error>(err.into()) .context(format!("{log_id}: serializing ack")), false, - ); - } - }; - match FrameWrite::new( - writer, - ack, - config::global::get(config::CODEC_MAX_FRAME_LENGTH), - ) { - Ok(fw) => { - self.write_state = WriteState::Writing(fw, next.seq); - } - Err((writer, e)) => { - debug_assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); - tracing::error!("failed to create ack frame (should be tiny): {e}"); - self.write_state = WriteState::Idle(writer); - } + )), + ); + } + }; + match FrameWrite::new( + writer, + ack, + config::global::get(config::CODEC_MAX_FRAME_LENGTH), + ) { + Ok(fw) => { + self.write_state = WriteState::Writing(fw, next.seq); + } + Err((writer, e)) => { + debug_assert_eq!(e.kind(), std::io::ErrorKind::InvalidData); + tracing::error!("failed to create ack frame (should be tiny): {e}"); + self.write_state = WriteState::Idle(writer); } } + } - tokio::select! { - // Prioritize ack, and then shutdown. Leave read last because - // there could be a large volume of messages to read, which - // subsequently starves the other select! branches. - biased; - - // We have to be careful to manage the ack write state here, so that we do not - // write partial acks in the presence of cancellation. - ack_result = self.write_state.send() => { - match ack_result { - Ok(acked_seq) => { - last_ack_time = RealClock.now(); - next.ack = acked_seq; - } - Err(err) => { - let v = self.write_state.value(); - break ( - next, + tokio::select! { + // Prioritize ack, and then shutdown. Leave read last because + // there could be a large volume of messages to read, which + // subsequently starves the other select! branches. + biased; + + // We have to be careful to manage the ack write state here, so that we do not + // write partial acks in the presence of cancellation. + ack_result = self.write_state.send() => { + match ack_result { + Ok(acked_seq) => { + tracing::info!("sent ack {acked_seq}"); + *last_ack_time = RealClock.now(); + next.ack = acked_seq; + } + Err(err) => { + let v = self.write_state.value(); + return ( + next, + Some(( Err::<(), anyhow::Error>(err.into()) .context(format!("{log_id}: acking peer message: {v:?}")), - false - ); - } + false, + )), + ); } - }, - // Have a tick to abort select! call to make sure the ack for the last message can get the chance - // to be sent as a result of time interval being reached. - _ = RealClock.sleep_until(last_ack_time + ack_time_interval), if next.ack < next.seq => {}, - _ = cancel_token.cancelled() => break (next, Ok(()), false), - bytes_result = self.reader.next() => { - rcv_raw_frame_count += 1; - // First handle transport-level I/O errors, and EOFs. - let bytes = match bytes_result { - Ok(Some(bytes)) => bytes, - Ok(None) => { - tracing::debug!("{log_id}: EOF"); - break (next, Ok(()), false); - } - Err(err) => break ( + } + }, + // Have a tick to abort select! call to make sure the ack for the last message can get the chance + // to be sent as a result of time interval being reached. + _ = RealClock.sleep_until(*last_ack_time + ack_time_interval), if next.ack < next.seq => {}, + + _ = cancel_token.cancelled() => return (next, Some((Ok(()), false))), + + bytes_result = self.reader.next() => { + *rcv_raw_frame_count += 1; + // First handle transport-level I/O errors, and EOFs. + let bytes = match bytes_result { + Ok(Some(bytes)) => bytes, + Ok(None) => { + tracing::debug!("{log_id}: EOF"); + return (next, Some((Ok(()), false))); + } + Err(err) => { + return ( next, - Err::<(), anyhow::Error>(err.into()).context( - format!( + Some(( + Err::<(), anyhow::Error>(err.into()).context(format!( "{log_id}: reading into Frame with M = {}", type_name::(), - ) - ), - false - ), - }; + )), + false, + )), + ) + } + }; - // De-frame the multi-part message. - let message = match serde_multipart::Message::from_framed(bytes) { - Ok(message) => message, - Err(err) => break ( + // De-frame the multi-part message. + let message = match serde_multipart::Message::from_framed(bytes) { + Ok(message) => message, + Err(err) => { + return ( next, - Err::<(), anyhow::Error>(err.into()).context( - format!( + Some(( + Err::<(), anyhow::Error>(err.into()).context(format!( "{log_id}: de-frame message with M = {}", type_name::(), - ) - ), - false - ), - }; - - // Finally decode the message. This assembles the M-typed message - // from its constituent parts. - match serde_multipart::deserialize_bincode(message) { - Ok(Frame::Init(_)) => { - break (next, Err(anyhow::anyhow!("{log_id}: unexpected init frame")), true) - }, - // Ignore retransmits. - Ok(Frame::Message(seq, _)) if seq < next.seq => { - tracing::debug!( - "{log_id}: ignoring retransmit; retransmit seq: {}; expected next seq: {}", - seq, - next.seq, - ); - }, - // The following segment ensures exactly-once semantics. - // That means No out-of-order delivery and no duplicate delivery. - Ok(Frame::Message(seq, message)) => { - // received seq should be equal to next seq. Else error out! - if seq > next.seq { - let msg = format!("{log_id}: out-of-sequence message, expected seq {}, got {}", next.seq, seq); - tracing::error!(msg); - break (next, Err(anyhow::anyhow!(msg)), true) + )), + false, + )), + ) + } + }; + + // Finally decode the message. This assembles the M-typed message + // from its constituent parts. + match serde_multipart::deserialize_bincode(message) { + Ok(Frame::Init(_)) => { + return ( + next, + Some((Err(anyhow::anyhow!("{log_id}: unexpected init frame")), true)), + ) + }, + // Ignore retransmits. + Ok(Frame::Message(seq, _)) if seq < next.seq => { + tracing::debug!( + "{log_id}: ignoring retransmit; retransmit seq: {}; expected next seq: {}", + seq, + next.seq, + ); + }, + // The following segment ensures exactly-once semantics. + // That means No out-of-order delivery and no duplicate delivery. + Ok(Frame::Message(seq, message)) => { + // received seq should be equal to next seq. Else error out! + if seq > next.seq { + let msg = format!("{log_id}: out-of-sequence message, expected seq {}, got {}", next.seq, seq); + tracing::error!(msg); + return (next, Some((Err(anyhow::anyhow!(msg)), true))) + } + match self.send_with_buffer_metric(log_id, tx, message) + .instrument(tracing::info_span!( + "send_with_buffer_metric", + seq = seq, + )) + .await + { + Ok(()) => { + // In channel's contract, "delivered" means the message + // is sent to the NetRx object. Therefore, we could bump + // `next_seq` as far as the message is put on the mpsc + // channel. + // + // Note that when/how the messages in NetRx are processed + // is not covered by channel's contract. For example, + // the message might never be taken out of netRx, but + // channel still considers those messages delivered. + next.seq = seq+1; } - match self.send_with_buffer_metric(&log_id, &tx, message).await { - Ok(()) => { - // In channel's contract, "delivered" means the message - // is sent to the NetRx object. Therefore, we could bump - // `next_seq` as far as the message is put on the mpsc - // channel. - // - // Note that when/how the messages in NetRx are processed - // is not covered by channel's contract. For example, - // the message might never be taken out of netRx, but - // channel still considers those messages delivered. - next.seq = seq+1; - } - Err(err) => { - break (next, Err::<(), anyhow::Error>(err).context(format!("{log_id}: relaying message to mpsc channel")), false) - } + Err(err) => { + return ( + next, + Some(( + Err::<(), anyhow::Error>(err) + .context(format!("{log_id}: relaying message to mpsc channel")), + false, + )), + ) } - }, - Err(err) => break ( + } + }, + Err(err) => { + return ( next, - Err::<(), anyhow::Error>(err.into()).context( - format!( + Some(( + Err::<(), anyhow::Error>(err.into()).context(format!( "{log_id}: deserialize message with M = {}", type_name::(), - ) - ), - false - ), + )), + false, + )), + ) } - }, + } + }, + } + + (next, None) + } + + /// Handles a server side stream created during the `listen` loop. + async fn process( + &mut self, + session_id: u64, + tx: mpsc::Sender, + cancel_token: CancellationToken, + mut next: Next, + ) -> (Next, Result<(), anyhow::Error>) { + let log_id = format!("session {}.{}<-{}", self.dest, session_id, self.source); + let initial_next: Next = next.clone(); + let mut rcv_raw_frame_count = 0u64; + let mut last_ack_time = RealClock.now(); + + let ack_time_interval = config::global::get(config::MESSAGE_ACK_TIME_INTERVAL); + let ack_msg_interval = config::global::get(config::MESSAGE_ACK_EVERY_N_MESSAGES); + + let (mut final_next, final_result, reject_conn) = loop { + let span = process_state_span( + &self.source, + &self.dest, + session_id, + &next, + rcv_raw_frame_count, + last_ack_time, + ); + + let (new_next, break_info) = self + .process_step( + &tx, + &cancel_token, + &next, + &mut last_ack_time, + &mut rcv_raw_frame_count, + ack_time_interval, + ack_msg_interval, + &log_id, + ) + .instrument(span) + .await; + + next = new_next; + + if let Some((result, reject_conn)) = break_info { + break (next, result, reject_conn); } }; + let span = process_state_span( + &self.source, + &self.dest, + session_id, + &final_next, + rcv_raw_frame_count, + last_ack_time, + ); + // Note: // 1. processed seq/ack is Next-1; // 2. rcv_raw_frame_count contains the last frame which might not be // desrializable, e.g. EOF, error, etc. tracing::debug!( + parent: &span, "{log_id}: NetRx::process exited its loop with states: initial Next \ was {initial_next}; final Next is {final_next}; since acked: {}sec; \ rcv raw frame count is {rcv_raw_frame_count}; final result: {:?}", diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index 88b0f6618..36cb5bf33 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -1185,6 +1185,7 @@ impl MailboxClient { } impl MailboxSender for MailboxClient { + #[hyperactor::instrument_infallible] fn post_unchecked( &self, envelope: MessageEnvelope, diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 2dd6aa4c5..83b49b7a6 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -77,6 +77,7 @@ declare_attrs! { /// Common implementation for `ActorMesh`s and `ActorMeshRef`s to cast /// an `M`-typed message #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`. +#[hyperactor::instrument] pub(crate) fn actor_mesh_cast( cx: &impl context::Actor, actor_mesh_id: ActorMeshId, diff --git a/hyperactor_mesh/src/comm.rs b/hyperactor_mesh/src/comm.rs index bbbc63f4d..f7b9dac43 100644 --- a/hyperactor_mesh/src/comm.rs +++ b/hyperactor_mesh/src/comm.rs @@ -350,6 +350,7 @@ impl Handler for CommActor { // TODO(T218630526): reliable casting for mutable topology #[async_trait] impl Handler for CommActor { + #[hyperactor::instrument] async fn handle(&mut self, cx: &Context, cast_message: CastMessage) -> Result<()> { // Always forward the message to the root rank of the slice, casting starts from there. let slice = cast_message.dest.slice.clone(); @@ -380,6 +381,7 @@ impl Handler for CommActor { #[async_trait] impl Handler for CommActor { + #[hyperactor::instrument] async fn handle(&mut self, cx: &Context, fwd_message: ForwardMessage) -> Result<()> { let ForwardMessage { sender, diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 76939eb71..c13ca3e1e 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -771,6 +771,7 @@ impl PanicFlag { #[async_trait] impl Handler for PythonActor { + #[hyperactor::instrument] async fn handle( &mut self, cx: &Context, diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index a52b7b5be..b71dac0d8 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -137,6 +137,7 @@ pub(crate) fn to_hy_sel(selection: &str) -> PyResult { #[pymethods] impl PythonActorMesh { + #[hyperactor::instrument] fn cast( &self, message: &PythonMessage,