Skip to content

Commit a91104e

Browse files
committed
WIP fix cancellation
1 parent bbb166b commit a91104e

File tree

8 files changed

+465
-112
lines changed

8 files changed

+465
-112
lines changed

crates/chat-cli/src/cli/acp/client_connection.rs

Lines changed: 43 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,9 @@ pub(super) enum ClientConnectionMethod {
124124
// This ensures that the prompt termination is ordered with respect
125125
// to the other notifications that are routed to that same session.
126126
Prompt(acp::PromptRequest),
127-
127+
128128
#[allow(dead_code)] // Will be used when client-side cancellation is implemented
129-
Cancel(
130-
acp::CancelNotification,
131-
oneshot::Sender<Result<(), acp::Error>>,
132-
),
129+
Cancel(acp::CancelNotification, oneshot::Sender<Result<(), acp::Error>>),
133130
}
134131

135132
impl AcpClientConnectionHandle {
@@ -159,6 +156,7 @@ impl AcpClientConnectionHandle {
159156
acp::ClientSideConnection::new(callbacks, outgoing_bytes, incoming_bytes, |fut| {
160157
tokio::task::spawn_local(fut);
161158
});
159+
let client_conn = Arc::new(client_conn);
162160

163161
// Start the client I/O handler
164162
tokio::task::spawn_local(async move {
@@ -168,42 +166,55 @@ impl AcpClientConnectionHandle {
168166
});
169167

170168
while let Some(method) = client_rx.recv().await {
171-
tracing::debug!(actor="client_connection", event="message received", ?method);
169+
tracing::debug!(actor = "client_connection", event = "message received", ?method);
172170

173171
match method {
174172
ClientConnectionMethod::Initialize(initialize_request, sender) => {
175173
let response = client_conn.initialize(initialize_request).await;
176-
tracing::debug!(actor="client_connection", event="sending response", ?response);
174+
tracing::debug!(actor = "client_connection", event = "sending response", ?response);
177175
ignore_error(sender.send(response));
178176
},
179177
ClientConnectionMethod::NewSession(new_session_request, sender) => {
180178
match client_conn.new_session(new_session_request).await {
181179
Ok(session_info) => {
182-
let result = AcpClientSessionHandle::new(
183-
session_info,
184-
&client_dispatch,
185-
client_tx.clone(),
186-
)
187-
.await
188-
.map_err(|_err| acp::Error::internal_error());
189-
tracing::debug!(actor="client_connection", event="sending response", ?result);
180+
let result =
181+
AcpClientSessionHandle::new(session_info, &client_dispatch, client_tx.clone())
182+
.await
183+
.map_err(|_err| acp::Error::internal_error());
184+
tracing::debug!(actor = "client_connection", event = "sending response", ?result);
190185
ignore_error(sender.send(result));
191186
},
192187
Err(err) => {
193-
tracing::debug!(actor="client_connection", event="sending response", ?err);
188+
tracing::debug!(actor = "client_connection", event = "sending response", ?err);
194189
ignore_error(sender.send(Err(err)));
195190
},
196191
}
197192
},
198193
ClientConnectionMethod::Prompt(prompt_request) => {
199194
let session_id = prompt_request.session_id.clone();
200-
let response = client_conn.prompt(prompt_request).await;
201-
tracing::debug!(actor="client_connection", event="sending response", ?session_id, ?response);
202-
client_dispatch.client_callback(ClientCallback::PromptResponse(session_id, response));
195+
196+
// Spawn off the call to prompt so it runs concurrently.
197+
//
198+
// This way if the user tries to cancel, that message can be received
199+
// and sent to the server. That will cause the server to cancel this prompt call.
200+
tokio::task::spawn_local({
201+
let client_conn = client_conn.clone();
202+
let client_dispatch = client_dispatch.clone();
203+
async move {
204+
let response = client_conn.prompt(prompt_request).await;
205+
tracing::debug!(
206+
actor = "client_connection",
207+
event = "sending response",
208+
?session_id,
209+
?response
210+
);
211+
client_dispatch.client_callback(ClientCallback::PromptResponse(session_id, response));
212+
}
213+
});
203214
},
204215
ClientConnectionMethod::Cancel(cancel_notification, sender) => {
205216
let response = client_conn.cancel(cancel_notification).await;
206-
tracing::debug!(actor="client_connection", event="sending response", ?response);
217+
tracing::debug!(actor = "client_connection", event = "sending response", ?response);
207218
ignore_error(sender.send(response));
208219
},
209220
}
@@ -229,12 +240,10 @@ impl AcpClientConnectionHandle {
229240
Ok(rx.await??)
230241
}
231242

232-
#[allow(dead_code)] // Will be used when client-side cancellation is implemented
243+
#[cfg_attr(not(test), allow(dead_code))] // Will be used when client-side cancellation is implemented
233244
pub async fn cancel(&self, args: acp::CancelNotification) -> Result<()> {
234245
let (tx, rx) = tokio::sync::oneshot::channel();
235-
self.client_tx
236-
.send(ClientConnectionMethod::Cancel(args, tx))
237-
.await?;
246+
self.client_tx.send(ClientConnectionMethod::Cancel(args, tx)).await?;
238247
Ok(rx.await??)
239248
}
240249
}
@@ -258,7 +267,10 @@ impl acp::Client for AcpClientForward {
258267
todo!()
259268
}
260269

261-
async fn write_text_file(&self, _args: acp::WriteTextFileRequest) -> Result<acp::WriteTextFileResponse, acp::Error> {
270+
async fn write_text_file(
271+
&self,
272+
_args: acp::WriteTextFileRequest,
273+
) -> Result<acp::WriteTextFileResponse, acp::Error> {
262274
todo!()
263275
}
264276

@@ -267,12 +279,16 @@ impl acp::Client for AcpClientForward {
267279
}
268280

269281
async fn session_notification(&self, args: acp::SessionNotification) -> Result<(), acp::Error> {
270-
tracing::debug!(actor="client_connection", event="session_notification", ?args);
282+
tracing::debug!(actor = "client_connection", event = "session_notification", ?args);
271283
let (tx, rx) = oneshot::channel();
272284
self.client_dispatch
273285
.client_callback(ClientCallback::Notification(args, tx));
274286
let result = rx.await;
275-
tracing::debug!(actor="client_connection", event="session_notification complete", ?result);
287+
tracing::debug!(
288+
actor = "client_connection",
289+
event = "session_notification complete",
290+
?result
291+
);
276292
result.map_err(acp::Error::into_internal_error)?
277293
}
278294

@@ -342,4 +358,3 @@ impl ClientCallback {
342358
}
343359
}
344360
}
345-

crates/chat-cli/src/cli/acp/client_session.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ impl AcpClientSessionHandle {
3535
})
3636
}
3737

