@@ -46,6 +46,7 @@ use tokio::{
4646 sync:: mpsc,
4747 task:: JoinSet ,
4848} ;
49+ use tokio_util:: sync:: CancellationToken ;
4950
5051use crate :: {
5152 bincode_input_closed,
@@ -57,25 +58,21 @@ use crate::{
5758 DropErrorDetailsExt ,
5859} ;
5960
60- type CommandRequest = ( Multiplexed < ExecuteCommandRequest > , MultiplexingSender ) ;
61-
6261pub async fn listen ( project_dir : impl Into < PathBuf > ) -> Result < ( ) , Error > {
6362 let project_dir = project_dir. into ( ) ;
6463
6564 let ( coordinator_msg_tx, coordinator_msg_rx) = mpsc:: channel ( 8 ) ;
6665 let ( worker_msg_tx, worker_msg_rx) = mpsc:: channel ( 8 ) ;
6766 let mut io_tasks = spawn_io_queue ( coordinator_msg_tx, worker_msg_rx) ;
6867
69- let ( cmd_tx, cmd_rx) = mpsc:: channel ( 8 ) ;
70- let ( stdin_tx, stdin_rx) = mpsc:: channel ( 8 ) ;
71- let process_task = tokio:: spawn ( manage_processes ( stdin_rx, cmd_rx, project_dir. clone ( ) ) ) ;
68+ let ( process_tx, process_rx) = mpsc:: channel ( 8 ) ;
69+ let process_task = tokio:: spawn ( manage_processes ( process_rx, project_dir. clone ( ) ) ) ;
7270
7371 let handler_task = tokio:: spawn ( handle_coordinator_message (
7472 coordinator_msg_rx,
7573 worker_msg_tx,
7674 project_dir,
77- cmd_tx,
78- stdin_tx,
75+ process_tx,
7976 ) ) ;
8077
8178 select ! {
@@ -122,8 +119,7 @@ async fn handle_coordinator_message(
122119 mut coordinator_msg_rx : mpsc:: Receiver < Multiplexed < CoordinatorMessage > > ,
123120 worker_msg_tx : mpsc:: Sender < Multiplexed < WorkerMessage > > ,
124121 project_dir : PathBuf ,
125- cmd_tx : mpsc:: Sender < CommandRequest > ,
126- stdin_tx : mpsc:: Sender < Multiplexed < String > > ,
122+ process_tx : mpsc:: Sender < Multiplexed < ProcessCommand > > ,
127123) -> Result < ( ) , HandleCoordinatorMessageError > {
128124 use handle_coordinator_message_error:: * ;
129125
@@ -177,20 +173,36 @@ async fn handle_coordinator_message(
177173 }
178174
179175 CoordinatorMessage :: ExecuteCommand ( req) => {
180- cmd_tx
181- . send( ( Multiplexed ( job_id, req) , worker_msg_tx( ) ) )
176+ process_tx
177+ . send( Multiplexed ( job_id, ProcessCommand :: Start ( req, worker_msg_tx( ) ) ) )
182178 . await
183179 . drop_error_details( )
184180 . context( UnableToSendCommandExecutionRequestSnafu ) ?;
185181 }
186182
187183 CoordinatorMessage :: StdinPacket ( data) => {
188- stdin_tx
189- . send( Multiplexed ( job_id, data) )
184+ process_tx
185+ . send( Multiplexed ( job_id, ProcessCommand :: Stdin ( data) ) )
190186 . await
191187 . drop_error_details( )
192188 . context( UnableToSendStdinPacketSnafu ) ?;
193189 }
190+
191+ CoordinatorMessage :: StdinClose => {
192+ process_tx
193+ . send( Multiplexed ( job_id, ProcessCommand :: StdinClose ) )
194+ . await
195+ . drop_error_details( )
196+ . context( UnableToSendStdinCloseSnafu ) ?;
197+ }
198+
199+ CoordinatorMessage :: Kill => {
200+ process_tx
201+ . send( Multiplexed ( job_id, ProcessCommand :: Kill ) )
202+ . await
203+ . drop_error_details( )
204+ . context( UnableToSendKillSnafu ) ?;
205+ }
194206 }
195207 }
196208
@@ -221,6 +233,12 @@ pub enum HandleCoordinatorMessageError {
221233 #[ snafu( display( "Failed to send stdin packet to the command task" ) ) ]
222234 UnableToSendStdinPacket { source : mpsc:: error:: SendError < ( ) > } ,
223235
236+ #[ snafu( display( "Failed to send stdin close request to the command task" ) ) ]
237+ UnableToSendStdinClose { source : mpsc:: error:: SendError < ( ) > } ,
238+
239+ #[ snafu( display( "Failed to send kill request to the command task" ) ) ]
240+ UnableToSendKill { source : mpsc:: error:: SendError < ( ) > } ,
241+
224242 #[ snafu( display( "A coordinator command handler background task panicked" ) ) ]
225243 TaskPanicked { source : tokio:: task:: JoinError } ,
226244}
@@ -373,63 +391,144 @@ fn parse_working_dir(cwd: Option<String>, project_path: impl Into<PathBuf>) -> P
373391 final_path
374392}
375393
394+ enum ProcessCommand {
395+ Start ( ExecuteCommandRequest , MultiplexingSender ) ,
396+ Stdin ( String ) ,
397+ StdinClose ,
398+ Kill ,
399+ }
400+
401+ struct ProcessState {
402+ project_path : PathBuf ,
403+ processes : JoinSet < Result < ( ) , ProcessError > > ,
404+ stdin_senders : HashMap < JobId , mpsc:: Sender < String > > ,
405+ stdin_shutdown_tx : mpsc:: Sender < JobId > ,
406+ kill_tokens : HashMap < JobId , CancellationToken > ,
407+ }
408+
409+ impl ProcessState {
410+ fn new ( project_path : PathBuf , stdin_shutdown_tx : mpsc:: Sender < JobId > ) -> Self {
411+ Self {
412+ project_path,
413+ processes : Default :: default ( ) ,
414+ stdin_senders : Default :: default ( ) ,
415+ stdin_shutdown_tx,
416+ kill_tokens : Default :: default ( ) ,
417+ }
418+ }
419+
420+ async fn start (
421+ & mut self ,
422+ job_id : JobId ,
423+ req : ExecuteCommandRequest ,
424+ worker_msg_tx : MultiplexingSender ,
425+ ) -> Result < ( ) , ProcessError > {
426+ use process_error:: * ;
427+
428+ let token = CancellationToken :: new ( ) ;
429+
430+ let RunningChild {
431+ child,
432+ stdin_rx,
433+ stdin,
434+ stdout,
435+ stderr,
436+ } = match process_begin ( req, & self . project_path , & mut self . stdin_senders , job_id) {
437+ Ok ( v) => v,
438+ Err ( e) => {
439+ // Should we add a message for process started
440+ // in addition to the current message which
441+ // indicates that the process has ended?
442+ worker_msg_tx
443+ . send_err ( e)
444+ . await
445+ . context ( UnableToSendExecuteCommandStartedResponseSnafu ) ?;
446+ return Ok ( ( ) ) ;
447+ }
448+ } ;
449+
450+ let task_set = stream_stdio ( worker_msg_tx. clone ( ) , stdin_rx, stdin, stdout, stderr) ;
451+
452+ self . kill_tokens . insert ( job_id, token. clone ( ) ) ;
453+
454+ self . processes . spawn ( {
455+ let stdin_shutdown_tx = self . stdin_shutdown_tx . clone ( ) ;
456+ async move {
457+ worker_msg_tx
458+ . send ( process_end ( token, child, task_set, stdin_shutdown_tx, job_id) . await )
459+ . await
460+ . context ( UnableToSendExecuteCommandResponseSnafu )
461+ }
462+ } ) ;
463+
464+ Ok ( ( ) )
465+ }
466+
467+ async fn stdin ( & mut self , job_id : JobId , packet : String ) -> Result < ( ) , ProcessError > {
468+ use process_error:: * ;
469+
470+ if let Some ( stdin_tx) = self . stdin_senders . get ( & job_id) {
471+ stdin_tx
472+ . send ( packet)
473+ . await
474+ . drop_error_details ( )
475+ . context ( UnableToSendStdinDataSnafu ) ?;
476+ }
477+
478+ Ok ( ( ) )
479+ }
480+
481+ fn stdin_close ( & mut self , job_id : JobId ) {
482+ self . stdin_senders . remove ( & job_id) ;
483+ // Should we care if we remove a sender that's already removed?
484+ }
485+
486+ async fn join_process ( & mut self ) -> Option < Result < ( ) , ProcessError > > {
487+ use process_error:: * ;
488+
489+ let process = self . processes . join_next ( ) . await ?;
490+ Some ( process. context ( ProcessTaskPanickedSnafu ) . and_then ( |e| e) )
491+ }
492+
493+ fn kill ( & mut self , job_id : JobId ) {
494+ if let Some ( token) = self . kill_tokens . get ( & job_id) {
495+ token. cancel ( ) ;
496+ }
497+ }
498+ }
499+
376500async fn manage_processes (
377- mut stdin_rx : mpsc:: Receiver < Multiplexed < String > > ,
378- mut cmd_rx : mpsc:: Receiver < CommandRequest > ,
501+ mut rx : mpsc:: Receiver < Multiplexed < ProcessCommand > > ,
379502 project_path : PathBuf ,
380503) -> Result < ( ) , ProcessError > {
381504 use process_error:: * ;
382505
383- let mut processes = JoinSet :: new ( ) ;
384- let mut stdin_senders = HashMap :: new ( ) ;
385506 let ( stdin_shutdown_tx, mut stdin_shutdown_rx) = mpsc:: channel ( 8 ) ;
507+ let mut state = ProcessState :: new ( project_path, stdin_shutdown_tx) ;
386508
387509 loop {
388510 select ! {
389- cmd_req = cmd_rx. recv( ) => {
390- let Some ( ( Multiplexed ( job_id, req) , worker_msg_tx) ) = cmd_req else { break } ;
391-
392- let RunningChild { child, stdin_rx, stdin, stdout, stderr } = match process_begin( req, & project_path, & mut stdin_senders, job_id) {
393- Ok ( v) => v,
394- Err ( e) => {
395- // Should we add a message for process started
396- // in addition to the current message which
397- // indicates that the process has ended?
398- worker_msg_tx. send_err( e) . await . context( UnableToSendExecuteCommandStartedResponseSnafu ) ?;
399- continue ;
400- }
401- } ;
511+ cmd = rx. recv( ) => {
512+ let Some ( Multiplexed ( job_id, cmd) ) = cmd else { break } ;
402513
403- let task_set = stream_stdio( worker_msg_tx. clone( ) , stdin_rx, stdin, stdout, stderr) ;
514+ match cmd {
515+ ProcessCommand :: Start ( req, worker_msg_tx) => state. start( job_id, req, worker_msg_tx) . await ?,
404516
405- processes. spawn( {
406- let stdin_shutdown_tx = stdin_shutdown_tx. clone( ) ;
407- async move {
408- worker_msg_tx
409- . send( process_end( child, task_set, stdin_shutdown_tx, job_id) . await )
410- . await
411- . context( UnableToSendExecuteCommandResponseSnafu )
412- }
413- } ) ;
414- }
517+ ProcessCommand :: Stdin ( packet) => state. stdin( job_id, packet) . await ?,
415518
416- stdin_packet = stdin_rx. recv( ) => {
417- // Dispatch stdin packet to different child by attached command id.
418- let Some ( Multiplexed ( job_id, packet) ) = stdin_packet else { break } ;
519+ ProcessCommand :: StdinClose => state. stdin_close( job_id) ,
419520
420- if let Some ( stdin_tx) = stdin_senders. get( & job_id) {
421- stdin_tx. send( packet) . await . drop_error_details( ) . context( UnableToSendStdinDataSnafu ) ?;
521+ ProcessCommand :: Kill => state. kill( job_id) ,
422522 }
423523 }
424524
425525 job_id = stdin_shutdown_rx. recv( ) => {
426526 let job_id = job_id. context( StdinShutdownReceiverEndedSnafu ) ?;
427- stdin_senders. remove( & job_id) ;
428- // Should we care if we remove a sender that's already removed?
527+ state. stdin_close( job_id) ;
429528 }
430529
431- Some ( process) = processes . join_next ( ) => {
432- process. context ( ProcessTaskPanickedSnafu ) ? ?;
530+ Some ( process) = state . join_process ( ) => {
531+ process?;
433532 }
434533 }
435534 }
@@ -488,13 +587,19 @@ fn process_begin(
488587}
489588
490589async fn process_end (
590+ token : CancellationToken ,
491591 mut child : Child ,
492592 mut task_set : JoinSet < Result < ( ) , StdioError > > ,
493593 stdin_shutdown_tx : mpsc:: Sender < JobId > ,
494594 job_id : JobId ,
495595) -> Result < ExecuteCommandResponse , ProcessError > {
496596 use process_error:: * ;
497597
598+ select ! {
599+ ( ) = token. cancelled( ) => child. kill( ) . await . context( KillChildSnafu ) ?,
600+ _ = child. wait( ) => { } ,
601+ } ;
602+
498603 let status = child. wait ( ) . await . context ( WaitChildSnafu ) ?;
499604
500605 stdin_shutdown_tx
@@ -634,6 +739,9 @@ pub enum ProcessError {
634739 #[ snafu( display( "Failed to send stdin data" ) ) ]
635740 UnableToSendStdinData { source : mpsc:: error:: SendError < ( ) > } ,
636741
742+ #[ snafu( display( "Failed to kill the child process" ) ) ]
743+ KillChild { source : std:: io:: Error } ,
744+
637745 #[ snafu( display( "Failed to wait for child process exiting" ) ) ]
638746 WaitChild { source : std:: io:: Error } ,
639747
@@ -671,10 +779,7 @@ fn stream_stdio(
671779 let mut set = JoinSet :: new ( ) ;
672780
673781 set. spawn ( async move {
674- loop {
675- let Some ( data) = stdin_rx. recv ( ) . await else {
676- break ;
677- } ;
782+ while let Some ( data) = stdin_rx. recv ( ) . await {
678783 stdin
679784 . write_all ( data. as_bytes ( ) )
680785 . await
0 commit comments