@@ -2,7 +2,7 @@ use crate::{
22 metrics:: { DURATION_WS , LIVE_WS } ,
33 parse_channel, parse_crate_type, parse_edition, parse_mode,
44 sandbox:: { self , Sandbox } ,
5- Error , ExecutionSnafu , Result , SandboxCreationSnafu ,
5+ Error , ExecutionSnafu , Result , SandboxCreationSnafu , WebSocketTaskPanicSnafu ,
66} ;
77
88use axum:: extract:: ws:: { Message , WebSocket } ;
@@ -11,7 +11,7 @@ use std::{
1111 convert:: { TryFrom , TryInto } ,
1212 time:: Instant ,
1313} ;
14- use tokio:: sync:: mpsc;
14+ use tokio:: { sync:: mpsc, task :: JoinSet } ;
1515
1616#[ derive( serde:: Deserialize ) ]
1717#[ serde( tag = "type" ) ]
@@ -109,6 +109,9 @@ pub async fn handle(mut socket: WebSocket) {
109109 let start = Instant :: now ( ) ;
110110
111111 let ( tx, mut rx) = mpsc:: channel ( 3 ) ;
112+ let mut tasks = JoinSet :: new ( ) ;
113+
114+ // TODO: Implement some kind of timeout to shutdown running work?
112115
113116 loop {
114117 tokio:: select! {
@@ -118,7 +121,7 @@ pub async fn handle(mut socket: WebSocket) {
118121 // browser disconnected
119122 break ;
120123 }
121- Some ( Ok ( Message :: Text ( txt) ) ) => handle_msg( txt, & tx) . await ,
124+ Some ( Ok ( Message :: Text ( txt) ) ) => handle_msg( txt, & tx, & mut tasks ) . await ,
122125 Some ( Ok ( _) ) => {
123126 // unknown message type
124127 continue ;
@@ -128,10 +131,31 @@ pub async fn handle(mut socket: WebSocket) {
128131 } ,
129132 resp = rx. recv( ) => {
130133 let resp = resp. expect( "The rx should never close as we have a tx" ) ;
131- let resp = resp. unwrap_or_else( |e| WSMessageResponse :: Error ( WSError { error: e. to_string( ) } ) ) ;
132- const LAST_CHANCE_ERROR : & str = r#"{ "type": "WEBSOCKET_ERROR", "error": "Unable to serialize JSON" }"# ;
133- let resp = serde_json:: to_string( & resp) . unwrap_or_else( |_| LAST_CHANCE_ERROR . into( ) ) ;
134- let resp = Message :: Text ( resp) ;
134+ let resp = resp. unwrap_or_else( error_to_response) ;
135+ let resp = response_to_message( resp) ;
136+
137+ if let Err ( _) = socket. send( resp) . await {
138+ // We can't send a response
139+ break ;
140+ }
141+ } ,
142+ // We don't care if there are no running tasks
143+ Some ( task) = tasks. join_next( ) => {
144+ let Err ( error) = task else { continue } ;
145+ // The task was cancelled; no need to report
146+ let Ok ( panic) = error. try_into_panic( ) else { continue } ;
147+
148+ let text = match panic. downcast:: <String >( ) {
149+ Ok ( text) => * text,
150+ Err ( panic) => match panic. downcast:: <& str >( ) {
151+ Ok ( text) => text. to_string( ) ,
152+ _ => "An unknown panic occurred" . into( ) ,
153+ }
154+ } ;
155+ let error = WebSocketTaskPanicSnafu { text } . build( ) ;
156+
157+ let resp = error_to_response( error) ;
158+ let resp = response_to_message( resp) ;
135159
136160 if let Err ( _) = socket. send( resp) . await {
137161 // We can't send a response
@@ -141,22 +165,42 @@ pub async fn handle(mut socket: WebSocket) {
141165 }
142166 }
143167
168+ drop ( ( tx, rx, socket) ) ;
169+ tasks. shutdown ( ) . await ;
170+
144171 LIVE_WS . dec ( ) ;
145172 let elapsed = start. elapsed ( ) ;
146173 DURATION_WS . observe ( elapsed. as_secs_f64 ( ) ) ;
147174}
148175
149- async fn handle_msg ( txt : String , tx : & mpsc:: Sender < Result < WSMessageResponse > > ) {
176+ fn error_to_response ( error : Error ) -> WSMessageResponse {
177+ let error = error. to_string ( ) ;
178+ WSMessageResponse :: Error ( WSError { error } )
179+ }
180+
181+ fn response_to_message ( response : WSMessageResponse ) -> Message {
182+ const LAST_CHANCE_ERROR : & str =
183+ r#"{ "type": "WEBSOCKET_ERROR", "error": "Unable to serialize JSON" }"# ;
184+ let resp = serde_json:: to_string ( & response) . unwrap_or_else ( |_| LAST_CHANCE_ERROR . into ( ) ) ;
185+ Message :: Text ( resp)
186+ }
187+
188+ async fn handle_msg (
189+ txt : String ,
190+ tx : & mpsc:: Sender < Result < WSMessageResponse > > ,
191+ tasks : & mut JoinSet < Result < ( ) > > ,
192+ ) {
150193 use WSMessageRequest :: * ;
151194
152195 let msg = serde_json:: from_str ( & txt) . context ( crate :: DeserializationSnafu ) ;
153196
154197 match msg {
155198 Ok ( WSExecuteRequest ( req) ) => {
156199 let tx = tx. clone ( ) ;
157- tokio :: spawn ( async move {
200+ tasks . spawn ( async move {
158201 let resp = handle_execute ( req) . await ;
159202 tx. send ( resp) . await . ok ( /* We don't care if the channel is closed */ ) ;
203+ Ok ( ( ) )
160204 } ) ;
161205 }
162206 Err ( e) => {
0 commit comments