Skip to content

Commit c7ee6e4

Browse files
benjipelletierfacebook-github-bot
authored andcommitted
Add channel pair metrics (meta-pytorch#1809)
Summary: We want to track throughput, error, latency, etc for channel pair metrics. Add those to `net.rs` Attributes for channel metrics: ``` "error_type" => human readable error type if an error occured "source" => source address (populated when available) "dest" => destination address ``` Reviewed By: vidhyav Differential Revision: D86425779
1 parent b5c0b9f commit c7ee6e4

File tree

2 files changed

+177
-33
lines changed

2 files changed

+177
-33
lines changed

hyperactor/src/channel/net.rs

Lines changed: 143 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -329,14 +329,18 @@ impl<M: RemoteMessage> NetTx<M> {
329329
next_seq: u64,
330330
deque: MessageDeque<M>,
331331
log_id: &'a str,
332+
dest_addr: &'a ChannelAddr,
333+
session_id: u64,
332334
}
333335

334336
impl<'a, M: RemoteMessage> Outbox<'a, M> {
335-
fn new(log_id: &'a str) -> Self {
337+
fn new(log_id: &'a str, dest_addr: &'a ChannelAddr, session_id: u64) -> Self {
336338
Self {
337339
next_seq: 0,
338340
deque: MessageDeque(VecDeque::new()),
339341
log_id,
342+
dest_addr,
343+
session_id,
340344
}
341345
}
342346

@@ -383,7 +387,24 @@ impl<M: RemoteMessage> NetTx<M> {
383387
let frame = Frame::Message(self.next_seq, message);
384388
let message =
385389
serialize_bincode(&frame).map_err(|e| format!("serialization error: {e}"))?;
386-
metrics::REMOTE_MESSAGE_SEND_SIZE.record(message.frame_len() as f64, &[]);
390+
let message_size = message.frame_len();
391+
metrics::REMOTE_MESSAGE_SEND_SIZE.record(message_size as f64, &[]);
392+
393+
// Track throughput for this channel pair
394+
metrics::CHANNEL_THROUGHPUT_BYTES.add(
395+
message_size as u64,
396+
hyperactor_telemetry::kv_pairs!(
397+
"dest" => self.dest_addr.to_string(),
398+
"session_id" => self.session_id.to_string(),
399+
),
400+
);
401+
metrics::CHANNEL_THROUGHPUT_MESSAGES.add(
402+
1,
403+
hyperactor_telemetry::kv_pairs!(
404+
"dest" => self.dest_addr.to_string(),
405+
"session_id" => self.session_id.to_string(),
406+
),
407+
);
387408

388409
self.deque.push_back(QueuedMessage {
389410
seq: self.next_seq,
@@ -498,7 +519,13 @@ impl<M: RemoteMessage> NetTx<M> {
498519
}
499520

500521
/// Remove acked messages from the deque.
501-
fn prune(&mut self, acked: u64, acked_at: Instant) {
522+
fn prune(
523+
&mut self,
524+
acked: u64,
525+
acked_at: Instant,
526+
dest_addr: &ChannelAddr,
527+
session_id: u64,
528+
) {
502529
assert!(
503530
self.largest_acked.as_ref().map_or(0, |i| i.0) <= acked,
504531
"{}: received out-of-order ack; received: {}; stored largest: {}",
@@ -513,7 +540,16 @@ impl<M: RemoteMessage> NetTx<M> {
513540
let deque = &mut self.deque;
514541
while let Some(msg) = deque.front() {
515542
if msg.seq <= acked {
516-
deque.pop_front();
543+
let msg: QueuedMessage<M> = deque.pop_front().unwrap();
544+
// Track latency: time from when message was first received to when it was acked
545+
let latency_micros = msg.received_at.elapsed().as_micros() as i64;
546+
metrics::CHANNEL_LATENCY_MICROS.record(
547+
latency_micros as f64,
548+
hyperactor_telemetry::kv_pairs!(
549+
"dest" => dest_addr.to_string(),
550+
"session_id" => session_id.to_string(),
551+
),
552+
);
517553
} else {
518554
// Messages in the deque are orderd by seq in ascending
519555
// order. So we could return early once we encounter
@@ -588,9 +624,9 @@ impl<M: RemoteMessage> NetTx<M> {
588624
}
589625

590626
impl<'a, M: RemoteMessage> State<'a, M> {
591-
fn init(log_id: &'a str) -> Self {
627+
fn init(log_id: &'a str, dest_addr: &'a ChannelAddr, session_id: u64) -> Self {
592628
Self::Running(Deliveries {
593-
outbox: Outbox::new(log_id),
629+
outbox: Outbox::new(log_id, dest_addr, session_id),
594630
unacked: Unacked::new(None, log_id),
595631
})
596632
}
@@ -638,8 +674,9 @@ impl<M: RemoteMessage> NetTx<M> {
638674
}
639675

640676
let session_id = rand::random();
641-
let log_id = format!("session {}.{}", link.dest(), session_id);
642-
let mut state = State::init(&log_id);
677+
let dest = link.dest();
678+
let log_id = format!("session {}.{}", dest, session_id);
679+
let mut state = State::init(&log_id, &dest, session_id);
643680
let mut conn = Conn::reconnect_with_default();
644681

645682
let (state, conn) = loop {
@@ -760,7 +797,7 @@ impl<M: RemoteMessage> NetTx<M> {
760797
Ok(response) => {
761798
match response {
762799
NetRxResponse::Ack(ack) => {
763-
unacked.prune(ack, RealClock.now());
800+
unacked.prune(ack, RealClock.now(), &dest, session_id);
764801
(State::Running(Deliveries { outbox, unacked }), Conn::Connected { reader, write_state })
765802
}
766803
NetRxResponse::Reject => {
@@ -821,6 +858,15 @@ impl<M: RemoteMessage> NetTx<M> {
821858
"{log_id}: outbox send error: {err}; message size: {}",
822859
outbox.front_size().expect("outbox should not be empty"),
823860
);
861+
// Track error for this channel pair
862+
metrics::CHANNEL_ERRORS.add(
863+
1,
864+
hyperactor_telemetry::kv_pairs!(
865+
"dest" => dest.to_string(),
866+
"session_id" => session_id.to_string(),
867+
"error_type" => metrics::ChannelErrorType::SendError.as_str(),
868+
),
869+
);
824870
(State::Running(Deliveries { outbox, unacked }), Conn::reconnect_with_default())
825871
}
826872
}
@@ -919,6 +965,18 @@ impl<M: RemoteMessage> NetTx<M> {
919965

920966
// Need to resend unacked after reconnecting.
921967
let largest_acked = unacked.largest_acked;
968+
let num_retries = unacked.deque.len();
969+
if num_retries > 0 {
970+
// Track reconnection for this channel pair
971+
metrics::CHANNEL_RECONNECTIONS.add(
972+
1,
973+
hyperactor_telemetry::kv_pairs!(
974+
"dest" => dest.to_string(),
975+
"transport" => dest.transport().to_string(),
976+
"reason" => "reconnect_with_unacked",
977+
),
978+
);
979+
}
922980
outbox.requeue_unacked(unacked.deque);
923981
(
924982
State::Running(Deliveries {
@@ -950,6 +1008,15 @@ impl<M: RemoteMessage> NetTx<M> {
9501008
session_id,
9511009
err
9521010
);
1011+
// Track connection error for this channel pair
1012+
metrics::CHANNEL_ERRORS.add(
1013+
1,
1014+
hyperactor_telemetry::kv_pairs!(
1015+
"dest" => dest.to_string(),
1016+
"session_id" => session_id.to_string(),
1017+
"error_type" => metrics::ChannelErrorType::ConnectionError.as_str(),
1018+
),
1019+
);
9531020
(
9541021
State::Running(Deliveries { outbox, unacked }),
9551022
Conn::reconnect(backoff),
@@ -1357,18 +1424,31 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
13571424
};
13581425

13591426
// De-frame the multi-part message.
1427+
let bytes_len = bytes.len();
13601428
let message = match serde_multipart::Message::from_framed(bytes) {
13611429
Ok(message) => message,
1362-
Err(err) => break (
1363-
next,
1364-
Err::<(), anyhow::Error>(err.into()).context(
1365-
format!(
1366-
"{log_id}: de-frame message with M = {}",
1367-
type_name::<M>(),
1368-
)
1369-
),
1370-
false
1371-
),
1430+
Err(err) => {
1431+
// Track deframing error for this channel pair
1432+
metrics::CHANNEL_ERRORS.add(
1433+
1,
1434+
hyperactor_telemetry::kv_pairs!(
1435+
"source" => self.source.to_string(),
1436+
"dest" => self.dest.to_string(),
1437+
"session_id" => session_id.to_string(),
1438+
"error_type" => metrics::ChannelErrorType::DeframeError.as_str(),
1439+
),
1440+
);
1441+
break (
1442+
next,
1443+
Err::<(), anyhow::Error>(err.into()).context(
1444+
format!(
1445+
"{log_id}: de-frame message with M = {}",
1446+
type_name::<M>(),
1447+
)
1448+
),
1449+
false
1450+
)
1451+
},
13721452
};
13731453

13741454
// Finally decode the message. This assembles the M-typed message
@@ -1396,6 +1476,23 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
13961476
}
13971477
match self.send_with_buffer_metric(&log_id, &tx, message).await {
13981478
Ok(()) => {
1479+
// Track throughput for this channel pair
1480+
metrics::CHANNEL_THROUGHPUT_BYTES.add(
1481+
bytes_len as u64,
1482+
hyperactor_telemetry::kv_pairs!(
1483+
"source" => self.source.to_string(),
1484+
"dest" => self.dest.to_string(),
1485+
"session_id" => session_id.to_string(),
1486+
),
1487+
);
1488+
metrics::CHANNEL_THROUGHPUT_MESSAGES.add(
1489+
1,
1490+
hyperactor_telemetry::kv_pairs!(
1491+
"source" => self.source.to_string(),
1492+
"dest" => self.dest.to_string(),
1493+
"session_id" => session_id.to_string(),
1494+
),
1495+
);
13991496
// In channel's contract, "delivered" means the message
14001497
// is sent to the NetRx object. Therefore, we could bump
14011498
// `next_seq` as far as the message is put on the mspc
@@ -1412,16 +1509,28 @@ impl<S: AsyncRead + AsyncWrite + Send + 'static + Unpin> ServerConn<S> {
14121509
}
14131510
}
14141511
},
1415-
Err(err) => break (
1416-
next,
1417-
Err::<(), anyhow::Error>(err.into()).context(
1418-
format!(
1419-
"{log_id}: deserialize message with M = {}",
1420-
type_name::<M>(),
1421-
)
1422-
),
1423-
false
1424-
),
1512+
Err(err) => {
1513+
// Track deserialization error for this channel pair
1514+
metrics::CHANNEL_ERRORS.add(
1515+
1,
1516+
hyperactor_telemetry::kv_pairs!(
1517+
"source" => self.source.to_string(),
1518+
"dest" => self.dest.to_string(),
1519+
"session_id" => session_id.to_string(),
1520+
"error_type" => metrics::ChannelErrorType::DeserializeError.as_str(),
1521+
),
1522+
);
1523+
break (
1524+
next,
1525+
Err::<(), anyhow::Error>(err.into()).context(
1526+
format!(
1527+
"{log_id}: deserialize message with M = {}",
1528+
type_name::<M>(),
1529+
)
1530+
),
1531+
false
1532+
)
1533+
},
14251534
}
14261535
},
14271536
}
@@ -1743,11 +1852,13 @@ where
17431852
};
17441853

17451854
if let Err(ref err) = res {
1746-
metrics::CHANNEL_CONNECTION_ERRORS.add(
1855+
metrics::CHANNEL_ERRORS.add(
17471856
1,
17481857
hyperactor_telemetry::kv_pairs!(
17491858
"transport" => dest.transport().to_string(),
17501859
"error" => err.to_string(),
1860+
"error_type" => metrics::ChannelErrorType::ConnectionError.as_str(),
1861+
"dest" => dest.to_string(),
17511862
),
17521863
);
17531864

@@ -1766,12 +1877,13 @@ where
17661877
});
17671878
}
17681879
Err(err) => {
1769-
metrics::CHANNEL_CONNECTION_ERRORS.add(
1880+
metrics::CHANNEL_ERRORS.add(
17701881
1,
17711882
hyperactor_telemetry::kv_pairs!(
17721883
"transport" => listener_channel_addr.transport().to_string(),
17731884
"operation" => "accept",
17741885
"error" => err.to_string(),
1886+
"error_type" => metrics::ChannelErrorType::ConnectionError.as_str(),
17751887
),
17761888
);
17771889

hyperactor/src/metrics.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,31 @@ use hyperactor_telemetry::declare_static_histogram;
1515
use hyperactor_telemetry::declare_static_timer;
1616
use hyperactor_telemetry::declare_static_up_down_counter;
1717

18+
/// Error types for channel-related errors. Only used for telemetry.
19+
#[derive(Debug, Clone, Copy)]
20+
pub enum ChannelErrorType {
21+
/// Error occurred while sending a message.
22+
SendError,
23+
/// Error occurred while connecting to a channel.
24+
ConnectionError,
25+
/// Error occurred while deframing a message.
26+
DeframeError,
27+
/// Error occurred while deserializing a message.
28+
DeserializeError,
29+
}
30+
31+
impl ChannelErrorType {
32+
/// Returns the string representation of the error type.
33+
pub fn as_str(&self) -> &'static str {
34+
match self {
35+
ChannelErrorType::SendError => "send_error",
36+
ChannelErrorType::ConnectionError => "connection_error",
37+
ChannelErrorType::DeframeError => "deframe_error",
38+
ChannelErrorType::DeserializeError => "deserialize_error",
39+
}
40+
}
41+
}
42+
1843
// MAILBOX
1944
// Tracks messages that couldn't be delivered to their destination and were returned as undeliverable
2045
declare_static_counter!(
@@ -44,16 +69,23 @@ declare_static_timer!(
4469
declare_static_histogram!(REMOTE_MESSAGE_SEND_SIZE, "channel.remote_message_send_size");
4570
// Tracks the number of new channel connections established (client and server)
4671
declare_static_counter!(CHANNEL_CONNECTIONS, "channel.connections");
47-
// Tracks errors that occur when establishing channel connections
48-
declare_static_counter!(CHANNEL_CONNECTION_ERRORS, "channel.connection_errors");
4972
// Tracks the number of channel reconnection attempts
5073
declare_static_counter!(CHANNEL_RECONNECTIONS, "channel.reconnections");
74+
// Tracks errors for each channel pair
75+
declare_static_counter!(CHANNEL_ERRORS, "channel.errors");
5176
// Tracks the number of NetRx encountering full buffer, i.e. its mspc channel.
5277

5378
// This metric counts how often the NetRx→client mpsc channel remains full,
5479
// incrementing once per CHANNEL_NET_RX_BUFFER_FULL_CHECK_INTERVAL while blocked.
5580
declare_static_counter!(CHANNEL_NET_RX_BUFFER_FULL, "channel.net_rx_buffer_full");
5681

82+
// Tracks throughput (bytes sent)
83+
declare_static_counter!(CHANNEL_THROUGHPUT_BYTES, "channel.throughput.bytes");
84+
// Tracks throughput (message count)
85+
declare_static_counter!(CHANNEL_THROUGHPUT_MESSAGES, "channel.throughput.messages");
86+
// Tracks message latency for each channel pair in microseconds
87+
declare_static_histogram!(CHANNEL_LATENCY_MICROS, "channel.latency.us");
88+
5789
// PROC MESH
5890
// Tracks the number of active processes in the process mesh
5991
declare_static_counter!(PROC_MESH_ALLOCATION, "proc_mesh.active_procs");

0 commit comments

Comments
 (0)