diff --git a/Cargo.lock b/Cargo.lock index 74a6c9f8e5..33cadbd212 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1207,7 +1207,7 @@ checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" [[package]] name = "chat_cli" -version = "1.16.3" +version = "1.17.0" dependencies = [ "amzn-codewhisperer-client", "amzn-codewhisperer-streaming-client", @@ -5256,9 +5256,9 @@ dependencies = [ [[package]] name = "rmcp" -version = "0.6.3" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c7dd163d26e254725137b7933e4ba042ea6bf2d756a4260559aaea8b6ad4c27e" +checksum = "534fd1cd0601e798ac30545ff2b7f4a62c6f14edd4aaed1cc5eb1e85f69f09af" dependencies = [ "base64 0.22.1", "chrono", @@ -5285,9 +5285,9 @@ dependencies = [ [[package]] name = "rmcp-macros" -version = "0.6.3" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a43bb4c90a0d4b12f7315eb681a73115d335a2cee81322eca96f3467fe4cd06f" +checksum = "9ba777eb0e5f53a757e36f0e287441da0ab766564ba7201600eeb92a4753022e" dependencies = [ "darling 0.21.3", "proc-macro2", @@ -5647,7 +5647,7 @@ dependencies = [ [[package]] name = "semantic_search_client" -version = "1.16.3" +version = "1.17.0" dependencies = [ "anyhow", "bm25", diff --git a/Cargo.toml b/Cargo.toml index d2fb418eb6..65cc611f86 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ authors = ["Amazon Q CLI Team (q-cli@amazon.com)", "Chay Nabors (nabochay@amazon edition = "2024" homepage = "https://aws.amazon.com/q/" publish = false -version = "1.16.3" +version = "1.17.0" license = "MIT OR Apache-2.0" [workspace.dependencies] @@ -129,7 +129,7 @@ winnow = "=0.6.2" winreg = "0.55.0" schemars = "1.0.4" jsonschema = "0.30.0" -rmcp = { version = "0.6.3", features = ["client", "transport-sse-client-reqwest", "reqwest", "transport-streamable-http-client-reqwest", "transport-child-process", "tower", "auth"] } +rmcp = { version = "0.7.0", features = ["client", "transport-sse-client-reqwest", "reqwest", "transport-streamable-http-client-reqwest", "transport-child-process", "tower", "auth"] } [workspace.lints.rust] future_incompatible = "warn" diff --git a/crates/chat-cli/src/cli/chat/checkpoint.rs b/crates/chat-cli/src/cli/chat/checkpoint.rs index c5fb0b8183..a63a7fcf32 100644 --- a/crates/chat-cli/src/cli/chat/checkpoint.rs +++ b/crates/chat-cli/src/cli/chat/checkpoint.rs @@ -25,6 +25,7 @@ use serde::{ Deserialize, Serialize, }; +use tracing::debug; use crate::cli::ConversationState; use crate::cli::chat::conversation::HistoryEntry; @@ -36,6 +37,9 @@ pub struct CheckpointManager { /// Path to the shadow (bare) git repository pub shadow_repo_path: PathBuf, + /// Path to current working directory + pub work_tree_path: PathBuf, + /// All checkpoints in chronological order pub checkpoints: Vec, @@ -84,10 +88,10 @@ impl CheckpointManager { current_history: &VecDeque, ) -> Result { if !is_git_installed() { - bail!("Git is not installed. Checkpoints require git to function."); + bail!("Checkpoints are not available. Git is required but not installed."); } if !is_in_git_repo() { - bail!("Not in a git repository. Use '/checkpoint init' to manually enable checkpoints."); + bail!("Checkpoints are not available in this directory. Use '/checkpoint init' to enable checkpoints."); } let manager = Self::manual_init(os, shadow_path, current_history).await?; @@ -103,14 +107,17 @@ impl CheckpointManager { let path = path.as_ref(); os.fs.create_dir_all(path).await?; + let work_tree_path = + std::env::current_dir().map_err(|e| eyre!("Failed to get current working directory: {}", e))?; + // Initialize bare repository - run_git(path, false, &["init", "--bare", &path.to_string_lossy()])?; + run_git(path, None, &["init", "--bare", &path.to_string_lossy()])?; // Configure git configure_git(&path.to_string_lossy())?; // Create initial checkpoint - stage_commit_tag(&path.to_string_lossy(), "Initial state", "0")?; + stage_commit_tag(&path.to_string_lossy(), &work_tree_path, "Initial state", "0")?; let initial_checkpoint = Checkpoint { tag: "0".to_string(), @@ -126,6 +133,7 @@ impl CheckpointManager { Ok(Self { shadow_repo_path: path.to_path_buf(), + work_tree_path, checkpoints: vec![initial_checkpoint], tag_index, current_turn: 0, @@ -146,7 +154,12 @@ impl CheckpointManager { tool_name: Option, ) -> Result<()> { // Stage, commit and tag - stage_commit_tag(&self.shadow_repo_path.to_string_lossy(), description, tag)?; + stage_commit_tag( + &self.shadow_repo_path.to_string_lossy(), + &self.work_tree_path, + description, + tag, + )?; // Record checkpoint metadata let checkpoint = Checkpoint { @@ -175,9 +188,14 @@ impl CheckpointManager { if hard { // Hard: reset the whole work-tree to the tag - let output = run_git(&self.shadow_repo_path, true, &["reset", "--hard", tag])?; + let output = run_git(&self.shadow_repo_path, Some(&self.work_tree_path), &[ + "reset", "--hard", tag, + ])?; if !output.status.success() { - bail!("Failed to restore: {}", String::from_utf8_lossy(&output.stderr)); + bail!( + "Failed to restore checkpoint: {}", + String::from_utf8_lossy(&output.stderr) + ); } } else { // Soft: only restore tracked files. If the tag is an empty tree, this is a no-op. @@ -187,9 +205,14 @@ impl CheckpointManager { return Ok(()); } // Use checkout against work-tree - let output = run_git(&self.shadow_repo_path, true, &["checkout", tag, "--", "."])?; + let output = run_git(&self.shadow_repo_path, Some(&self.work_tree_path), &[ + "checkout", tag, "--", ".", + ])?; if !output.status.success() { - bail!("Failed to restore: {}", String::from_utf8_lossy(&output.stderr)); + bail!( + "Failed to restore checkpoint: {}", + String::from_utf8_lossy(&output.stderr) + ); } } @@ -205,7 +228,7 @@ impl CheckpointManager { let out = run_git( &self.shadow_repo_path, // work_tree - false, + None, &["ls-tree", "-r", "--name-only", tag], )?; Ok(!out.stdout.is_empty()) @@ -223,7 +246,7 @@ impl CheckpointManager { /// Compute file statistics between two checkpoints pub fn compute_stats_between(&self, from: &str, to: &str) -> Result { - let output = run_git(&self.shadow_repo_path, false, &["diff", "--name-status", from, to])?; + let output = run_git(&self.shadow_repo_path, None, &["diff", "--name-status", from, to])?; let mut stats = FileStats::default(); for line in String::from_utf8_lossy(&output.stdout).lines() { @@ -246,7 +269,7 @@ impl CheckpointManager { let mut result = String::new(); // Get file changes - let output = run_git(&self.shadow_repo_path, false, &["diff", "--name-status", from, to])?; + let output = run_git(&self.shadow_repo_path, None, &["diff", "--name-status", from, to])?; for line in String::from_utf8_lossy(&output.stdout).lines() { if let Some((status, file)) = line.split_once('\t') { @@ -261,7 +284,7 @@ impl CheckpointManager { } // Add statistics - let stat_output = run_git(&self.shadow_repo_path, false, &[ + let stat_output = run_git(&self.shadow_repo_path, None, &[ "diff", from, to, @@ -279,7 +302,10 @@ impl CheckpointManager { /// Check for uncommitted changes pub fn has_changes(&self) -> Result { - let output = run_git(&self.shadow_repo_path, true, &["status", "--porcelain"])?; + let output = run_git(&self.shadow_repo_path, Some(&self.work_tree_path), &[ + "status", + "--porcelain", + ])?; Ok(!output.stdout.is_empty()) } @@ -351,18 +377,18 @@ fn is_in_git_repo() -> bool { } fn configure_git(shadow_path: &str) -> Result<()> { - run_git(Path::new(shadow_path), false, &["config", "user.name", "Q"])?; - run_git(Path::new(shadow_path), false, &["config", "user.email", "qcli@local"])?; - run_git(Path::new(shadow_path), false, &["config", "core.preloadindex", "true"])?; + run_git(Path::new(shadow_path), None, &["config", "user.name", "Q"])?; + run_git(Path::new(shadow_path), None, &["config", "user.email", "qcli@local"])?; + run_git(Path::new(shadow_path), None, &["config", "core.preloadindex", "true"])?; Ok(()) } -fn stage_commit_tag(shadow_path: &str, message: &str, tag: &str) -> Result<()> { +fn stage_commit_tag(shadow_path: &str, work_tree: &Path, message: &str, tag: &str) -> Result<()> { // Stage all changes - run_git(Path::new(shadow_path), true, &["add", "-A"])?; + run_git(Path::new(shadow_path), Some(work_tree), &["add", "-A"])?; // Commit - let output = run_git(Path::new(shadow_path), true, &[ + let output = run_git(Path::new(shadow_path), Some(work_tree), &[ "commit", "--allow-empty", "--no-verify", @@ -371,33 +397,53 @@ fn stage_commit_tag(shadow_path: &str, message: &str, tag: &str) -> Result<()> { ])?; if !output.status.success() { - bail!("Git commit failed: {}", String::from_utf8_lossy(&output.stderr)); + bail!( + "Checkpoint initialization failed: {}", + String::from_utf8_lossy(&output.stderr) + ); } // Tag - let output = run_git(Path::new(shadow_path), false, &["tag", tag])?; + let output = run_git(Path::new(shadow_path), None, &["tag", tag])?; if !output.status.success() { - bail!("Git tag failed: {}", String::from_utf8_lossy(&output.stderr)); + bail!( + "Checkpoint initialization failed: {}", + String::from_utf8_lossy(&output.stderr) + ); } Ok(()) } -fn run_git(dir: &Path, with_work_tree: bool, args: &[&str]) -> Result { +fn run_git(dir: &Path, work_tree: Option<&Path>, args: &[&str]) -> Result { let mut cmd = Command::new("git"); cmd.arg(format!("--git-dir={}", dir.display())); - if with_work_tree { - cmd.arg("--work-tree=."); + if let Some(work_tree_path) = work_tree { + cmd.arg(format!("--work-tree={}", work_tree_path.display())); } cmd.args(args); + debug!("Executing git command: {:?}", cmd); let output = cmd.output()?; - if !output.status.success() && !output.stderr.is_empty() { - bail!(String::from_utf8_lossy(&output.stderr).to_string()); + + if !output.status.success() { + debug!("Git command failed with exit code: {:?}", output.status.code()); + debug!("Git stderr: {}", String::from_utf8_lossy(&output.stderr)); + debug!("Git stdout: {}", String::from_utf8_lossy(&output.stdout)); + + if !output.stderr.is_empty() { + bail!( + "Checkpoint operation failed: {}", + String::from_utf8_lossy(&output.stderr) + ); + } else { + bail!("Checkpoint operation failed unexpectedly"); + } } + debug!("Git command succeeded"); Ok(output) } diff --git a/crates/chat-cli/src/cli/chat/cli/checkpoint.rs b/crates/chat-cli/src/cli/chat/cli/checkpoint.rs index 634da119c3..a38dabb366 100644 --- a/crates/chat-cli/src/cli/chat/cli/checkpoint.rs +++ b/crates/chat-cli/src/cli/chat/cli/checkpoint.rs @@ -107,7 +107,7 @@ impl CheckpointSubcommand { session.stderr, style::SetForegroundColor(Color::Yellow), style::Print( - "⚠️ Checkpoint is disabled while in tangent mode. Disable tangent mode with: q settings -d chat.enableTangentMode.\n\n" + "⚠️ Checkpoint is disabled while in tangent mode. Please exit tangent mode if you want to use checkpoint.\n\n" ), style::SetForegroundColor(Color::Reset), )?; diff --git a/crates/chat-cli/src/cli/chat/cli/experiment.rs b/crates/chat-cli/src/cli/chat/cli/experiment.rs index 9b7c3a8cd2..41d853dd40 100644 --- a/crates/chat-cli/src/cli/chat/cli/experiment.rs +++ b/crates/chat-cli/src/cli/chat/cli/experiment.rs @@ -12,7 +12,6 @@ use crossterm::{ }; use dialoguer::Select; -use crate::cli::chat::conversation::format_tool_spec; use crate::cli::chat::{ ChatError, ChatSession, @@ -52,11 +51,7 @@ static AVAILABLE_EXPERIMENTS: &[Experiment] = &[ }, Experiment { name: "Checkpoint", - description: concat!( - "Enables workspace checkpoints to snapshot, list, expand, diff, and restore files (/checkpoint)\n", - " ", - "Cannot be used in tangent mode (to avoid mixing up conversation history)" - ), + description: "Enables workspace checkpoints to snapshot, list, expand, diff, and restore files (/checkpoint)\nNote: Cannot be used in tangent mode (to avoid mixing up conversation history)", setting_key: Setting::EnabledCheckpoint, }, Experiment { @@ -85,17 +80,21 @@ async fn select_experiment(os: &mut Os, session: &mut ChatSession) -> Result Result {}", test_file_str); #[cfg(windows)] - let command = format!("type > {}", test_file_str); + let command = format!( + "powershell -Command \"$input | Out-File -FilePath '{}'\"", + test_file_str + ); let hook = Hook { command, diff --git a/crates/chat-cli/src/cli/chat/cli/mod.rs b/crates/chat-cli/src/cli/chat/cli/mod.rs index bf951596e6..4edbc15a06 100644 --- a/crates/chat-cli/src/cli/chat/cli/mod.rs +++ b/crates/chat-cli/src/cli/chat/cli/mod.rs @@ -12,6 +12,7 @@ pub mod model; pub mod persist; pub mod profile; pub mod prompts; +pub mod reply; pub mod subscribe; pub mod tangent; pub mod todos; @@ -32,6 +33,7 @@ use model::ModelArgs; use persist::PersistSubcommand; use profile::AgentSubcommand; use prompts::PromptsArgs; +use reply::ReplyArgs; use tangent::TangentArgs; use todos::TodoSubcommand; use tools::ToolsArgs; @@ -73,6 +75,8 @@ pub enum SlashCommand { /// Open $EDITOR (defaults to vi) to compose a prompt #[command(name = "editor")] PromptEditor(EditorArgs), + /// Open $EDITOR with the most recent assistant message quoted for reply + Reply(ReplyArgs), /// Summarize the conversation to free up context space Compact(CompactArgs), /// View tools and permissions @@ -104,7 +108,11 @@ pub enum SlashCommand { Persist(PersistSubcommand), // #[command(flatten)] // Root(RootSubcommand), - #[command(subcommand)] + #[command( + about = "(Beta) Manage workspace checkpoints (init, list, restore, expand, diff, clean)\nExperimental features may be changed or removed at any time", + hide = true, + subcommand + )] Checkpoint(CheckpointSubcommand), /// View, manage, and resume to-do lists #[command(subcommand)] @@ -143,6 +151,7 @@ impl SlashCommand { Self::Context(args) => args.execute(os, session).await, Self::Knowledge(subcommand) => subcommand.execute(os, session).await, Self::PromptEditor(args) => args.execute(session).await, + Self::Reply(args) => args.execute(session).await, Self::Compact(args) => args.execute(os, session).await, Self::Tools(args) => args.execute(session).await, Self::Issue(args) => { @@ -187,6 +196,7 @@ impl SlashCommand { Self::Context(_) => "context", Self::Knowledge(_) => "knowledge", Self::PromptEditor(_) => "editor", + Self::Reply(_) => "reply", Self::Compact(_) => "compact", Self::Tools(_) => "tools", Self::Issue(_) => "issue", diff --git a/crates/chat-cli/src/cli/chat/cli/prompts.rs b/crates/chat-cli/src/cli/chat/cli/prompts.rs index 53ed8dac91..7b6dc1ce6e 100644 --- a/crates/chat-cli/src/cli/chat/cli/prompts.rs +++ b/crates/chat-cli/src/cli/chat/cli/prompts.rs @@ -2462,14 +2462,18 @@ mod tests { let prompt1 = rmcp::model::Prompt { name: "test_prompt".to_string(), description: Some("Test description".to_string()), + title: Some("Test Prompt".to_string()), + icons: None, arguments: Some(vec![ PromptArgument { name: "arg1".to_string(), description: Some("First argument".to_string()), + title: Some("Argument 1".to_string()), required: Some(true), }, PromptArgument { name: "arg2".to_string(), + title: Some("Argument 2".to_string()), description: None, required: Some(false), }, diff --git a/crates/chat-cli/src/cli/chat/cli/reply.rs b/crates/chat-cli/src/cli/chat/cli/reply.rs new file mode 100644 index 0000000000..0bddf2083e --- /dev/null +++ b/crates/chat-cli/src/cli/chat/cli/reply.rs @@ -0,0 +1,108 @@ +use clap::Args; +use crossterm::execute; +use crossterm::style::{ + self, + Color, +}; + +use super::editor::open_editor; +use crate::cli::chat::{ + ChatError, + ChatSession, + ChatState, +}; + +/// Arguments to the `/reply` command. +#[deny(missing_docs)] +#[derive(Debug, PartialEq, Args)] +pub struct ReplyArgs {} + +impl ReplyArgs { + pub async fn execute(self, session: &mut ChatSession) -> Result { + // Get the most recent assistant message from transcript + let last_assistant_message = session + .conversation + .transcript + .iter() + .rev() + .find(|msg| !msg.starts_with("> ")) + .cloned(); + + let initial_text = match last_assistant_message { + Some(msg) => { + // Format with > prefix for each line + msg.lines() + .map(|line| format!("> {}", line)) + .collect::>() + .join("\n") + }, + None => { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("\nNo assistant message found to reply to.\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }, + }; + + let content = match open_editor(Some(initial_text.clone())) { + Ok(content) => content, + Err(err) => { + execute!( + session.stderr, + style::SetForegroundColor(Color::Red), + style::Print(format!("\nError opening editor: {}\n\n", err)), + style::SetForegroundColor(Color::Reset) + )?; + + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + }, + }; + + Ok( + match content.trim().is_empty() || content.trim() == initial_text.trim() { + true => { + execute!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("\nNo changes made in editor, not submitting.\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + + ChatState::PromptUser { + skip_printing_tools: true, + } + }, + false => { + execute!( + session.stderr, + style::SetForegroundColor(Color::Green), + style::Print("\nContent loaded from editor. Submitting prompt...\n\n"), + style::SetForegroundColor(Color::Reset) + )?; + + // Display the content as if the user typed it + execute!( + session.stderr, + style::SetAttribute(style::Attribute::Reset), + style::SetForegroundColor(Color::Magenta), + style::Print("> "), + style::SetAttribute(style::Attribute::Reset), + style::Print(&content), + style::Print("\n") + )?; + + // Process the content as user input + ChatState::HandleInput { input: content } + }, + }, + ) + } +} diff --git a/crates/chat-cli/src/cli/chat/consts.rs b/crates/chat-cli/src/cli/chat/consts.rs index 21f6b1b8ea..d2ff1b2a0c 100644 --- a/crates/chat-cli/src/cli/chat/consts.rs +++ b/crates/chat-cli/src/cli/chat/consts.rs @@ -3,7 +3,7 @@ pub const MAX_CURRENT_WORKING_DIRECTORY_LEN: usize = 256; /// Limit to send the number of messages as part of chat. -pub const MAX_CONVERSATION_STATE_HISTORY_LEN: usize = 250; +pub const MAX_CONVERSATION_STATE_HISTORY_LEN: usize = 10000; /// Actual service limit is 800_000 pub const MAX_TOOL_RESPONSE_SIZE: usize = 400_000; diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 1217c0289b..022e42e74d 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -199,7 +199,7 @@ impl ConversationState { next_message: None, history: VecDeque::new(), valid_history_range: Default::default(), - transcript: VecDeque::with_capacity(MAX_CONVERSATION_STATE_HISTORY_LEN), + transcript: VecDeque::new(), tools: format_tool_spec(tool_config), context_manager, tool_manager, @@ -253,6 +253,10 @@ impl ConversationState { self.transcript = checkpoint.main_transcript; self.latest_summary = checkpoint.main_latest_summary; self.valid_history_range = (0, self.history.len()); + if let Some(manager) = self.checkpoint_manager.as_mut() { + manager.message_locked = false; + manager.pending_user_message = None; + } } /// Enter tangent mode - creates checkpoint of current state @@ -912,6 +916,21 @@ Return only the JSON configuration, no additional text.", Ok(()) } + /// Reloads only built-in tools while preserving MCP tools + pub async fn reload_builtin_tools(&mut self, os: &mut Os, stderr: &mut impl Write) -> Result<(), ChatError> { + let builtin_tools = self + .tool_manager + .load_tools(os, stderr) + .await + .map_err(|e| ChatError::Custom(format!("Failed to reload built-in tools: {e}").into()))?; + + // Remove existing built-in tools and add updated ones, preserving MCP tools + self.tools.retain(|origin, _| *origin != ToolOrigin::Native); + self.tools.extend(format_tool_spec(builtin_tools)); + + Ok(()) + } + /// Swapping agent involves the following: /// - Reinstantiate the context manager /// - Swap agent on tool manager @@ -1370,7 +1389,7 @@ mod tests { // First, build a large conversation history. We need to ensure that the order is always // User -> Assistant -> User -> Assistant ...and so on. conversation.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + for i in 0..=200 { let s = conversation .as_sendable_conversation_state(&os, &mut vec![], true) .await @@ -1400,7 +1419,7 @@ mod tests { ) .await; conversation.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + for i in 0..=200 { let s = conversation .as_sendable_conversation_state(&os, &mut vec![], true) .await @@ -1436,7 +1455,7 @@ mod tests { ) .await; conversation.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + for i in 0..=200 { let s = conversation .as_sendable_conversation_state(&os, &mut vec![], true) .await @@ -1496,7 +1515,7 @@ mod tests { // First, build a large conversation history. We need to ensure that the order is always // User -> Assistant -> User -> Assistant ...and so on. conversation.set_next_user_message("start".to_string()).await; - for i in 0..=(MAX_CONVERSATION_STATE_HISTORY_LEN + 100) { + for i in 0..=200 { let s = conversation .as_sendable_conversation_state(&os, &mut vec![], true) .await diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index fcdb8b30ef..5c46bff757 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -3030,11 +3030,11 @@ impl ChatSession { // Reset for next turn manager.tools_in_turn = 0; - manager.message_locked = false; // Unlock for next turn } else { // Clear pending message even if no tools were used manager.pending_user_message = None; } + manager.message_locked = false; // Unlock for next turn // Put manager back self.conversation.checkpoint_manager = Some(manager); diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index 210a0f635c..b4ebdae25a 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -54,6 +54,7 @@ pub const COMMANDS: &[&str] = &[ "/clear", "/help", "/editor", + "/reply", "/issue", "/quit", "/tools", diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index a957e3c858..718c12d968 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -2112,7 +2112,9 @@ mod tests { // Create mock prompt bundles let prompt = rmcp::model::Prompt { name: "test_prompt".to_string(), + title: Some("Test Prompt".to_string()), description: Some("Test description".to_string()), + icons: None, arguments: None, }; @@ -2126,7 +2128,7 @@ mod tests { prompt_get: prompt, }; - let bundles = vec![&bundle1, &bundle2]; + let bundles = [&bundle1, &bundle2]; // Test finding specific server let found = bundles.iter().find(|b| b.server_name == "server1"); diff --git a/crates/chat-cli/src/cli/feed.json b/crates/chat-cli/src/cli/feed.json index 53c38195c2..6503c3af9e 100644 --- a/crates/chat-cli/src/cli/feed.json +++ b/crates/chat-cli/src/cli/feed.json @@ -12,10 +12,22 @@ }, { "type": "release", - "date": "2025-09-26", - "version": "1.16.3", - "title": "Version 1.16.3", + "date": "2025-09-29", + "version": "1.17.0", + "title": "Version 1.17.0", "changes": [ + { + "type": "added", + "description": "SSE support for MCP - [#2995](https://github.com/aws/amazon-q-developer-cli/pull/2995)" + }, + { + "type": "fixed", + "description": "Reloads only the built-in tools only for /experiment - [#3012](https://github.com/aws/amazon-q-developer-cli/pull/3012)" + }, + { + "type": "added", + "description": "Add a `/reply` command - [#2680](https://github.com/aws/amazon-q-developer-cli/pull/2680)" + }, { "type": "added", "description": "[Experimental] Adds checkpointing functionality using Git CLI commands - [#2896](https://github.com/aws/amazon-q-developer-cli/pull/2896)" @@ -39,6 +51,10 @@ { "type": "fixed", "description": "Improve error messages for dispatch failures - [#2969](https://github.com/aws/amazon-q-developer-cli/pull/2969)" + }, + { + "type": "added", + "description": "Enhanced MCP prompt management with improved UX - [#2953](https://github.com/aws/amazon-q-developer-cli/pull/2953)" } ] }, diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 44473c3a74..fefaa1a9cb 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -6,6 +6,7 @@ use regex::Regex; use rmcp::model::{ CallToolRequestParam, CallToolResult, + ClientResult, ErrorCode, GetPromptRequestParam, GetPromptResult, @@ -42,17 +43,15 @@ use tokio::process::{ }; use tokio::task::JoinHandle; use tracing::{ - debug, error, info, }; use super::messenger::Messenger; -use super::oauth_util::HttpTransport; use super::{ AuthClientWrapper, + HttpServiceBuilder, OauthUtilError, - get_http_transport, }; use crate::cli::chat::server_messenger::ServerMessenger; use crate::cli::chat::tools::custom_tool::{ @@ -266,37 +265,10 @@ impl RunningService { decorate_with_auth_retry!(GetPromptRequestParam, get_prompt, GetPromptResult); } -pub type StdioTransport = (TokioChildProcess, Option); - -// TODO: add sse support (even though it's deprecated) -/// Represents the different transport mechanisms available for MCP (Model Context Protocol) -/// communication. -/// -/// This enum encapsulates the two primary ways to communicate with MCP servers: -/// - HTTP-based transport for remote servers -/// - Standard I/O transport for local process-based servers -pub enum Transport { - /// HTTP transport for communicating with remote MCP servers over network protocols. - /// Uses a streamable HTTP client with authentication support. - Http(HttpTransport), - /// Standard I/O transport for communicating with local MCP servers via child processes. - /// Communication happens through stdin/stdout pipes. - Stdio(StdioTransport), -} - -impl std::fmt::Debug for Transport { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Transport::Http(_) => f.debug_tuple("Http").field(&"HttpTransport").finish(), - Transport::Stdio(_) => f.debug_tuple("Stdio").field(&"TokioChildProcess").finish(), - } - } -} - /// This struct implements the [Service] trait from rmcp. It is within this trait the logic of /// server driven data flow (i.e. requests and notifications that are sent from the server) are /// handled. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct McpClientService { pub config: CustomToolConfig, server_name: String, @@ -312,103 +284,14 @@ impl McpClientService { } } - pub async fn init(mut self, os: &Os) -> Result { + pub async fn init(self, os: &Os) -> Result { let os_clone = os.clone(); let handle: JoinHandle> = tokio::spawn(async move { let messenger_clone = self.messenger.clone(); let server_name = self.server_name.clone(); - let backup_config = self.config.clone(); - - let result: Result<_, McpClientError> = async { - let messenger_dup = messenger_clone.duplicate(); - let (service, stderr, auth_client) = match self.get_transport(&os_clone, &*messenger_dup).await? { - Transport::Stdio((child_process, stderr)) => { - let service = self - .into_dyn() - .serve::(child_process) - .await - .map_err(Box::new)?; - - (service, stderr, None) - }, - Transport::Http(http_transport) => { - match http_transport { - HttpTransport::WithAuth((transport, mut auth_client)) => { - // The crate does not automatically refresh tokens when they expire. We - // would need to handle that here - let url = &backup_config.url; - let service = match self.into_dyn().serve(transport).await.map_err(Box::new) { - Ok(service) => service, - Err(e) if matches!(*e, ClientInitializeError::ConnectionClosed(_)) => { - debug!("## mcp: first hand shake attempt failed: {:?}", e); - let refresh_res = auth_client.refresh_token().await; - let new_self = McpClientService::new( - server_name.clone(), - backup_config.clone(), - messenger_clone.clone(), - ); - - let scopes = &backup_config.oauth_scopes; - let timeout = backup_config.timeout; - let headers = &backup_config.headers; - let new_transport = - get_http_transport(&os_clone, url, timeout, scopes, headers,Some(auth_client.auth_client.clone()), &*messenger_dup).await?; - - match new_transport { - HttpTransport::WithAuth((new_transport, new_auth_client)) => { - auth_client = new_auth_client; - - match refresh_res { - Ok(_) => { - new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? - }, - Err(e) => { - error!("## mcp: token refresh attempt failed: {:?}", e); - info!("Retry for http transport failed {e}. Possible reauth needed"); - // This could be because the refresh token is expired, in which - // case we would need to have user go through the auth flow - // again. We do this by deleting the cred - // and discarding the client to trigger a full auth flow - tokio::fs::remove_file(&auth_client.cred_full_path).await?; - let new_transport = - get_http_transport(&os_clone, url, timeout, scopes,headers,None, &*messenger_dup).await?; - - match new_transport { - HttpTransport::WithAuth((new_transport, new_auth_client)) => { - auth_client = new_auth_client; - new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? - }, - HttpTransport::WithoutAuth(new_transport) => { - new_self.into_dyn().serve(new_transport).await.map_err(Box::new)? - }, - } - }, - } - }, - HttpTransport::WithoutAuth(new_transport) => - new_self.into_dyn().serve(new_transport).await.map_err(Box::new)?, - } - }, - Err(e) => return Err(e.into()), - }; - - (service, None, Some(auth_client)) - }, - HttpTransport::WithoutAuth(transport) => { - let service = self.into_dyn().serve(transport).await.map_err(Box::new)?; - - (service, None, None) - }, - } - }, - }; - Ok((service, stderr, auth_client)) - } - .await; - - let (service, child_stderr, auth_dropguard) = match result { + let (service, child_stderr, auth_dropguard) = match self.into_service(&os_clone, &messenger_clone).await { Ok((service, stderr, auth_dg)) => (service, stderr, auth_dg), Err(e) => { let msg = e.to_string(); @@ -498,18 +381,24 @@ impl McpClientService { Ok(InitializedMcpClient::Pending(handle)) } - async fn get_transport(&mut self, os: &Os, messenger: &dyn Messenger) -> Result { + async fn into_service( + mut self, + os: &Os, + messenger: &dyn Messenger, + ) -> Result< + ( + rmcp::service::RunningService>>, + Option, + Option, + ), + McpClientError, + > { let CustomToolConfig { r#type, url, - headers, - oauth_scopes: scopes, command: command_as_str, - args, - env: config_envs, - timeout, .. - } = &mut self.config; + } = &self.config; let is_malformed_http = matches!(r#type, TransportType::Http) && url.is_empty(); let is_malformed_stdio = matches!(r#type, TransportType::Stdio) && command_as_str.is_empty(); @@ -526,6 +415,13 @@ impl McpClientService { match r#type { TransportType::Stdio => { + let CustomToolConfig { + command: command_as_str, + args, + env: config_envs, + .. + } = &mut self.config; + let context = |input: &str| Ok(os.env.get(input).ok()); let home_dir = || os.env.home().map(|p| p.to_string_lossy().to_string()); let expanded_cmd = shellexpand::full_with_context(command_as_str, home_dir, context)?; @@ -544,12 +440,28 @@ impl McpClientService { let (tokio_child_process, child_stderr) = TokioChildProcess::builder(command).stderr(Stdio::piped()).spawn()?; - Ok(Transport::Stdio((tokio_child_process, child_stderr))) + let service = self + .into_dyn() + .serve::(tokio_child_process) + .await + .map_err(Box::new)?; + + Ok((service, child_stderr, None)) }, TransportType::Http => { - let http_transport = get_http_transport(os, url, *timeout, scopes, headers, None, messenger).await?; + let CustomToolConfig { + url, + headers, + oauth_scopes: scopes, + timeout, + .. + } = &self.config; + + let http_service_builder = HttpServiceBuilder::new(url, os, url, *timeout, scopes, headers, messenger); + + let (service, auth_client_wrapper) = http_service_builder.try_build(&self).await?; - Ok(Transport::Http(http_transport)) + Ok((service, None, auth_client_wrapper)) }, } } @@ -620,7 +532,7 @@ impl Service for McpClientService { _context: rmcp::service::RequestContext, ) -> Result<::Resp, rmcp::ErrorData> { match request { - ServerRequest::PingRequest(_) => Err(rmcp::ErrorData::method_not_found::()), + ServerRequest::PingRequest(_) => Ok(ClientResult::empty(())), ServerRequest::CreateMessageRequest(_) => Err(rmcp::ErrorData::method_not_found::< rmcp::model::CreateMessageRequestMethod, >()), @@ -660,6 +572,7 @@ impl Service for McpClientService { client_info: Implementation { name: "Q DEV CLI".to_string(), version: "1.0.0".to_string(), + ..Default::default() }, } } diff --git a/crates/chat-cli/src/mcp_client/oauth_util.rs b/crates/chat-cli/src/mcp_client/oauth_util.rs index 8af59a6c13..dfd1e9e0e7 100644 --- a/crates/chat-cli/src/mcp_client/oauth_util.rs +++ b/crates/chat-cli/src/mcp_client/oauth_util.rs @@ -15,22 +15,28 @@ use hyper::body::Bytes; use hyper::server::conn::http1; use hyper_util::rt::TokioIo; use reqwest::Client; -use rmcp::serde_json; +use rmcp::service::{ + DynService, + ServiceExt, +}; use rmcp::transport::auth::{ AuthClient, OAuthClientConfig, OAuthState, OAuthTokenResponse, }; -use rmcp::transport::streamable_http_client::{ - StreamableHttpClientTransportConfig, - StreamableHttpClientWorker, -}; +use rmcp::transport::sse_client::SseClientConfig; +use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use rmcp::transport::{ AuthorizationManager, AuthorizationSession, + SseClientTransport, StreamableHttpClientTransport, - WorkerTransport, +}; +use rmcp::{ + RoleClient, + Service, + serde_json, }; use serde::{ Deserialize, @@ -68,6 +74,8 @@ pub enum OauthUtilError { Serde(#[from] serde_json::Error), #[error("Missing authorization manager")] MissingAuthorizationManager, + #[error("Missing auth client when token refresh is needed")] + MissingAuthClient, #[error(transparent)] OneshotRecv(#[from] tokio::sync::oneshot::error::RecvError), #[error(transparent)] @@ -80,6 +88,10 @@ pub enum OauthUtilError { MalformDirectory, #[error("Missing credential")] MissingCredentials, + #[error("Failed to create a running service after running through all fallbacks: {0}")] + ServiceNotObtained(String), + #[error("{0}")] + SseTransport(String), } /// A guard that automatically cancels the cancellation token when dropped. @@ -95,6 +107,14 @@ impl Drop for LoopBackDropGuard { } } +/// OAuth Authorization Server metadata for endpoint discovery +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct OAuthMeta { + pub authorization_endpoint: String, + pub token_endpoint: String, + pub registration_endpoint: Option, +} + /// This is modeled after [OAuthClientConfig] /// It's only here because [OAuthClientConfig] does not implement Serialize and Deserialize #[derive(Clone, Serialize, Deserialize, Debug)] @@ -149,93 +169,281 @@ impl AuthClientWrapper { } } -/// HTTP transport wrapper that handles both authenticated and non-authenticated MCP connections. -/// -/// This enum provides two variants for different authentication scenarios: -/// - `WithAuth`: Used when the MCP server requires OAuth authentication, containing both the -/// transport worker and an auth client guard that manages credential persistence -/// - `WithoutAuth`: Used for servers that don't require authentication, containing only the basic -/// transport worker -/// -/// The appropriate variant is automatically selected based on the server's response to -/// an initial probe request during transport creation. -pub enum HttpTransport { - WithAuth( - ( - WorkerTransport>>, - AuthClientWrapper, - ), - ), - WithoutAuth(WorkerTransport>), -} - pub fn get_default_scopes() -> &'static [&'static str] { &["openid", "email", "profile", "offline_access"] } -pub async fn get_http_transport( - os: &Os, - url: &str, - timeout: u64, - scopes: &[String], - headers: &HashMap, - auth_client: Option>, - messenger: &dyn Messenger, -) -> Result { - let cred_dir = get_mcp_auth_dir(os)?; - let url = Url::from_str(url)?; - let key = compute_key(&url); - let cred_full_path = cred_dir.join(format!("{key}.token.json")); - let reg_full_path = cred_dir.join(format!("{key}.registration.json")); - - let mut client_builder = reqwest::ClientBuilder::new().timeout(std::time::Duration::from_millis(timeout)); - if !headers.is_empty() { - let headers = HeaderMap::try_from(headers).map_err(|e| OauthUtilError::Http(e.to_string()))?; - client_builder = client_builder.default_headers(headers); - }; - let reqwest_client = client_builder.build()?; - - // The probe request, like all other request, should adhere to the standards as per https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server - let mut probe_request = reqwest_client.post(url.clone()); - probe_request = probe_request.header("Accept", "application/json, text/event-stream"); - let probe_resp = probe_request.send().await?; - match probe_resp.status() { - StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { - let auth_client = match auth_client { - Some(auth_client) => auth_client, - None => { - let am = get_auth_manager( - url.clone(), - cred_full_path.clone(), - reg_full_path.clone(), - scopes, - messenger, - ) - .await?; - AuthClient::new(reqwest_client, am) - }, - }; - let transport = - StreamableHttpClientTransport::with_client(auth_client.clone(), StreamableHttpClientTransportConfig { - uri: url.as_str().into(), - allow_stateless: true, - ..Default::default() - }); +enum TransportType { + Http, + Sse, +} + +enum HttpServiceBuilderState { + AttemptConnection(TransportType, bool), + FailedBecauseTokenMightBeExpired, + Exhausted, +} - let auth_dg = AuthClientWrapper::new(cred_full_path, auth_client); +pub type HttpRunningService = ( + rmcp::service::RunningService>>, + Option, +); + +pub struct HttpServiceBuilder<'a> { + pub server_name: &'a str, + pub os: &'a Os, + pub url: &'a str, + pub timeout: u64, + pub scopes: &'a [String], + pub headers: &'a HashMap, + pub messenger: &'a dyn Messenger, +} - Ok(HttpTransport::WithAuth((transport, auth_dg))) - }, - _ => { - let transport = - StreamableHttpClientTransport::with_client(reqwest_client, StreamableHttpClientTransportConfig { - uri: url.as_str().into(), - allow_stateless: true, - ..Default::default() - }); - - Ok(HttpTransport::WithoutAuth(transport)) - }, +impl<'a> HttpServiceBuilder<'a> { + pub fn new( + server_name: &'a str, + os: &'a Os, + url: &'a str, + timeout: u64, + scopes: &'a [String], + headers: &'a HashMap, + messenger: &'a dyn Messenger, + ) -> Self { + Self { + server_name, + os, + url, + timeout, + scopes, + headers, + messenger, + } + } + + pub async fn try_build + Clone>( + self, + service: &S, + ) -> Result { + let HttpServiceBuilder { + server_name, + os, + url, + timeout, + scopes, + headers, + messenger, + } = self; + + let mut state = HttpServiceBuilderState::AttemptConnection(TransportType::Http, false); + let cred_dir = get_mcp_auth_dir(os)?; + let url = Url::from_str(url)?; + let key = compute_key(&url); + let cred_full_path = cred_dir.join(format!("{key}.token.json")); + let reg_full_path = cred_dir.join(format!("{key}.registration.json")); + let mut auth_client = None::>; + + let mut client_builder = reqwest::ClientBuilder::new().timeout(std::time::Duration::from_millis(timeout)); + if !headers.is_empty() { + let headers = HeaderMap::try_from(headers).map_err(|e| OauthUtilError::Http(e.to_string()))?; + client_builder = client_builder.default_headers(headers); + }; + let reqwest_client = client_builder.build()?; + + // The probe request, like all other request, should adhere to the standards as per https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#sending-messages-to-the-server + let probe_resp = reqwest_client + .post(url.clone()) + .header("Accept", "application/json, text/event-stream") + .send() + .await; + let is_probe_err = probe_resp.is_err(); + let is_status_401_or_403 = probe_resp + .as_ref() + .is_ok_and(|resp| resp.status() == StatusCode::UNAUTHORIZED || resp.status() == StatusCode::FORBIDDEN); + + let contains_auth_header = probe_resp.is_ok_and(|resp| { + resp.headers().get("www-authenticate").is_some_and(|v| { + let value_as_str = v.to_str(); + if let Ok(value) = value_as_str { + value.to_lowercase().contains("bearer") + } else { + false + } + }) + }); + let needs_auth = is_probe_err || is_status_401_or_403 || contains_auth_header; + + // Here we attempt the following in the order they are presented: + // 1. Build transport, first assume http on attempt one, sse on attempt two + // - If it fails and it needs auth, attempt to refresh token (#2) + // - If it fails and it does not need auth OR if it fails after a refresh, attempt sse (#3) + // 2. Refresh token, go back to #1 + // 3. Attempt sse + // - If it fails, abort (because at this point we have run out of things to try, note that + // refreshing of token is agnostic to the type of transport) + loop { + match state { + HttpServiceBuilderState::AttemptConnection(transport_type, has_refreshed) => { + if needs_auth { + let ac = match auth_client { + Some(ref auth_client) => auth_client.clone(), + None => { + let am = get_auth_manager( + url.clone(), + cred_full_path.clone(), + reg_full_path.clone(), + scopes, + messenger, + ) + .await?; + + let ac = AuthClient::new(reqwest_client.clone(), am); + auth_client.replace(ac.clone()); + ac + }, + }; + + match transport_type { + TransportType::Http => { + let transport = StreamableHttpClientTransport::with_client( + ac.clone(), + StreamableHttpClientTransportConfig { + uri: url.as_str().into(), + allow_stateless: true, + ..Default::default() + }, + ); + + match service.clone().into_dyn().serve(transport).await { + Ok(service) => { + let auth_client_wrapper = AuthClientWrapper::new(cred_full_path, ac); + return Ok((service, Some(auth_client_wrapper))); + }, + Err(e) => { + if !has_refreshed { + error!( + "## mcp: http handshake attempt failed for {server_name}: {:?}. Attempting to refresh token", + e + ); + // first we'll try refreshing the token + state = HttpServiceBuilderState::FailedBecauseTokenMightBeExpired; + } else { + error!( + "## mcp: http handshake attempt failed for {server_name}: {:?}. Attempting sse", + e + ); + state = + HttpServiceBuilderState::AttemptConnection(TransportType::Sse, true); + } + }, + } + }, + TransportType::Sse => { + let transport = SseClientTransport::start_with_client(ac.clone(), SseClientConfig { + sse_endpoint: url.as_str().into(), + ..Default::default() + }) + .await + .map_err(|e| OauthUtilError::SseTransport(e.to_string()))?; + + match service.clone().into_dyn().serve(transport).await { + Ok(service) => { + let auth_client_wrapper = AuthClientWrapper::new(cred_full_path, ac); + return Ok((service, Some(auth_client_wrapper))); + }, + Err(e) => { + // at this point we would have already tried refreshing + // we are out of things to try and should just fail + error!( + "## mcp: sse handshake attempted failed for {server_name}: {:?}. Aborting", + e + ); + state = HttpServiceBuilderState::Exhausted; + }, + } + }, + } + } else { + info!( + "## mcp: No OAuth endpoints discovered for {server_name}, using unauthenticated transport" + ); + + match transport_type { + TransportType::Http => { + info!("## mcp: attempting open http handshake for {server_name}"); + let transport = StreamableHttpClientTransport::with_client( + reqwest_client.clone(), + StreamableHttpClientTransportConfig { + uri: url.as_str().into(), + allow_stateless: true, + ..Default::default() + }, + ); + + match service.clone().into_dyn().serve(transport).await { + Ok(service) => return Ok((service, None)), + Err(e) => { + error!( + "## mcp: open http handshake attempted failed for {server_name}: {:?}. Attempting sse", + e + ); + state = HttpServiceBuilderState::AttemptConnection(TransportType::Sse, false); + }, + } + }, + TransportType::Sse => { + info!("## mcp: attempting open sse handshake for {server_name}"); + let transport = + SseClientTransport::start_with_client(reqwest_client.clone(), SseClientConfig { + sse_endpoint: url.as_str().into(), + ..Default::default() + }) + .await + .map_err(|e| OauthUtilError::SseTransport(e.to_string()))?; + + match service.clone().into_dyn().serve(transport).await { + Ok(service) => return Ok((service, None)), + Err(e) => { + error!( + "## mcp: open sse handshake attempted failed for {server_name}: {:?}. Aborting", + e + ); + state = HttpServiceBuilderState::Exhausted; + }, + } + }, + } + } + }, + HttpServiceBuilderState::FailedBecauseTokenMightBeExpired => { + let auth_client_ref = auth_client.as_ref().ok_or(OauthUtilError::MissingAuthClient)?; + let auth_client_wrapper = AuthClientWrapper::new(cred_full_path.clone(), auth_client_ref.clone()); + let refresh_res = auth_client_wrapper.refresh_token().await; + + if let Err(e) = refresh_res { + error!("## mcp: token refresh attempt failed: {:?}", e); + info!("Retry for http transport failed {e}. Possible reauth needed"); + // This could be because the refresh token is expired, in which + // case we would need to have user go through the auth flow + // again. We do this by deleting the cred + // and discarding the client to trigger a full auth flow + if cred_full_path.is_file() { + tokio::fs::remove_file(&cred_full_path).await?; + } + + // we'll also need to remove the auth client to force a reauth when we go + // back to attempt the first step again + auth_client.take(); + } + + state = HttpServiceBuilderState::AttemptConnection(TransportType::Http, true); + }, + HttpServiceBuilderState::Exhausted => { + return Err(OauthUtilError::ServiceNotObtained( + "Max number of retries exhausted".to_string(), + )); + }, + } + } } } @@ -305,7 +513,7 @@ async fn get_auth_manager_impl( ) -> Result<(AuthorizationManager, String), OauthUtilError> { let socket_addr = SocketAddr::from(([127, 0, 0, 1], 0)); let cancellation_token = tokio_util::sync::CancellationToken::new(); - let (tx, rx) = tokio::sync::oneshot::channel::(); + let (tx, rx) = tokio::sync::oneshot::channel::<(String, String)>(); let (actual_addr, _dg) = make_svc(tx, socket_addr, cancellation_token).await?; info!("Listening on local host port {:?} for oauth", actual_addr); @@ -318,8 +526,8 @@ async fn get_auth_manager_impl( let auth_url = oauth_state.get_authorization_url().await?; _ = messenger.send_oauth_link(auth_url).await; - let auth_code = rx.await?; - oauth_state.handle_callback(&auth_code).await?; + let (auth_code, csrf_token) = rx.await?; + oauth_state.handle_callback(&auth_code, &csrf_token).await?; let am = oauth_state .into_authorization_manager() .ok_or(OauthUtilError::MissingAuthorizationManager)?; @@ -408,13 +616,14 @@ fn get_stub_credentials() -> Result { } async fn make_svc( - one_shot_sender: Sender, + one_shot_sender: Sender<(String, String)>, socket_addr: SocketAddr, cancellation_token: CancellationToken, ) -> Result<(SocketAddr, LoopBackDropGuard), OauthUtilError> { + type AuthCodeSender = Sender<(String, String)>; #[derive(Clone, Debug)] struct LoopBackForSendingAuthCode { - one_shot_sender: Arc>>>, + one_shot_sender: Arc>>, } #[derive(Debug, thiserror::Error)] @@ -423,8 +632,8 @@ async fn make_svc( Poison(String), #[error(transparent)] Http(#[from] http::Error), - #[error("Failed to send auth code: {0}")] - Send(String), + #[error("Failed to send auth code")] + Send((String, String)), } fn mk_response(s: String) -> Result>, LoopBackError> { @@ -459,13 +668,14 @@ async fn make_svc( }; let code = params.get("code").cloned().unwrap_or_default(); + let state = params.get("state").cloned().unwrap_or_default(); if let Some(sender) = self_clone .one_shot_sender .lock() .map_err(|e| LoopBackError::Poison(e.to_string()))? .take() { - sender.send(code).map_err(LoopBackError::Send)?; + sender.send((code, state)).map_err(LoopBackError::Send)?; } resp