|
5 | 5 |
|
6 | 6 | use async_trait::async_trait; |
7 | 7 | use base64::{engine::general_purpose as b64, Engine as _}; |
| 8 | +use futures::{stream::FuturesUnordered, StreamExt}; |
8 | 9 | use serde::Serialize; |
9 | 10 | use sha2::{Digest, Sha256}; |
10 | 11 | use std::{str::FromStr, time::Duration}; |
11 | 12 | use sysinfo::Pid; |
12 | 13 |
|
13 | 14 | use super::{ |
14 | 15 | args::{ |
15 | | - AuthProvider, CliCore, ExistingTunnelArgs, TunnelRenameArgs, TunnelServeArgs, |
16 | | - TunnelServiceSubCommands, TunnelUserSubCommands, |
| 16 | + AuthProvider, CliCore, CommandShellArgs, ExistingTunnelArgs, TunnelRenameArgs, |
| 17 | + TunnelServeArgs, TunnelServiceSubCommands, TunnelUserSubCommands, |
17 | 18 | }, |
18 | 19 | CommandContext, |
19 | 20 | }; |
20 | 21 |
|
21 | 22 | use crate::{ |
| 23 | + async_pipe::{get_socket_name, listen_socket_rw_stream, socket_stream_split}, |
22 | 24 | auth::Auth, |
23 | 25 | constants::{APPLICATION_NAME, TUNNEL_CLI_LOCK_NAME, TUNNEL_SERVICE_LOCK_NAME}, |
24 | 26 | log, |
@@ -59,20 +61,31 @@ impl From<AuthProvider> for crate::auth::AuthProvider { |
59 | 61 | } |
60 | 62 | } |
61 | 63 |
|
62 | | -impl From<ExistingTunnelArgs> for Option<dev_tunnels::ExistingTunnel> { |
63 | | - fn from(d: ExistingTunnelArgs) -> Option<dev_tunnels::ExistingTunnel> { |
64 | | - if let (Some(tunnel_id), Some(tunnel_name), Some(cluster), Some(host_token)) = |
65 | | - (d.tunnel_id, d.tunnel_name, d.cluster, d.host_token) |
66 | | - { |
| 64 | +fn fulfill_existing_tunnel_args( |
| 65 | + d: ExistingTunnelArgs, |
| 66 | + name_arg: &Option<String>, |
| 67 | +) -> Option<dev_tunnels::ExistingTunnel> { |
| 68 | + let tunnel_name = d.tunnel_name.or_else(|| name_arg.clone()); |
| 69 | + |
| 70 | + match (d.tunnel_id, d.cluster, d.host_token) { |
| 71 | + (Some(tunnel_id), None, Some(host_token)) => { |
| 72 | + let i = tunnel_id.find('.')?; |
67 | 73 | Some(dev_tunnels::ExistingTunnel { |
68 | | - tunnel_id, |
| 74 | + tunnel_id: tunnel_id[..i].to_string(), |
| 75 | + cluster: tunnel_id[i + 1..].to_string(), |
69 | 76 | tunnel_name, |
70 | 77 | host_token, |
71 | | - cluster, |
72 | 78 | }) |
73 | | - } else { |
74 | | - None |
75 | 79 | } |
| 80 | + |
| 81 | + (Some(tunnel_id), Some(cluster), Some(host_token)) => Some(dev_tunnels::ExistingTunnel { |
| 82 | + tunnel_id, |
| 83 | + tunnel_name, |
| 84 | + host_token, |
| 85 | + cluster, |
| 86 | + }), |
| 87 | + |
| 88 | + _ => None, |
76 | 89 | } |
77 | 90 | } |
78 | 91 |
|
@@ -109,23 +122,55 @@ impl ServiceContainer for TunnelServiceContainer { |
109 | 122 | } |
110 | 123 | } |
111 | 124 |
|
112 | | -pub async fn command_shell(ctx: CommandContext) -> Result<i32, AnyError> { |
| 125 | +pub async fn command_shell(ctx: CommandContext, args: CommandShellArgs) -> Result<i32, AnyError> { |
113 | 126 | let platform = PreReqChecker::new().verify().await?; |
114 | | - serve_stream( |
115 | | - tokio::io::stdin(), |
116 | | - tokio::io::stderr(), |
117 | | - ServeStreamParams { |
118 | | - log: ctx.log, |
119 | | - launcher_paths: ctx.paths, |
120 | | - platform, |
121 | | - requires_auth: true, |
122 | | - exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]), |
123 | | - code_server_args: (&ctx.args).into(), |
124 | | - }, |
125 | | - ) |
126 | | - .await; |
| 127 | + let mut params = ServeStreamParams { |
| 128 | + log: ctx.log, |
| 129 | + launcher_paths: ctx.paths, |
| 130 | + platform, |
| 131 | + requires_auth: true, |
| 132 | + exit_barrier: ShutdownRequest::create_rx([ShutdownRequest::CtrlC]), |
| 133 | + code_server_args: (&ctx.args).into(), |
| 134 | + }; |
127 | 135 |
|
128 | | - Ok(0) |
| 136 | + if !args.on_socket { |
| 137 | + serve_stream(tokio::io::stdin(), tokio::io::stderr(), params).await; |
| 138 | + return Ok(0); |
| 139 | + } |
| 140 | + |
| 141 | + let socket = get_socket_name(); |
| 142 | + let mut listener = listen_socket_rw_stream(&socket) |
| 143 | + .await |
| 144 | + .map_err(|e| wrap(e, "error listening on socket"))?; |
| 145 | + |
| 146 | + params |
| 147 | + .log |
| 148 | + .result(format!("Listening on {}", socket.display())); |
| 149 | + |
| 150 | + let mut servers = FuturesUnordered::new(); |
| 151 | + |
| 152 | + loop { |
| 153 | + tokio::select! { |
| 154 | + Some(_) = servers.next() => {}, |
| 155 | + socket = listener.accept() => { |
| 156 | + match socket { |
| 157 | + Ok(s) => { |
| 158 | + let (read, write) = socket_stream_split(s); |
| 159 | + servers.push(serve_stream(read, write, params.clone())); |
| 160 | + }, |
| 161 | + Err(e) => { |
| 162 | + error!(params.log, &format!("Error accepting connection: {}", e)); |
| 163 | + return Ok(1); |
| 164 | + } |
| 165 | + } |
| 166 | + }, |
| 167 | + _ = params.exit_barrier.wait() => { |
| 168 | + // wait for all servers to finish up: |
| 169 | + while (servers.next().await).is_some() { } |
| 170 | + return Ok(0); |
| 171 | + } |
| 172 | + } |
| 173 | + } |
129 | 174 | } |
130 | 175 |
|
131 | 176 | pub async fn service( |
@@ -412,8 +457,10 @@ async fn serve_with_csa( |
412 | 457 | let auth = Auth::new(&paths, log.clone()); |
413 | 458 | let mut dt = dev_tunnels::DevTunnels::new(&log, auth, &paths); |
414 | 459 | loop { |
415 | | - let tunnel = if let Some(d) = gateway_args.tunnel.clone().into() { |
416 | | - dt.start_existing_tunnel(d).await |
| 460 | + let tunnel = if let Some(t) = |
| 461 | + fulfill_existing_tunnel_args(gateway_args.tunnel.clone(), &gateway_args.name) |
| 462 | + { |
| 463 | + dt.start_existing_tunnel(t).await |
417 | 464 | } else { |
418 | 465 | dt.start_new_launcher_tunnel(gateway_args.name.as_deref(), gateway_args.random_name) |
419 | 466 | .await |
|
0 commit comments