@@ -456,14 +456,15 @@ where
456456
457457 pub async fn begin_execute (
458458 & self ,
459+ token : CancellationToken ,
459460 request : ExecuteRequest ,
460461 ) -> Result < ActiveExecution , ExecuteError > {
461462 use execute_error:: * ;
462463
463464 self . select_channel ( request. channel )
464465 . await
465466 . context ( CouldNotStartContainerSnafu ) ?
466- . begin_execute ( request)
467+ . begin_execute ( token , request)
467468 . await
468469 }
469470
@@ -482,14 +483,15 @@ where
482483
483484 pub async fn begin_compile (
484485 & self ,
486+ token : CancellationToken ,
485487 request : CompileRequest ,
486488 ) -> Result < ActiveCompilation , CompileError > {
487489 use compile_error:: * ;
488490
489491 self . select_channel ( request. channel )
490492 . await
491493 . context ( CouldNotStartContainerSnafu ) ?
492- . begin_compile ( request)
494+ . begin_compile ( token , request)
493495 . await
494496 }
495497
@@ -603,12 +605,14 @@ impl Container {
603605 & self ,
604606 request : ExecuteRequest ,
605607 ) -> Result < WithOutput < ExecuteResponse > , ExecuteError > {
608+ let token = Default :: default ( ) ;
609+
606610 let ActiveExecution {
607611 task,
608612 stdin_tx,
609613 stdout_rx,
610614 stderr_rx,
611- } = self . begin_execute ( request) . await ?;
615+ } = self . begin_execute ( token , request) . await ?;
612616
613617 drop ( stdin_tx) ;
614618 WithOutput :: try_absorb ( task, stdout_rx, stderr_rx) . await
@@ -617,6 +621,7 @@ impl Container {
617621 #[ instrument( skip_all) ]
618622 async fn begin_execute (
619623 & self ,
624+ token : CancellationToken ,
620625 request : ExecuteRequest ,
621626 ) -> Result < ActiveExecution , ExecuteError > {
622627 use execute_error:: * ;
@@ -642,7 +647,7 @@ impl Container {
642647 stdout_rx,
643648 stderr_rx,
644649 } = self
645- . spawn_cargo_task ( execute_cargo)
650+ . spawn_cargo_task ( token , execute_cargo)
646651 . await
647652 . context ( CouldNotStartCargoSnafu ) ?;
648653
@@ -673,18 +678,21 @@ impl Container {
673678 & self ,
674679 request : CompileRequest ,
675680 ) -> Result < WithOutput < CompileResponse > , CompileError > {
681+ let token = Default :: default ( ) ;
682+
676683 let ActiveCompilation {
677684 task,
678685 stdout_rx,
679686 stderr_rx,
680- } = self . begin_compile ( request) . await ?;
687+ } = self . begin_compile ( token , request) . await ?;
681688
682689 WithOutput :: try_absorb ( task, stdout_rx, stderr_rx) . await
683690 }
684691
685692 #[ instrument( skip_all) ]
686693 async fn begin_compile (
687694 & self ,
695+ token : CancellationToken ,
688696 request : CompileRequest ,
689697 ) -> Result < ActiveCompilation , CompileError > {
690698 use compile_error:: * ;
@@ -715,7 +723,7 @@ impl Container {
715723 stdout_rx,
716724 stderr_rx,
717725 } = self
718- . spawn_cargo_task ( execute_cargo)
726+ . spawn_cargo_task ( token , execute_cargo)
719727 . await
720728 . context ( CouldNotStartCargoSnafu ) ?;
721729
@@ -761,6 +769,7 @@ impl Container {
761769
762770 async fn spawn_cargo_task (
763771 & self ,
772+ token : CancellationToken ,
764773 execute_cargo : ExecuteCommandRequest ,
765774 ) -> Result < SpawnCargo , SpawnCargoError > {
766775 use spawn_cargo_error:: * ;
@@ -777,10 +786,19 @@ impl Container {
777786
778787 let task = tokio:: spawn ( {
779788 async move {
789+ let mut already_cancelled = false ;
780790 let mut stdin_open = true ;
781791
782792 loop {
783793 select ! {
794+ ( ) = token. cancelled( ) , if !already_cancelled => {
795+ already_cancelled = true ;
796+
797+ let msg = CoordinatorMessage :: Kill ;
798+ trace!( "processing {msg:?}" ) ;
799+ to_worker_tx. send( msg) . await . context( KillSnafu ) ?;
800+ } ,
801+
784802 stdin = stdin_rx. recv( ) , if stdin_open => {
785803 let msg = match stdin {
786804 Some ( stdin) => {
@@ -952,6 +970,9 @@ pub enum SpawnCargoError {
952970
953971 #[ snafu( display( "Unable to send stdin message" ) ) ]
954972 Stdin { source : MultiplexedSenderError } ,
973+
974+ #[ snafu( display( "Unable to send kill message" ) ) ]
975+ Kill { source : MultiplexedSenderError } ,
955976}
956977
957978#[ derive( Debug , Clone ) ]
@@ -1787,12 +1808,13 @@ mod tests {
17871808 ..ARBITRARY_EXECUTE_REQUEST
17881809 } ;
17891810
1811+ let token = Default :: default ( ) ;
17901812 let ActiveExecution {
17911813 task,
17921814 stdin_tx,
17931815 stdout_rx,
17941816 stderr_rx,
1795- } = coordinator. begin_execute ( request) . await . unwrap ( ) ;
1817+ } = coordinator. begin_execute ( token , request) . await . unwrap ( ) ;
17961818
17971819 stdin_tx. send ( "this is stdin\n " . into ( ) ) . await . unwrap ( ) ;
17981820 // Purposefully not dropping stdin_tx early -- a user might forget
@@ -1836,12 +1858,13 @@ mod tests {
18361858 ..ARBITRARY_EXECUTE_REQUEST
18371859 } ;
18381860
1861+ let token = Default :: default ( ) ;
18391862 let ActiveExecution {
18401863 task,
18411864 stdin_tx,
18421865 stdout_rx,
18431866 stderr_rx,
1844- } = coordinator. begin_execute ( request) . await . unwrap ( ) ;
1867+ } = coordinator. begin_execute ( token , request) . await . unwrap ( ) ;
18451868
18461869 for i in 0 ..3 {
18471870 stdin_tx. send ( format ! ( "line {i}\n " ) ) . await . unwrap ( ) ;
@@ -1870,6 +1893,62 @@ mod tests {
18701893 Ok ( ( ) )
18711894 }
18721895
1896+ #[ tokio:: test]
1897+ #[ snafu:: report]
1898+ async fn execute_kill ( ) -> Result < ( ) > {
1899+ let coordinator = new_coordinator ( ) . await ;
1900+
1901+ let request = ExecuteRequest {
1902+ code : r#"
1903+ fn main() {
1904+ println!("Before");
1905+ loop {
1906+ std::thread::sleep(std::time::Duration::from_secs(1));
1907+ }
1908+ println!("After");
1909+ }
1910+ "#
1911+ . into ( ) ,
1912+ ..ARBITRARY_EXECUTE_REQUEST
1913+ } ;
1914+
1915+ let token = CancellationToken :: new ( ) ;
1916+ let ActiveExecution {
1917+ task,
1918+ stdin_tx : _,
1919+ mut stdout_rx,
1920+ stderr_rx,
1921+ } = coordinator
1922+ . begin_execute ( token. clone ( ) , request)
1923+ . await
1924+ . unwrap ( ) ;
1925+
1926+ // Wait for some output before killing
1927+ let early_stdout = stdout_rx. recv ( ) . await . unwrap ( ) ;
1928+
1929+ token. cancel ( ) ;
1930+
1931+ let WithOutput {
1932+ response,
1933+ stdout,
1934+ stderr,
1935+ } = WithOutput :: try_absorb ( task, stdout_rx, stderr_rx)
1936+ . with_timeout ( )
1937+ . await
1938+ . unwrap ( ) ;
1939+
1940+ assert ! ( !response. success, "{stderr}" ) ;
1941+ assert_contains ! ( response. exit_detail, "kill" ) ;
1942+
1943+ assert_contains ! ( early_stdout, "Before" ) ;
1944+ assert_not_contains ! ( stdout, "Before" ) ;
1945+ assert_not_contains ! ( stdout, "After" ) ;
1946+
1947+ coordinator. shutdown ( ) . await ?;
1948+
1949+ Ok ( ( ) )
1950+ }
1951+
18731952 const HELLO_WORLD_CODE : & str = r#"fn main() { println!("Hello World!"); }"# ;
18741953
18751954 const ARBITRARY_COMPILE_REQUEST : CompileRequest = CompileRequest {
@@ -1914,11 +1993,12 @@ mod tests {
19141993 ..ARBITRARY_COMPILE_REQUEST
19151994 } ;
19161995
1996+ let token = Default :: default ( ) ;
19171997 let ActiveCompilation {
19181998 task,
19191999 stdout_rx,
19202000 stderr_rx,
1921- } = coordinator. begin_compile ( req) . await . unwrap ( ) ;
2001+ } = coordinator. begin_compile ( token , req) . await . unwrap ( ) ;
19222002
19232003 let WithOutput {
19242004 response,
0 commit comments