@@ -124,12 +124,9 @@ pub(super) enum ClientConnectionMethod {
124124 // This ensures that the prompt termination is ordered with respect
125125 // to the other notifications that are routed to that same session.
126126 Prompt ( acp:: PromptRequest ) ,
127-
127+
128128 #[ allow( dead_code) ] // Will be used when client-side cancellation is implemented
129- Cancel (
130- acp:: CancelNotification ,
131- oneshot:: Sender < Result < ( ) , acp:: Error > > ,
132- ) ,
129+ Cancel ( acp:: CancelNotification , oneshot:: Sender < Result < ( ) , acp:: Error > > ) ,
133130}
134131
135132impl AcpClientConnectionHandle {
@@ -159,6 +156,7 @@ impl AcpClientConnectionHandle {
159156 acp:: ClientSideConnection :: new ( callbacks, outgoing_bytes, incoming_bytes, |fut| {
160157 tokio:: task:: spawn_local ( fut) ;
161158 } ) ;
159+ let client_conn = Arc :: new ( client_conn) ;
162160
163161 // Start the client I/O handler
164162 tokio:: task:: spawn_local ( async move {
@@ -168,42 +166,55 @@ impl AcpClientConnectionHandle {
168166 } ) ;
169167
170168 while let Some ( method) = client_rx. recv ( ) . await {
171- tracing:: debug!( actor= "client_connection" , event= "message received" , ?method) ;
169+ tracing:: debug!( actor = "client_connection" , event = "message received" , ?method) ;
172170
173171 match method {
174172 ClientConnectionMethod :: Initialize ( initialize_request, sender) => {
175173 let response = client_conn. initialize ( initialize_request) . await ;
176- tracing:: debug!( actor= "client_connection" , event= "sending response" , ?response) ;
174+ tracing:: debug!( actor = "client_connection" , event = "sending response" , ?response) ;
177175 ignore_error ( sender. send ( response) ) ;
178176 } ,
179177 ClientConnectionMethod :: NewSession ( new_session_request, sender) => {
180178 match client_conn. new_session ( new_session_request) . await {
181179 Ok ( session_info) => {
182- let result = AcpClientSessionHandle :: new (
183- session_info,
184- & client_dispatch,
185- client_tx. clone ( ) ,
186- )
187- . await
188- . map_err ( |_err| acp:: Error :: internal_error ( ) ) ;
189- tracing:: debug!( actor="client_connection" , event="sending response" , ?result) ;
180+ let result =
181+ AcpClientSessionHandle :: new ( session_info, & client_dispatch, client_tx. clone ( ) )
182+ . await
183+ . map_err ( |_err| acp:: Error :: internal_error ( ) ) ;
184+ tracing:: debug!( actor = "client_connection" , event = "sending response" , ?result) ;
190185 ignore_error ( sender. send ( result) ) ;
191186 } ,
192187 Err ( err) => {
193- tracing:: debug!( actor= "client_connection" , event= "sending response" , ?err) ;
188+ tracing:: debug!( actor = "client_connection" , event = "sending response" , ?err) ;
194189 ignore_error ( sender. send ( Err ( err) ) ) ;
195190 } ,
196191 }
197192 } ,
198193 ClientConnectionMethod :: Prompt ( prompt_request) => {
199194 let session_id = prompt_request. session_id . clone ( ) ;
200- let response = client_conn. prompt ( prompt_request) . await ;
201- tracing:: debug!( actor="client_connection" , event="sending response" , ?session_id, ?response) ;
202- client_dispatch. client_callback ( ClientCallback :: PromptResponse ( session_id, response) ) ;
195+
196+ // Spawn off the call to prompt so it runs concurrently.
197+ //
198+ // This way if the user tries to cancel, that message can be received
199+ // and sent to the server. That will cause the server to cancel this prompt call.
200+ tokio:: task:: spawn_local ( {
201+ let client_conn = client_conn. clone ( ) ;
202+ let client_dispatch = client_dispatch. clone ( ) ;
203+ async move {
204+ let response = client_conn. prompt ( prompt_request) . await ;
205+ tracing:: debug!(
206+ actor = "client_connection" ,
207+ event = "sending response" ,
208+ ?session_id,
209+ ?response
210+ ) ;
211+ client_dispatch. client_callback ( ClientCallback :: PromptResponse ( session_id, response) ) ;
212+ }
213+ } ) ;
203214 } ,
204215 ClientConnectionMethod :: Cancel ( cancel_notification, sender) => {
205216 let response = client_conn. cancel ( cancel_notification) . await ;
206- tracing:: debug!( actor= "client_connection" , event= "sending response" , ?response) ;
217+ tracing:: debug!( actor = "client_connection" , event = "sending response" , ?response) ;
207218 ignore_error ( sender. send ( response) ) ;
208219 } ,
209220 }
@@ -229,12 +240,10 @@ impl AcpClientConnectionHandle {
229240 Ok ( rx. await ??)
230241 }
231242
232- #[ allow( dead_code) ] // Will be used when client-side cancellation is implemented
243+ #[ cfg_attr ( not ( test ) , allow( dead_code) ) ] // Will be used when client-side cancellation is implemented
233244 pub async fn cancel ( & self , args : acp:: CancelNotification ) -> Result < ( ) > {
234245 let ( tx, rx) = tokio:: sync:: oneshot:: channel ( ) ;
235- self . client_tx
236- . send ( ClientConnectionMethod :: Cancel ( args, tx) )
237- . await ?;
246+ self . client_tx . send ( ClientConnectionMethod :: Cancel ( args, tx) ) . await ?;
238247 Ok ( rx. await ??)
239248 }
240249}
@@ -258,7 +267,10 @@ impl acp::Client for AcpClientForward {
258267 todo ! ( )
259268 }
260269
261- async fn write_text_file ( & self , _args : acp:: WriteTextFileRequest ) -> Result < acp:: WriteTextFileResponse , acp:: Error > {
270+ async fn write_text_file (
271+ & self ,
272+ _args : acp:: WriteTextFileRequest ,
273+ ) -> Result < acp:: WriteTextFileResponse , acp:: Error > {
262274 todo ! ( )
263275 }
264276
@@ -267,12 +279,16 @@ impl acp::Client for AcpClientForward {
267279 }
268280
269281 async fn session_notification ( & self , args : acp:: SessionNotification ) -> Result < ( ) , acp:: Error > {
270- tracing:: debug!( actor= "client_connection" , event= "session_notification" , ?args) ;
282+ tracing:: debug!( actor = "client_connection" , event = "session_notification" , ?args) ;
271283 let ( tx, rx) = oneshot:: channel ( ) ;
272284 self . client_dispatch
273285 . client_callback ( ClientCallback :: Notification ( args, tx) ) ;
274286 let result = rx. await ;
275- tracing:: debug!( actor="client_connection" , event="session_notification complete" , ?result) ;
287+ tracing:: debug!(
288+ actor = "client_connection" ,
289+ event = "session_notification complete" ,
290+ ?result
291+ ) ;
276292 result. map_err ( acp:: Error :: into_internal_error) ?
277293 }
278294
@@ -342,4 +358,3 @@ impl ClientCallback {
342358 }
343359 }
344360}
345-
0 commit comments