38+
/// Get the session ID for this session
39+
pub fn session_id(&self) -> &acp::SessionId {
40+
&self.session_info.session_id
41+
}
42+
3843
/// Send a message to the agent and read the complete response
3944
pub async fn prompt(&mut self, message: impl IntoPrompt) -> Result<String> {
4045
// Construct the prompt
@@ -53,6 +58,7 @@ impl AcpClientSessionHandle {
5358
// Read notifications until we get the prompt response, then we can return.
5459
let mut response_text = String::new();
5560
while let Some(client_callback) = self.callback_rx.recv().await {
61+
tracing::debug!(actor="client_session", event="callback received", "session_id"=?self.session_info.session_id, ?client_callback);
5662
match client_callback {
5763
ClientCallback::Notification(notification, tx) => {
5864
self.handle_notification(notification, tx, &mut response_text)

crates/chat-cli/src/cli/acp/server.rs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use agent_client_protocol as acp;
77
use serde_json::value::RawValue;
88
use tokio::sync::{mpsc, oneshot};
99

10-
use crate::os::Os;
10+
use crate::{cli::acp::util::ignore_error, os::Os};
1111
use super::{server_session::AcpServerSessionHandle, server_connection::AcpServerConnectionHandle};
1212

1313
/// Convert channel errors to ACP errors
@@ -30,6 +30,7 @@ pub struct AcpServerHandle {
3030
/// Each variant contains:
3131
/// - Request parameters (the input)
3232
/// - oneshot::Sender (the "return address" where the actor sends the response back)
33+
#[derive(Debug)]
3334
enum ServerMethod {
3435
Initialize(acp::InitializeRequest, oneshot::Sender<Result<acp::InitializeResponse, acp::Error>>),
3536
Authenticate(acp::AuthenticateRequest, oneshot::Sender<Result<acp::AuthenticateResponse, acp::Error>>),
@@ -50,6 +51,7 @@ impl AcpServerHandle {
5051
let mut sessions: HashMap<String, AcpServerSessionHandle> = HashMap::new();
5152

5253
while let Some(method) = server_rx.recv().await {
54+
tracing::debug!(actor="server", event="method call received", ?method);
5355
match method {
5456
ServerMethod::Initialize(args, tx) => {
5557
let response = Self::handle_initialize(args).await;
@@ -87,11 +89,7 @@ impl AcpServerHandle {
8789
}
8890
}
8991
ServerMethod::Prompt(args, tx) => {
90-
let response = Self::handle_prompt(args, &sessions).await;
91-
if tx.send(response).is_err() {
92-
tracing::debug!(actor="server", event="response receiver dropped", method="prompt");
93-
break;
94-
}
92+
Self::handle_prompt(args, tx, &sessions).await;
9593
}
9694
ServerMethod::Cancel(args, tx) => {
9795
let response = Self::handle_cancel(args, &sessions).await;
@@ -267,17 +265,20 @@ impl AcpServerHandle {
267265

268266
async fn handle_prompt(
269267
args: acp::PromptRequest,
268+
prompt_tx: oneshot::Sender<Result<acp::PromptResponse, acp::Error>>,
270269
sessions: &HashMap<String, AcpServerSessionHandle>,
271-
) -> Result<acp::PromptResponse, acp::Error> {
270+
) {
272271
let session_id = args.session_id.0.as_ref();
273272

274273
// Find the session actor
275274
if let Some(session_handle) = sessions.get(session_id) {
276-
// Forward to session actor
277-
session_handle.prompt(args).await
275+
// Forward to session actor. Importantly, this actor is responsible
276+
// for sending the final result from the prompt to `prompt_tx` -- we just
277+
// return immediately.
278+
session_handle.prompt(args, prompt_tx).await;
278279
} else {
279280
tracing::warn!("Session not found for prompt: {}", session_id);
280-
Err(acp::Error::invalid_params())
281+
ignore_error(prompt_tx.send(Err(acp::Error::invalid_params())))
281282
}
282283
}
283284

crates/chat-cli/src/cli/acp/server_connection.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,16 @@ impl AcpServerConnectionHandle {
5555

5656
// Launch the "transport actor", which owns the connection.
5757
tokio::task::spawn_local(async move {
58-
tracing::debug!(actor="transport", event="started");
58+
tracing::debug!(actor="server_connection", event="started");
5959

6060
while let Some(method) = transport_rx.recv().await {
61-
tracing::debug!(actor="transport", event="message received", ?method);
61+
tracing::debug!(actor="server_connection", event="message received", ?method);
6262
match method {
6363
TransportMethod::SessionNotification(notification, tx) => {
6464
let result = connection.session_notification(notification).await;
65-
tracing::debug!(actor="transport", event="notification delivered");
65+
tracing::debug!(actor="server_connection", event="notification delivered");
6666
if tx.send(result).is_err() {
67-
tracing::debug!(actor="transport", event="response receiver dropped");
67+
tracing::debug!(actor="server_connection", event="response receiver dropped");
6868
}
6969
},
7070
}

0 commit comments

Comments
 (0)