diff --git a/lightning-net-tokio/src/lib.rs b/lightning-net-tokio/src/lib.rs index 068f77a84bb..cf66311444f 100644 --- a/lightning-net-tokio/src/lib.rs +++ b/lightning-net-tokio/src/lib.rs @@ -660,8 +660,8 @@ mod tests { } fn handle_channel_update( &self, _their_node_id: Option, _msg: &ChannelUpdate, - ) -> Result { - Ok(false) + ) -> Result, LightningError> { + Ok(None) } fn get_next_channel_announcement( &self, _starting_point: u64, diff --git a/lightning/src/ln/channel_open_tests.rs b/lightning/src/ln/channel_open_tests.rs index 3fd546aaff7..9a3f9704069 100644 --- a/lightning/src/ln/channel_open_tests.rs +++ b/lightning/src/ln/channel_open_tests.rs @@ -2337,7 +2337,7 @@ pub fn test_funding_and_commitment_tx_confirm_same_block() { } else { panic!(); } - if let MessageSendEvent::BroadcastChannelUpdate { ref msg } = msg_events.remove(0) { + if let MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } = msg_events.remove(0) { assert_eq!(msg.contents.channel_flags & 2, 2); } else { panic!(); diff --git a/lightning/src/ln/channelmanager.rs b/lightning/src/ln/channelmanager.rs index 644920557d2..21a99260c28 100644 --- a/lightning/src/ln/channelmanager.rs +++ b/lightning/src/ln/channelmanager.rs @@ -111,6 +111,7 @@ use crate::onion_message::messenger::{ MessageRouter, MessageSendInstructions, Responder, ResponseInstruction, }; use crate::onion_message::offers::{OffersMessage, OffersMessageHandler}; +use crate::routing::gossip::NodeId; use crate::routing::router::{ BlindedTail, FixedRouter, InFlightHtlcs, Path, Payee, PaymentParameters, Route, RouteParameters, RouteParametersConfig, Router, @@ -942,7 +943,7 @@ impl Into for FailureCode { struct MsgHandleErrInternal { err: msgs::LightningError, closes_channel: bool, - shutdown_finish: Option<(ShutdownResult, Option)>, + shutdown_finish: Option<(ShutdownResult, Option<(msgs::ChannelUpdate, NodeId, NodeId)>)>, tx_abort: Option, } impl MsgHandleErrInternal { @@ -966,7 +967,7 @@ impl MsgHandleErrInternal { fn from_finish_shutdown( err: String, channel_id: ChannelId, shutdown_res: ShutdownResult, - channel_update: Option, + channel_update: Option<(msgs::ChannelUpdate, NodeId, NodeId)>, ) -> Self { let err_msg = msgs::ErrorMessage { channel_id, data: err.clone() }; let action = if shutdown_res.monitor_update.is_some() { @@ -3244,10 +3245,10 @@ macro_rules! handle_error { log_error!(logger, "Closing channel: {}", err.err); $self.finish_close_channel(shutdown_res); - if let Some(update) = update_option { + if let Some((update, node_id_1, node_id_2)) = update_option { let mut pending_broadcast_messages = $self.pending_broadcast_messages.lock().unwrap(); pending_broadcast_messages.push(MessageSendEvent::BroadcastChannelUpdate { - msg: update + msg: update, node_id_1, node_id_2 }); } } else { @@ -3574,7 +3575,7 @@ macro_rules! handle_monitor_update_completion { // channel_update later through the announcement_signatures process for public // channels, but there's no reason not to just inform our counterparty of our fees // now. - if let Ok(msg) = $self.get_channel_update_for_unicast($chan) { + if let Ok((msg, _, _)) = $self.get_channel_update_for_unicast($chan) { Some(MessageSendEvent::SendChannelUpdate { node_id: counterparty_node_id, msg, @@ -5125,7 +5126,9 @@ where } } - /// Gets the current [`channel_update`] for the given channel. This first checks if the channel is + /// Gets the current [`channel_update`] for the given channel (as well as our and our + /// counterparty's [`NodeId`], which is needed for the + /// [`MessageSendEvent::BroadcastChannelUpdate`]). This first checks if the channel is /// public, and thus should be called whenever the result is going to be passed out in a /// [`MessageSendEvent::BroadcastChannelUpdate`] event. /// @@ -5137,7 +5140,7 @@ where /// [`internal_closing_signed`]: Self::internal_closing_signed fn get_channel_update_for_broadcast( &self, chan: &FundedChannel, - ) -> Result { + ) -> Result<(msgs::ChannelUpdate, NodeId, NodeId), LightningError> { if !chan.context.should_announce() { return Err(LightningError { err: "Cannot broadcast a channel_update for a private channel".to_owned(), @@ -5159,10 +5162,11 @@ where self.get_channel_update_for_unicast(chan) } - /// Gets the current [`channel_update`] for the given channel. This does not check if the channel - /// is public (only returning an `Err` if the channel does not yet have an assigned SCID), - /// and thus MUST NOT be called unless the recipient of the resulting message has already - /// provided evidence that they know about the existence of the channel. + /// Gets the current [`channel_update`] for the given channel (as well as our and our + /// counterparty's [`NodeId`]). This does not check if the channel is public (only returning an + /// `Err` if the channel does not yet have an assigned SCID), and thus MUST NOT be called + /// unless the recipient of the resulting message has already provided evidence that they know + /// about the existence of the channel. /// /// Note that through [`internal_closing_signed`], this function is called without the /// `peer_state` corresponding to the channel's counterparty locked, as the channel been @@ -5171,7 +5175,9 @@ where /// [`channel_update`]: msgs::ChannelUpdate /// [`internal_closing_signed`]: Self::internal_closing_signed #[rustfmt::skip] - fn get_channel_update_for_unicast(&self, chan: &FundedChannel) -> Result { + fn get_channel_update_for_unicast( + &self, chan: &FundedChannel, + ) -> Result<(msgs::ChannelUpdate, NodeId, NodeId), LightningError> { let logger = WithChannelContext::from(&self.logger, &chan.context, None); log_trace!(logger, "Attempting to generate channel update for channel {}", chan.context.channel_id()); let short_channel_id = match chan.funding.get_short_channel_id().or(chan.context.latest_inbound_scid_alias()) { @@ -5181,7 +5187,9 @@ where let logger = WithChannelContext::from(&self.logger, &chan.context, None); log_trace!(logger, "Generating channel update for channel {}", chan.context.channel_id()); - let were_node_one = self.our_network_pubkey.serialize()[..] < chan.context.get_counterparty_node_id().serialize()[..]; + let our_node_id = NodeId::from_pubkey(&self.our_network_pubkey); + let their_node_id = NodeId::from_pubkey(&chan.context.get_counterparty_node_id()); + let were_node_one = our_node_id < their_node_id; let enabled = chan.context.is_enabled(); let unsigned = msgs::UnsignedChannelUpdate { @@ -5203,10 +5211,14 @@ where // channel. let sig = self.node_signer.sign_gossip_message(msgs::UnsignedGossipMessage::ChannelUpdate(&unsigned)).unwrap(); - Ok(msgs::ChannelUpdate { - signature: sig, - contents: unsigned - }) + Ok(( + msgs::ChannelUpdate { + signature: sig, + contents: unsigned + }, + if were_node_one { our_node_id } else { their_node_id }, + if were_node_one { their_node_id } else { our_node_id }, + )) } #[cfg(any(test, feature = "_externalize_tests"))] @@ -6649,11 +6661,11 @@ where continue; } if let Some(channel) = channel.as_funded() { - if let Ok(msg) = self.get_channel_update_for_broadcast(channel) { + if let Ok((msg, node_id_1, node_id_2)) = self.get_channel_update_for_broadcast(channel) { let mut pending_broadcast_messages = self.pending_broadcast_messages.lock().unwrap(); - pending_broadcast_messages.push(MessageSendEvent::BroadcastChannelUpdate { msg }); + pending_broadcast_messages.push(MessageSendEvent::BroadcastChannelUpdate { msg, node_id_1, node_id_2 }); } else if peer_state.is_connected { - if let Ok(msg) = self.get_channel_update_for_unicast(channel) { + if let Ok((msg, _, _)) = self.get_channel_update_for_unicast(channel) { peer_state.pending_msg_events.push(MessageSendEvent::SendChannelUpdate { node_id: channel.context.get_counterparty_node_id(), msg, @@ -8177,10 +8189,10 @@ where n += 1; if n >= DISABLE_GOSSIP_TICKS { funded_chan.set_channel_update_status(ChannelUpdateStatus::Disabled); - if let Ok(update) = self.get_channel_update_for_broadcast(&funded_chan) { + if let Ok((update, node_id_1, node_id_2)) = self.get_channel_update_for_broadcast(&funded_chan) { let mut pending_broadcast_messages = self.pending_broadcast_messages.lock().unwrap(); pending_broadcast_messages.push(MessageSendEvent::BroadcastChannelUpdate { - msg: update + msg: update, node_id_1, node_id_2 }); } should_persist = NotifyOption::DoPersist; @@ -8192,10 +8204,10 @@ where n += 1; if n >= ENABLE_GOSSIP_TICKS { funded_chan.set_channel_update_status(ChannelUpdateStatus::Enabled); - if let Ok(update) = self.get_channel_update_for_broadcast(&funded_chan) { + if let Ok((update, node_id_1, node_id_2)) = self.get_channel_update_for_broadcast(&funded_chan) { let mut pending_broadcast_messages = self.pending_broadcast_messages.lock().unwrap(); pending_broadcast_messages.push(MessageSendEvent::BroadcastChannelUpdate { - msg: update + msg: update, node_id_1, node_id_2 }); } should_persist = NotifyOption::DoPersist; @@ -10821,7 +10833,7 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ // channel_update here if the channel is not public, i.e. we're not sending an // announcement_signatures. log_trace!(logger, "Sending private initial channel_update for our counterparty on channel {}", chan.context.channel_id()); - if let Ok(msg) = self.get_channel_update_for_unicast(chan) { + if let Ok((msg, _, _)) = self.get_channel_update_for_unicast(chan) { peer_state.pending_msg_events.push(MessageSendEvent::SendChannelUpdate { node_id: counterparty_node_id.clone(), msg, @@ -11620,7 +11632,7 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ msg: try_channel_entry!(self, peer_state, res, chan_entry), // Note that announcement_signatures fails if the channel cannot be announced, // so get_channel_update_for_broadcast will never fail by the time we get here. - update_msg: Some(self.get_channel_update_for_broadcast(chan).unwrap()), + update_msg: Some(self.get_channel_update_for_broadcast(chan).unwrap().0), }); } else { return try_channel_entry!(self, peer_state, Err(ChannelError::close( @@ -11729,7 +11741,7 @@ This indicates a bug inside LDK. Please report this error at https://github.com/ // If the channel is in a usable state (ie the channel is not being shut // down), send a unicast channel_update to our counterparty to make sure // they have the latest channel parameters. - if let Ok(msg) = self.get_channel_update_for_unicast(chan) { + if let Ok((msg, _, _)) = self.get_channel_update_for_unicast(chan) { channel_update = Some(MessageSendEvent::SendChannelUpdate { node_id: chan.context.get_counterparty_node_id(), msg, @@ -14340,7 +14352,7 @@ where send_channel_ready!(self, pending_msg_events, funded_channel, channel_ready); if funded_channel.context.is_usable() && peer_state.is_connected { log_trace!(logger, "Sending channel_ready with private initial channel_update for our counterparty on channel {}", channel_id); - if let Ok(msg) = self.get_channel_update_for_unicast(funded_channel) { + if let Ok((msg, _, _)) = self.get_channel_update_for_unicast(funded_channel) { pending_msg_events.push(MessageSendEvent::SendChannelUpdate { node_id: funded_channel.context.get_counterparty_node_id(), msg, @@ -14433,7 +14445,7 @@ where // if the channel cannot be announced, so // get_channel_update_for_broadcast will never fail // by the time we get here. - update_msg: Some(self.get_channel_update_for_broadcast(funded_channel).unwrap()), + update_msg: Some(self.get_channel_update_for_broadcast(funded_channel).unwrap().0), }); } } diff --git a/lightning/src/ln/functional_test_utils.rs b/lightning/src/ln/functional_test_utils.rs index 271d458bcc8..6bea16dbc8b 100644 --- a/lightning/src/ln/functional_test_utils.rs +++ b/lightning/src/ln/functional_test_utils.rs @@ -2182,7 +2182,7 @@ macro_rules! get_closing_signed_broadcast { assert!(events.len() == 1 || events.len() == 2); ( match events[events.len() - 1] { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => { + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => { assert_eq!(msg.contents.channel_flags & 2, 2); msg.clone() }, @@ -2253,7 +2253,7 @@ pub fn check_closed_broadcast( .into_iter() .filter_map(|msg_event| { match msg_event { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => { + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => { assert_eq!(msg.contents.channel_flags & 2, 2); None }, @@ -4875,7 +4875,7 @@ pub fn handle_announce_close_broadcast_events<'a, 'b, 'c>( let events_1 = nodes[a].node.get_and_clear_pending_msg_events(); assert_eq!(events_1.len(), 2); let as_update = match events_1[1] { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => msg.clone(), + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => msg.clone(), _ => panic!("Unexpected event"), }; match events_1[0] { @@ -4912,7 +4912,7 @@ pub fn handle_announce_close_broadcast_events<'a, 'b, 'c>( let events_2 = nodes[b].node.get_and_clear_pending_msg_events(); assert_eq!(events_2.len(), if needs_err_handle { 1 } else { 2 }); let bs_update = match events_2.last().unwrap() { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => msg.clone(), + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => msg.clone(), _ => panic!("Unexpected event"), }; if !needs_err_handle { diff --git a/lightning/src/ln/functional_tests.rs b/lightning/src/ln/functional_tests.rs index c161a9664c0..db229b4e0aa 100644 --- a/lightning/src/ln/functional_tests.rs +++ b/lightning/src/ln/functional_tests.rs @@ -717,7 +717,7 @@ pub fn channel_monitor_network_test() { let events = nodes[3].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 2); let close_chan_update_1 = match events[1] { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => msg.clone(), + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => msg.clone(), _ => panic!("Unexpected event"), }; match events[0] { @@ -752,7 +752,7 @@ pub fn channel_monitor_network_test() { let events = nodes[4].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 2); let close_chan_update_2 = match events[1] { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => msg.clone(), + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => msg.clone(), _ => panic!("Unexpected event"), }; match events[0] { @@ -2167,7 +2167,7 @@ fn do_test_commitment_revoked_fail_backward_exhaustive( // Ensure that the last remaining message event is the BroadcastChannelUpdate msg for chan_2 match events[0] { - MessageSendEvent::BroadcastChannelUpdate { msg: msgs::ChannelUpdate { .. } } => {}, + MessageSendEvent::BroadcastChannelUpdate { msg: msgs::ChannelUpdate { .. }, .. } => {}, _ => panic!("Unexpected event"), } @@ -6026,7 +6026,7 @@ pub fn test_announce_disable_channels() { let mut chans_disabled = new_hash_map(); for e in msg_events { match e { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => { + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => { assert_eq!(msg.contents.channel_flags & (1 << 1), 1 << 1); // The "channel disabled" bit should be set // Check that each channel gets updated exactly once if chans_disabled @@ -6077,7 +6077,7 @@ pub fn test_announce_disable_channels() { assert_eq!(msg_events.len(), 3); for e in msg_events { match e { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => { + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => { assert_eq!(msg.contents.channel_flags & (1 << 1), 0); // The "channel disabled" bit should be off match chans_disabled.remove(&msg.contents.short_channel_id) { // Each update should have a higher timestamp than the previous one, replacing @@ -7995,13 +7995,13 @@ pub fn test_error_chans_closed() { let events = nodes[0].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 2); match events[0] { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => { + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => { assert_eq!(msg.contents.channel_flags & 2, 2); }, _ => panic!("Unexpected event"), } match events[1] { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => { + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => { assert_eq!(msg.contents.channel_flags & 2, 2); }, _ => panic!("Unexpected event"), diff --git a/lightning/src/ln/msgs.rs b/lightning/src/ln/msgs.rs index c0c8239f621..36069b7c957 100644 --- a/lightning/src/ln/msgs.rs +++ b/lightning/src/ln/msgs.rs @@ -1917,6 +1917,16 @@ pub enum MessageSendEvent { BroadcastChannelUpdate { /// The channel_update which should be sent. msg: ChannelUpdate, + /// The node_id of the first endpoint of the channel. + /// + /// This is not used in the message broadcast, but rather is useful for deciding which + /// peer(s) to send the update to. + node_id_1: NodeId, + /// The node_id of the second endpoint of the channel. + /// + /// This is not used in the message broadcast, but rather is useful for deciding which + /// peer(s) to send the update to. + node_id_2: NodeId, }, /// Used to indicate that a node_announcement should be broadcast to all peers. BroadcastNodeAnnouncement { @@ -2189,13 +2199,13 @@ pub trait RoutingMessageHandler: BaseMessageHandler { fn handle_channel_announcement( &self, their_node_id: Option, msg: &ChannelAnnouncement, ) -> Result; - /// Handle an incoming `channel_update` message, returning true if it should be forwarded on, - /// `false` or returning an `Err` otherwise. + /// Handle an incoming `channel_update` message, returning the node IDs of the channel + /// participants if the message should be forwarded on, `None` or returning an `Err` otherwise. /// /// If `their_node_id` is `None`, the message was generated by our own local node. fn handle_channel_update( &self, their_node_id: Option, msg: &ChannelUpdate, - ) -> Result; + ) -> Result, LightningError>; /// Gets channel announcements and updates required to dump our routing table to a remote node, /// starting at the `short_channel_id` indicated by `starting_point` and including announcements /// for a single channel. diff --git a/lightning/src/ln/onion_route_tests.rs b/lightning/src/ln/onion_route_tests.rs index f4cfb9eda00..067b4092315 100644 --- a/lightning/src/ln/onion_route_tests.rs +++ b/lightning/src/ln/onion_route_tests.rs @@ -1662,7 +1662,7 @@ fn do_test_onion_failure_stale_channel_update(announce_for_forwarding: bool) { return None; } let new_update = match &events[0] { - MessageSendEvent::BroadcastChannelUpdate { msg } => { + MessageSendEvent::BroadcastChannelUpdate { msg, .. } => { assert!(announce_for_forwarding); msg.clone() }, diff --git a/lightning/src/ln/peer_handler.rs b/lightning/src/ln/peer_handler.rs index 74f081b03ae..4c379c29dd1 100644 --- a/lightning/src/ln/peer_handler.rs +++ b/lightning/src/ln/peer_handler.rs @@ -157,8 +157,8 @@ impl RoutingMessageHandler for IgnoringMessageHandler { } fn handle_channel_update( &self, _their_node_id: Option, _msg: &msgs::ChannelUpdate, - ) -> Result { - Ok(false) + ) -> Result, LightningError> { + Ok(None) } fn get_next_channel_announcement( &self, _starting_point: u64, @@ -611,6 +611,19 @@ where pub send_only_message_handler: SM, } +/// A gossip message to be forwarded to all peers. +enum BroadcastGossipMessage { + ChannelAnnouncement(msgs::ChannelAnnouncement), + NodeAnnouncement(msgs::NodeAnnouncement), + ChannelUpdate { + msg: msgs::ChannelUpdate, + /// One of the two channel endpoints. + node_id_1: NodeId, + /// One of the two channel endpoints. + node_id_2: NodeId, + }, +} + /// Provides an object which can be used to send data to and which uniquely identifies a connection /// to a remote host. You will need to be able to generate multiple of these which meet Eq and /// implement Hash to meet the PeerManager API. @@ -1101,6 +1114,7 @@ pub struct PeerManager< gossip_processing_backlog_lifted: AtomicBool, node_signer: NS, + our_node_id: NodeId, logger: L, secp_ctx: Secp256k1, @@ -1315,6 +1329,9 @@ where let ephemeral_hash = Sha256::from_engine(ephemeral_key_midstate.clone()).to_byte_array(); secp_ctx.seeded_randomize(&ephemeral_hash); + let our_node_pubkey = + node_signer.get_node_id(Recipient::Node).expect("node_id must be available"); + PeerManager { message_handler, peers: FairRwLock::new(new_hash_map()), @@ -1326,6 +1343,7 @@ where gossip_processing_backlog_lifted: AtomicBool::new(false), last_node_announcement_serial: AtomicU32::new(current_time), logger, + our_node_id: NodeId::from_pubkey(&our_node_pubkey), node_signer, secp_ctx, } @@ -2045,10 +2063,7 @@ where message: wire::Message< <::Target as wire::CustomMessageReader>::CustomMessage, >, - ) -> Result< - Option::Target as wire::CustomMessageReader>::CustomMessage>>, - MessageHandlingError, - > { + ) -> Result, MessageHandlingError> { let their_node_id = peer_lock .their_node_id .expect("We know the peer's public key by the time we receive messages") @@ -2390,10 +2405,7 @@ where <::Target as wire::CustomMessageReader>::CustomMessage, >, their_node_id: PublicKey, logger: &WithContext<'a, L>, - ) -> Result< - Option::Target as wire::CustomMessageReader>::CustomMessage>>, - MessageHandlingError, - > { + ) -> Result, MessageHandlingError> { if is_gossip_msg(message.type_id()) { log_gossip!(logger, "Received message {:?} from {}", message, their_node_id); } else { @@ -2575,7 +2587,7 @@ where .handle_channel_announcement(Some(their_node_id), &msg) .map_err(|e| -> MessageHandlingError { e.into() })? { - should_forward = Some(wire::Message::ChannelAnnouncement(msg)); + should_forward = Some(BroadcastGossipMessage::ChannelAnnouncement(msg)); } self.update_gossip_backlogged(); }, @@ -2585,7 +2597,7 @@ where .handle_node_announcement(Some(their_node_id), &msg) .map_err(|e| -> MessageHandlingError { e.into() })? { - should_forward = Some(wire::Message::NodeAnnouncement(msg)); + should_forward = Some(BroadcastGossipMessage::NodeAnnouncement(msg)); } self.update_gossip_backlogged(); }, @@ -2594,11 +2606,12 @@ where chan_handler.handle_channel_update(their_node_id, &msg); let route_handler = &self.message_handler.route_handler; - if route_handler + if let Some((node_id_1, node_id_2)) = route_handler .handle_channel_update(Some(their_node_id), &msg) .map_err(|e| -> MessageHandlingError { e.into() })? { - should_forward = Some(wire::Message::ChannelUpdate(msg)); + should_forward = + Some(BroadcastGossipMessage::ChannelUpdate { msg, node_id_1, node_id_2 }); } self.update_gossip_backlogged(); }, @@ -2652,20 +2665,23 @@ where /// unless `allow_large_buffer` is set, in which case the message will be treated as critical /// and delivered no matter the available buffer space. fn forward_broadcast_msg( - &self, peers: &HashMap>, - msg: &wire::Message<<::Target as wire::CustomMessageReader>::CustomMessage>, + &self, peers: &HashMap>, msg: &BroadcastGossipMessage, except_node: Option<&PublicKey>, allow_large_buffer: bool, ) { match msg { - wire::Message::ChannelAnnouncement(ref msg) => { + BroadcastGossipMessage::ChannelAnnouncement(ref msg) => { log_gossip!(self.logger, "Sending message to all peers except {:?} or the announced channel's counterparties: {:?}", except_node, msg); let encoded_msg = encode_msg!(msg); + let our_channel = self.our_node_id == msg.contents.node_id_1 + || self.our_node_id == msg.contents.node_id_2; for (_, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); - if !peer.handshake_complete() - || !peer.should_forward_channel_announcement(msg.contents.short_channel_id) - { + if !peer.handshake_complete() { + continue; + } + let scid = msg.contents.short_channel_id; + if !our_channel && !peer.should_forward_channel_announcement(scid) { continue; } debug_assert!(peer.their_node_id.is_some()); @@ -2696,7 +2712,7 @@ where peer.gossip_broadcast_buffer.push_back(encoded_message); } }, - wire::Message::NodeAnnouncement(ref msg) => { + BroadcastGossipMessage::NodeAnnouncement(ref msg) => { log_gossip!( self.logger, "Sending message to all peers except {:?} or the announced node: {:?}", @@ -2704,12 +2720,15 @@ where msg ); let encoded_msg = encode_msg!(msg); + let our_announcement = self.our_node_id == msg.contents.node_id; for (_, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); - if !peer.handshake_complete() - || !peer.should_forward_node_announcement(msg.contents.node_id) - { + if !peer.handshake_complete() { + continue; + } + let node_id = msg.contents.node_id; + if !our_announcement && !peer.should_forward_node_announcement(node_id) { continue; } debug_assert!(peer.their_node_id.is_some()); @@ -2738,7 +2757,7 @@ where peer.gossip_broadcast_buffer.push_back(encoded_message); } }, - wire::Message::ChannelUpdate(ref msg) => { + BroadcastGossipMessage::ChannelUpdate { msg, node_id_1, node_id_2 } => { log_gossip!( self.logger, "Sending message to all peers except {:?}: {:?}", @@ -2746,12 +2765,15 @@ where msg ); let encoded_msg = encode_msg!(msg); + let our_channel = self.our_node_id == *node_id_1 || self.our_node_id == *node_id_2; for (_, peer_mutex) in peers.iter() { let mut peer = peer_mutex.lock().unwrap(); - if !peer.handshake_complete() - || !peer.should_forward_channel_announcement(msg.contents.short_channel_id) - { + if !peer.handshake_complete() { + continue; + } + let scid = msg.contents.short_channel_id; + if !our_channel && !peer.should_forward_channel_announcement(scid) { continue; } debug_assert!(peer.their_node_id.is_some()); @@ -2775,9 +2797,6 @@ where peer.gossip_broadcast_buffer.push_back(encoded_message); } }, - _ => { - debug_assert!(false, "We shouldn't attempt to forward anything but gossip messages") - }, } } @@ -3129,13 +3148,15 @@ where }, MessageSendEvent::BroadcastChannelAnnouncement { msg, update_msg } => { log_debug!(self.logger, "Handling BroadcastChannelAnnouncement event in peer_handler for short channel id {}", msg.contents.short_channel_id); + let node_id_1 = msg.contents.node_id_1; + let node_id_2 = msg.contents.node_id_2; match route_handler.handle_channel_announcement(None, &msg) { Ok(_) | Err(LightningError { action: msgs::ErrorAction::IgnoreDuplicateGossip, .. }) => { - let forward = wire::Message::ChannelAnnouncement(msg); + let forward = BroadcastGossipMessage::ChannelAnnouncement(msg); self.forward_broadcast_msg( peers, &forward, @@ -3152,7 +3173,11 @@ where action: msgs::ErrorAction::IgnoreDuplicateGossip, .. }) => { - let forward = wire::Message::ChannelUpdate(msg); + let forward = BroadcastGossipMessage::ChannelUpdate { + msg, + node_id_1, + node_id_2, + }; self.forward_broadcast_msg( peers, &forward, @@ -3164,7 +3189,7 @@ where } } }, - MessageSendEvent::BroadcastChannelUpdate { msg } => { + MessageSendEvent::BroadcastChannelUpdate { msg, node_id_1, node_id_2 } => { log_debug!(self.logger, "Handling BroadcastChannelUpdate event in peer_handler for contents {:?}", msg.contents); match route_handler.handle_channel_update(None, &msg) { Ok(_) @@ -3172,7 +3197,11 @@ where action: msgs::ErrorAction::IgnoreDuplicateGossip, .. }) => { - let forward = wire::Message::ChannelUpdate(msg); + let forward = BroadcastGossipMessage::ChannelUpdate { + msg, + node_id_1, + node_id_2, + }; self.forward_broadcast_msg( peers, &forward, @@ -3191,7 +3220,7 @@ where action: msgs::ErrorAction::IgnoreDuplicateGossip, .. }) => { - let forward = wire::Message::NodeAnnouncement(msg); + let forward = BroadcastGossipMessage::NodeAnnouncement(msg); self.forward_broadcast_msg( peers, &forward, @@ -3668,7 +3697,7 @@ where let _ = self.message_handler.route_handler.handle_node_announcement(None, &msg); self.forward_broadcast_msg( &*self.peers.read().unwrap(), - &wire::Message::NodeAnnouncement(msg), + &BroadcastGossipMessage::NodeAnnouncement(msg), None, true, ); @@ -4409,8 +4438,6 @@ mod tests { #[test] fn test_forward_while_syncing() { - use crate::ln::peer_handler::tests::test_utils::get_dummy_channel_update; - // Test forwarding new channel announcements while we're doing syncing. let cfgs = create_peermgr_cfgs(2); cfgs[0].routing_handler.request_full_sync.store(true, Ordering::Release); @@ -4457,11 +4484,19 @@ mod tests { // At this point we should have sent channel announcements up to roughly SCID 150. Now // build an updated update for SCID 100 and SCID 5000 and make sure only the one for SCID // 100 gets forwarded - let msg_100 = get_dummy_channel_update(100); - let msg_ev_100 = MessageSendEvent::BroadcastChannelUpdate { msg: msg_100.clone() }; + let msg_100 = test_utils::get_dummy_channel_update(100); + let msg_ev_100 = MessageSendEvent::BroadcastChannelUpdate { + msg: msg_100.clone(), + node_id_1: NodeId::from_slice(&[2; 33]).unwrap(), + node_id_2: NodeId::from_slice(&[3; 33]).unwrap(), + }; - let msg_5000 = get_dummy_channel_update(5000); - let msg_ev_5000 = MessageSendEvent::BroadcastChannelUpdate { msg: msg_5000 }; + let msg_5000 = test_utils::get_dummy_channel_update(5000); + let msg_ev_5000 = MessageSendEvent::BroadcastChannelUpdate { + msg: msg_5000, + node_id_1: NodeId::from_slice(&[2; 33]).unwrap(), + node_id_2: NodeId::from_slice(&[3; 33]).unwrap(), + }; fd_a.hang_writes.store(true, Ordering::Relaxed); @@ -4870,6 +4905,71 @@ mod tests { assert_eq!(filter_addresses(None), None); } + #[test] + fn test_forward_gossip_for_our_channels_ignores_peer_filter() { + // Tests that gossip for channels where we are one of the endpoints is forwarded to all + // peers, regardless of any gossip filters they may have set. This ensures that updates + // for our own channels always propagate to all connected peers. + + let cfgs = create_peermgr_cfgs(2); + let peers = create_network(2, &cfgs); + + let id_0 = peers[0].node_signer.get_node_id(Recipient::Node).unwrap(); + + // Connect the peers and exchange the initial connection handshake (but not the final Init + // message). + let (mut fd_0_1, mut fd_1_0) = establish_connection(&peers[0], &peers[1]); + + // Once peer 1 receives the Init message in the last read_event, it'll generate a + // `GossipTimestampFilter` which will request gossip. Instead we drop it here. + cfgs[1] + .routing_handler + .pending_events + .lock() + .unwrap() + .retain(|ev| !matches!(ev, MessageSendEvent::SendGossipTimestampFilter { .. })); + + peers[1].process_events(); + let data_1_0 = fd_1_0.outbound_data.lock().unwrap().split_off(0); + peers[0].read_event(&mut fd_0_1, &data_1_0).unwrap(); // Init message + + peers[0].process_events(); + assert!(fd_0_1.outbound_data.lock().unwrap().is_empty()); + assert!(fd_1_0.outbound_data.lock().unwrap().is_empty()); + + let mut check_message_received = |expected_received: bool| { + let initial_count = cfgs[1].routing_handler.chan_upds_recvd.load(Ordering::Acquire); + + peers[0].process_events(); + let data_0_1 = fd_0_1.outbound_data.lock().unwrap().split_off(0); + assert_eq!(data_0_1.is_empty(), !expected_received); + peers[1].read_event(&mut fd_1_0, &data_0_1).unwrap(); + + let final_count = cfgs[1].routing_handler.chan_upds_recvd.load(Ordering::Acquire); + assert_eq!(final_count > initial_count, expected_received); + }; + + // Broadcast a gossip message that is unrelated to us and check that it doesn't get relayed + let unrelated_msg_ev = MessageSendEvent::BroadcastChannelUpdate { + msg: test_utils::get_dummy_channel_update(43), + node_id_1: NodeId::from_slice(&[2; 33]).unwrap(), + node_id_2: NodeId::from_slice(&[3; 33]).unwrap(), + }; + cfgs[0].routing_handler.pending_events.lock().unwrap().push(unrelated_msg_ev); + + check_message_received(false); + + // Broadcast a gossip message that we're a party to and check that its relayed + let our_channel_msg_ev = MessageSendEvent::BroadcastChannelUpdate { + msg: test_utils::get_dummy_channel_update(43), + node_id_1: NodeId::from_pubkey(&id_0), + node_id_2: NodeId::from_slice(&[3; 33]).unwrap(), + }; + cfgs[0].routing_handler.pending_events.lock().unwrap().push(our_channel_msg_ev); + + check_message_received(true); + } + #[test] #[cfg(feature = "std")] fn test_process_events_multithreaded() { diff --git a/lightning/src/ln/shutdown_tests.rs b/lightning/src/ln/shutdown_tests.rs index 437298afc22..8fbb22b40b4 100644 --- a/lightning/src/ln/shutdown_tests.rs +++ b/lightning/src/ln/shutdown_tests.rs @@ -1402,7 +1402,7 @@ fn do_test_closing_signed_reinit_timeout(timeout_step: TimeoutStep) { let events = nodes[1].node.get_and_clear_pending_msg_events(); assert_eq!(events.len(), 1); match events[0] { - MessageSendEvent::BroadcastChannelUpdate { ref msg } => { + MessageSendEvent::BroadcastChannelUpdate { ref msg, .. } => { assert_eq!(msg.contents.channel_flags & 2, 2); }, _ => panic!("Unexpected event"), diff --git a/lightning/src/routing/gossip.rs b/lightning/src/routing/gossip.rs index 80ffbf9fb6c..ae317ad1ac3 100644 --- a/lightning/src/routing/gossip.rs +++ b/lightning/src/routing/gossip.rs @@ -394,7 +394,7 @@ where *update_msg = None; } }, - MessageSendEvent::BroadcastChannelUpdate { msg } => { + MessageSendEvent::BroadcastChannelUpdate { msg, .. } => { if msg.contents.excess_data.len() > MAX_EXCESS_BYTES_FOR_RELAY { return; } @@ -556,9 +556,12 @@ where fn handle_channel_update( &self, _their_node_id: Option, msg: &msgs::ChannelUpdate, - ) -> Result { - self.network_graph.update_channel(msg)?; - Ok(msg.contents.excess_data.len() <= MAX_EXCESS_BYTES_FOR_RELAY) + ) -> Result, LightningError> { + match self.network_graph.update_channel(msg) { + Ok(nodes) if msg.contents.excess_data.len() <= MAX_EXCESS_BYTES_FOR_RELAY => Ok(nodes), + Ok(_) => Ok(None), + Err(e) => Err(e), + } } fn get_next_channel_announcement( @@ -2433,7 +2436,11 @@ where /// /// If not built with `std`, any updates with a timestamp more than two weeks in the past or /// materially in the future will be rejected. - pub fn update_channel(&self, msg: &msgs::ChannelUpdate) -> Result<(), LightningError> { + /// + /// Returns the [`NodeId`]s of both sides of the channel if it was applied. + pub fn update_channel( + &self, msg: &msgs::ChannelUpdate, + ) -> Result, LightningError> { self.update_channel_internal(&msg.contents, Some(&msg), Some(&msg.signature), false) } @@ -2443,9 +2450,11 @@ where /// /// If not built with `std`, any updates with a timestamp more than two weeks in the past or /// materially in the future will be rejected. + /// + /// Returns the [`NodeId`]s of both sides of the channel if it was applied. pub fn update_channel_unsigned( &self, msg: &msgs::UnsignedChannelUpdate, - ) -> Result<(), LightningError> { + ) -> Result, LightningError> { self.update_channel_internal(msg, None, None, false) } @@ -2456,13 +2465,14 @@ where /// If not built with `std`, any updates with a timestamp more than two weeks in the past or /// materially in the future will be rejected. pub fn verify_channel_update(&self, msg: &msgs::ChannelUpdate) -> Result<(), LightningError> { - self.update_channel_internal(&msg.contents, Some(&msg), Some(&msg.signature), true) + self.update_channel_internal(&msg.contents, Some(&msg), Some(&msg.signature), true)?; + Ok(()) } fn update_channel_internal( &self, msg: &msgs::UnsignedChannelUpdate, full_msg: Option<&msgs::ChannelUpdate>, sig: Option<&secp256k1::ecdsa::Signature>, only_verify: bool, - ) -> Result<(), LightningError> { + ) -> Result, LightningError> { let chan_enabled = msg.channel_flags & (1 << 1) != (1 << 1); if msg.chain_hash != self.chain_hash { @@ -2602,7 +2612,7 @@ where } if only_verify { - return Ok(()); + return Ok(None); } let mut channels = self.channels.write().unwrap(); @@ -2633,9 +2643,11 @@ where } else { channel.one_to_two = new_channel_info; } - } - Ok(()) + Ok(Some((channel.node_one, channel.node_two))) + } else { + Ok(None) + } } fn remove_channel_in_nodes_callback)>( @@ -3180,7 +3192,7 @@ pub(crate) mod tests { let valid_channel_update = get_signed_channel_update(|_| {}, node_1_privkey, &secp_ctx); network_graph.verify_channel_update(&valid_channel_update).unwrap(); match gossip_sync.handle_channel_update(Some(node_1_pubkey), &valid_channel_update) { - Ok(res) => assert!(res), + Ok(res) => assert!(res.is_some()), _ => panic!(), }; @@ -3202,9 +3214,9 @@ pub(crate) mod tests { node_1_privkey, &secp_ctx, ); - // Return false because contains excess data + // Update is accepted but won't be relayed because contains excess data match gossip_sync.handle_channel_update(Some(node_1_pubkey), &valid_channel_update) { - Ok(res) => assert!(!res), + Ok(res) => assert!(res.is_none()), _ => panic!(), }; diff --git a/lightning/src/routing/test_utils.rs b/lightning/src/routing/test_utils.rs index ab2b24c19e0..c5c35c9ce77 100644 --- a/lightning/src/routing/test_utils.rs +++ b/lightning/src/routing/test_utils.rs @@ -111,7 +111,7 @@ pub(crate) fn update_channel( }; match gossip_sync.handle_channel_update(Some(node_pubkey), &valid_channel_update) { - Ok(res) => assert!(res), + Ok(res) => assert!(res.is_some()), Err(e) => panic!("{e:?}") }; } diff --git a/lightning/src/routing/utxo.rs b/lightning/src/routing/utxo.rs index 4968d6cd7b4..4299dffb90f 100644 --- a/lightning/src/routing/utxo.rs +++ b/lightning/src/routing/utxo.rs @@ -233,6 +233,10 @@ impl UtxoFuture { // Note that we ignore errors as we don't disconnect peers anyway, so there's nothing to do // with them. let resolver = UtxoResolver(result); + let (node_id_1, node_id_2) = match &announcement { + ChannelAnnouncement::Full(signed_msg) => (signed_msg.contents.node_id_1, signed_msg.contents.node_id_2), + ChannelAnnouncement::Unsigned(msg) => (msg.node_id_1, msg.node_id_2), + }; match announcement { ChannelAnnouncement::Full(signed_msg) => { if graph.update_channel_from_announcement(&signed_msg, &Some(&resolver)).is_ok() { @@ -270,6 +274,8 @@ impl UtxoFuture { if graph.update_channel(&signed_msg).is_ok() { res[res_idx] = Some(MessageSendEvent::BroadcastChannelUpdate { msg: signed_msg, + node_id_1, + node_id_2, }); res_idx += 1; } diff --git a/lightning/src/util/test_utils.rs b/lightning/src/util/test_utils.rs index ad8ea224205..6e664d300de 100644 --- a/lightning/src/util/test_utils.rs +++ b/lightning/src/util/test_utils.rs @@ -1522,9 +1522,9 @@ impl msgs::RoutingMessageHandler for TestRoutingMessageHandler { } fn handle_channel_update( &self, _their_node_id: Option, _msg: &msgs::ChannelUpdate, - ) -> Result { + ) -> Result, msgs::LightningError> { self.chan_upds_recvd.fetch_add(1, Ordering::AcqRel); - Ok(true) + Ok(Some((NodeId::from_slice(&[2; 33]).unwrap(), NodeId::from_slice(&[3; 33]).unwrap()))) } fn get_next_channel_announcement( &self, starting_point: u64,