Skip to content

Commit a3c7c37

Browse files
committed
fix: solve code blocking on disabled relay
1 parent e8c797c commit a3c7c37

File tree

1 file changed

+16
-25
lines changed

1 file changed

+16
-25
lines changed

src/main.rs

Lines changed: 16 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ fn parse_alpn(alpn: &str) -> Result<Vec<u8>> {
155155
}
156156

157157
/// Available command line options for configuring relays.
158-
#[derive(Clone, Debug)]
158+
#[derive(Clone, Debug, PartialEq)]
159159
pub enum RelayModeOption {
160160
/// Disables relays altogether.
161161
Disabled,
@@ -351,6 +351,8 @@ async fn create_endpoint(
351351
builder = builder.bind_addr_v6(addr);
352352
}
353353
let endpoint = builder.bind().await?;
354+
endpoint.online().await;
355+
354356
Ok(endpoint)
355357
}
356358

@@ -396,11 +398,10 @@ async fn forward_bidi(
396398
async fn listen_stdio(args: ListenArgs) -> Result<()> {
397399
let secret_key = get_or_create_secret()?;
398400
let endpoint = create_endpoint(secret_key, &args.common, vec![args.common.alpn()?]).await?;
399-
// wait for the endpoint to figure out its home relay and addresses before making a ticket
400-
endpoint.online().await;
401-
let node = endpoint.node_addr();
402-
let mut short = node.clone();
403-
let ticket = NodeTicket::new(node);
401+
// wait for the endpoint to figure out its address before making a ticket
402+
let node_addr = endpoint.node_addr();
403+
let mut short = node_addr.clone();
404+
let ticket = NodeTicket::new(node_addr);
404405
short.direct_addresses.clear();
405406
let short = NodeTicket::new(short);
406407

@@ -457,10 +458,10 @@ async fn listen_stdio(args: ListenArgs) -> Result<()> {
457458
async fn connect_stdio(args: ConnectArgs) -> Result<()> {
458459
let secret_key = get_or_create_secret()?;
459460
let endpoint = create_endpoint(secret_key, &args.common, vec![]).await?;
460-
let addr = args.ticket.node_addr();
461-
let remote_node_id = addr.node_id;
461+
let node_addr = endpoint.node_addr();
462+
let remote_node_id = node_addr.node_id;
462463
// connect to the node, try only once
463-
let connection = endpoint.connect(addr.clone(), &args.common.alpn()?).await?;
464+
let connection = endpoint.connect(node_addr.clone(), &args.common.alpn()?).await?;
464465
tracing::info!("connected to {}", remote_node_id);
465466
// open a bidi stream, try only once
466467
let (mut s, r) = connection.open_bi().await.e()?;
@@ -495,9 +496,6 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
495496
.context("unable to bind endpoint")?;
496497
tracing::info!("tcp listening on {:?}", addrs);
497498

498-
// Wait for our own endpoint to be ready before trying to connect.
499-
endpoint.online().await;
500-
501499
let tcp_listener = match tokio::net::TcpListener::bind(addrs.as_slice()).await {
502500
Ok(tcp_listener) => tcp_listener,
503501
Err(cause) => {
@@ -534,7 +532,7 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
534532
forward_bidi(tcp_recv, tcp_send, endpoint_recv, endpoint_send).await?;
535533
Ok::<_, n0_snafu::Error>(())
536534
}
537-
let addr = args.ticket.node_addr();
535+
let node_addr = args.ticket.node_addr();
538536
loop {
539537
// also wait for ctrl-c here so we can use it before accepting a connection
540538
let next = tokio::select! {
@@ -545,11 +543,11 @@ async fn connect_tcp(args: ConnectTcpArgs) -> Result<()> {
545543
}
546544
};
547545
let endpoint = endpoint.clone();
548-
let addr = addr.clone();
546+
let node_addr = node_addr.clone();
549547
let handshake = !args.common.is_custom_alpn();
550548
let alpn = args.common.alpn()?;
551549
tokio::spawn(async move {
552-
if let Err(cause) = handle_tcp_accept(next, addr, endpoint, handshake, &alpn).await {
550+
if let Err(cause) = handle_tcp_accept(next, node_addr, endpoint, handshake, &alpn).await {
553551
// log error at warn level
554552
//
555553
// we should know about it, but it's not fatal
@@ -568,8 +566,6 @@ async fn listen_tcp(args: ListenTcpArgs) -> Result<()> {
568566
};
569567
let secret_key = get_or_create_secret()?;
570568
let endpoint = create_endpoint(secret_key, &args.common, vec![args.common.alpn()?]).await?;
571-
// wait for the endpoint to figure out its address before making a ticket
572-
endpoint.online().await;
573569
let node_addr = endpoint.node_addr();
574570
let mut short = node_addr.clone();
575571
let ticket = NodeTicket::new(node_addr);
@@ -650,8 +646,6 @@ async fn listen_unix(args: ListenUnixArgs) -> Result<()> {
650646
let socket_path = args.socket_path.clone();
651647
let secret_key = get_or_create_secret()?;
652648
let endpoint = create_endpoint(secret_key, &args.common, vec![args.common.alpn()?]).await?;
653-
// wait for the endpoint to figure out its address before making a ticket
654-
endpoint.online().await;
655649
let node_addr = endpoint.node_addr();
656650
let mut short = node_addr.clone();
657651
let ticket = NodeTicket::new(node_addr);
@@ -765,20 +759,17 @@ async fn connect_unix(args: ConnectUnixArgs) -> Result<()> {
765759
.context("unable to bind endpoint")?;
766760
tracing::info!("unix listening on {:?}", socket_path);
767761

768-
// Wait for our own endpoint to be ready before trying to connect.
769-
endpoint.online().await;
770-
771762
// Remove existing socket file if it exists
772763
if let Err(e) = tokio::fs::remove_file(&socket_path).await {
773764
if e.kind() != io::ErrorKind::NotFound {
774765
snafu::whatever!("failed to remove existing socket file: {}", e);
775766
}
776767
}
777768

778-
let addr = args.ticket.node_addr();
779-
tracing::info!("connecting to remote node: {:?}", addr);
769+
let node_addr = args.ticket.node_addr();
770+
tracing::info!("connecting to remote node: {:?}", node_addr);
780771
let connection = endpoint
781-
.connect(addr.clone(), &args.common.alpn()?)
772+
.connect(node_addr.clone(), &args.common.alpn()?)
782773
.await
783774
.context("failed to connect to remote node")?;
784775
tracing::info!("connected to remote node successfully");

0 commit comments

Comments
 (0)