From bd7521c9a5eb57898254132b47f7fd6dd9e6919d Mon Sep 17 00:00:00 2001 From: abhraina-aws Date: Fri, 15 Aug 2025 15:53:52 -0700 Subject: [PATCH 1/7] feat: add github action for release notification (#2625) * feat: add github action for release notification --- .github/workflows/release-notification.yaml | 26 +++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 .github/workflows/release-notification.yaml diff --git a/.github/workflows/release-notification.yaml b/.github/workflows/release-notification.yaml new file mode 100644 index 0000000000..1abf3387bc --- /dev/null +++ b/.github/workflows/release-notification.yaml @@ -0,0 +1,26 @@ + +name: Release Notification + +on: + release: + types: [published] # Trigger on new releases being published + +jobs: + slack_notification: + runs-on: ubuntu-latest + steps: + - name: Send Release Details to Slack + uses: slackapi/slack-github-action@v1.23.0 # Or the latest version of this action + with: + payload: | + { + "release_name": "${{ github.event.release.name }}", + "tag_name": "${{ github.event.release.tag_name }}", + "release_url": "${{ github.event.release.html_url }}", + "author_name": "${{ github.event.release.author.login }}", + "repository_name": "${{ github.event.repository.name }}", + "repository_url": "${{ github.event.repository.html_url }}", + "release_description": ${{ toJSON(github.event.release.body) }} + } + env: + SLACK_WEBHOOK_URL: ${{ secrets.SLACK_WEBHOOK_URL }} # Use the secret for the webhook URL From 71c00814247c3c2d6e134c3cbd0f23f6745b1466 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Tue, 19 Aug 2025 15:06:49 -0700 Subject: [PATCH 2/7] feat(agent): hot swap (#2637) * changes prompt list result to be sent over via messenger * changes tool manager orchestrator tasks to keep prompts * changes mpsc to broadcast * restores prompt list functionality * restore prompt get functionality * adds api on tool manager to hotswap * spawns task to send deinit msg via messenger * adds slash command to hotswap agent * modifies load tool wait time depending on context * adds comments to retry logic for prompt completer * fixes lint * adds pid field to messenger message * adds interactive menu for swapping agent * fixes stale mcp load record * documents build method on tool manager builder and refactor to make the build method smaller --- crates/chat-cli/src/cli/chat/cli/profile.rs | 48 + crates/chat-cli/src/cli/chat/cli/prompts.rs | 13 +- crates/chat-cli/src/cli/chat/conversation.rs | 25 + crates/chat-cli/src/cli/chat/input_source.rs | 12 +- crates/chat-cli/src/cli/chat/mod.rs | 18 +- crates/chat-cli/src/cli/chat/prompt.rs | 102 +- .../chat-cli/src/cli/chat/server_messenger.rs | 25 + crates/chat-cli/src/cli/chat/tool_manager.rs | 1426 ++++++++++------- .../src/cli/chat/tools/custom_tool.rs | 18 +- crates/chat-cli/src/mcp_client/client.rs | 69 +- crates/chat-cli/src/mcp_client/messenger.rs | 5 + 11 files changed, 1075 insertions(+), 686 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index 2a02063d25..fb6b17a67d 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -11,6 +11,7 @@ use crossterm::{ execute, queue, }; +use dialoguer::Select; use syntect::easy::HighlightLines; use syntect::highlighting::{ Style, @@ -77,6 +78,9 @@ pub enum AgentSubcommand { #[arg(long, short)] name: String, }, + /// Swap to a new agent at runtime + #[command(alias = "switch")] + Swap { name: Option }, } impl AgentSubcommand { @@ -224,6 +228,49 @@ impl AgentSubcommand { )?; }, }, + Self::Swap { name } => { + if let Some(name) = name { + session.conversation.swap_agent(os, &mut session.stderr, &name).await?; + } else { + let labels = session + .conversation + .agents + .agents + .keys() + .map(|name| name.as_str()) + .collect::>(); + + let name = { + let idx = match Select::with_theme(&crate::util::dialoguer_theme()) + .with_prompt("Choose one of the following agents") + .items(&labels) + .default(1) + .interact_on_opt(&dialoguer::console::Term::stdout()) + { + Ok(sel) => { + let _ = crossterm::execute!( + std::io::stdout(), + crossterm::style::SetForegroundColor(crossterm::style::Color::Magenta) + ); + sel + }, + // Ctrl‑C -> Err(Interrupted) + Err(dialoguer::Error::IO(ref e)) if e.kind() == std::io::ErrorKind::Interrupted => None, + Err(e) => { + return Err(ChatError::Custom( + format!("Dialog has failed to make a selection {e}").into(), + )); + }, + }; + + idx.and_then(|idx| labels.get(idx).cloned().map(str::to_string)) + }; + + if let Some(name) = name { + session.conversation.swap_agent(os, &mut session.stderr, &name).await?; + } + } + }, } Ok(ChatState::PromptUser { @@ -239,6 +286,7 @@ impl AgentSubcommand { Self::Set { .. } => "set", Self::Schema => "schema", Self::SetDefault { .. } => "set_default", + Self::Swap { .. } => "swap", } } } diff --git a/crates/chat-cli/src/cli/chat/cli/prompts.rs b/crates/chat-cli/src/cli/chat/cli/prompts.rs index efbdbc49ed..53b0012a57 100644 --- a/crates/chat-cli/src/cli/chat/cli/prompts.rs +++ b/crates/chat-cli/src/cli/chat/cli/prompts.rs @@ -38,12 +38,14 @@ pub enum GetPromptError { MissingClient, #[error("Missing prompt name")] MissingPromptName, - #[error("Synchronization error: {0}")] - Synchronization(String), #[error("Missing prompt bundle")] MissingPromptInfo, #[error(transparent)] General(#[from] eyre::Report), + #[error("Incorrect response type received")] + IncorrectResponseType, + #[error("Missing channel")] + MissingChannel, } #[deny(missing_docs)] @@ -76,10 +78,7 @@ impl PromptsArgs { } let terminal_width = session.terminal_width(); - let mut prompts_wl = session.conversation.tool_manager.prompts.write().map_err(|e| { - ChatError::Custom(format!("Poison error encountered while retrieving prompts: {}", e).into()) - })?; - session.conversation.tool_manager.refresh_prompts(&mut prompts_wl)?; + let prompts = session.conversation.tool_manager.list_prompts().await?; let mut longest_name = ""; let arg_pos = { let optimal_case = UnicodeWidthStr::width(longest_name) + terminal_width / 4; @@ -121,7 +120,7 @@ impl PromptsArgs { style::Print("\n"), style::Print(format!("{}\n", "▔".repeat(terminal_width))), )?; - let mut prompts_by_server: Vec<_> = prompts_wl + let mut prompts_by_server: Vec<_> = prompts .iter() .fold( HashMap::<&String, Vec<&PromptBundle>>::new(), diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index ca7b87d2c4..7c58febcf7 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -699,6 +699,31 @@ impl ConversationState { } self.transcript.push_back(message); } + + /// Swapping agent involves the following: + /// - Reinstantiate the context manager + /// - Swap agent on tool manager + pub async fn swap_agent( + &mut self, + os: &mut Os, + output: &mut impl Write, + agent_name: &str, + ) -> Result<(), ChatError> { + let agent = self.agents.switch(agent_name).map_err(ChatError::AgentSwapError)?; + self.context_manager.replace({ + ContextManager::from_agent(agent, calc_max_context_files_size(self.model_info.as_ref())) + .map_err(|e| ChatError::Custom(format!("Context manager has failed to instantiate: {e}").into()))? + }); + + self.tool_manager + .swap_agent(os, output, agent) + .await + .map_err(ChatError::AgentSwapError)?; + + self.update_state(true).await; + + Ok(()) + } } /// Represents a conversation state that can be converted into a [FigConversationState] (the type diff --git a/crates/chat-cli/src/cli/chat/input_source.rs b/crates/chat-cli/src/cli/chat/input_source.rs index 028b2e2889..5d88abf6f3 100644 --- a/crates/chat-cli/src/cli/chat/input_source.rs +++ b/crates/chat-cli/src/cli/chat/input_source.rs @@ -1,7 +1,11 @@ use eyre::Result; use rustyline::error::ReadlineError; -use super::prompt::rl; +use super::prompt::{ + PromptQueryResponseReceiver, + PromptQuerySender, + rl, +}; #[cfg(unix)] use super::skim_integration::SkimCommandSelector; use crate::os::Os; @@ -28,11 +32,7 @@ mod inner { } impl InputSource { - pub fn new( - os: &Os, - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, - ) -> Result { + pub fn new(os: &Os, sender: PromptQuerySender, receiver: PromptQueryResponseReceiver) -> Result { Ok(Self(inner::Inner::Readline(rl(os, sender, receiver)?))) } diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 0d3b7b1d8c..103973565c 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -96,6 +96,8 @@ use tokio::sync::{ broadcast, }; use tool_manager::{ + PromptQuery, + PromptQueryResult, ToolManager, ToolManagerBuilder, }; @@ -334,11 +336,14 @@ impl ChatArgs { Some(default_model_opt.model_id.clone()) }; - let (prompt_request_sender, prompt_request_receiver) = std::sync::mpsc::channel::>(); - let (prompt_response_sender, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let (prompt_request_sender, prompt_request_receiver) = tokio::sync::broadcast::channel::(5); + let (prompt_response_sender, prompt_response_receiver) = + tokio::sync::broadcast::channel::(5); let mut tool_manager = ToolManagerBuilder::default() - .prompt_list_sender(prompt_response_sender) - .prompt_list_receiver(prompt_request_receiver) + .prompt_query_result_sender(prompt_response_sender) + .prompt_query_receiver(prompt_request_receiver) + .prompt_query_sender(prompt_request_sender.clone()) + .prompt_query_result_receiver(prompt_response_receiver.resubscribe()) .conversation_id(&conversation_id) .agent(agents.get_active().cloned().unwrap_or_default()) .build(os, Box::new(std::io::stderr()), !self.no_interactive) @@ -470,6 +475,8 @@ pub enum ChatError { NonInteractiveToolApproval, #[error("The conversation history is too large to compact")] CompactHistoryFailure, + #[error("Failed to swap to agent: {0}")] + AgentSwapError(eyre::Report), } impl ChatError { @@ -486,6 +493,7 @@ impl ChatError { ChatError::GetPromptError(_) => None, ChatError::NonInteractiveToolApproval => None, ChatError::CompactHistoryFailure => None, + ChatError::AgentSwapError(_) => None, } } } @@ -504,6 +512,7 @@ impl ReasonCode for ChatError { ChatError::Auth(_) => "AuthError".to_string(), ChatError::NonInteractiveToolApproval => "NonInteractiveToolApproval".to_string(), ChatError::CompactHistoryFailure => "CompactHistoryFailure".to_string(), + ChatError::AgentSwapError(_) => "AgentSwapError".to_string(), } } } @@ -1602,6 +1611,7 @@ impl ChatSession { .await; if matches!(chat_state, ChatState::Exit) + || matches!(chat_state, ChatState::HandleResponseStream(_)) || matches!(chat_state, ChatState::HandleInput { input: _ }) // TODO(bskiser): this is just a hotfix for handling state changes // from manually running /compact, without impacting behavior of diff --git a/crates/chat-cli/src/cli/chat/prompt.rs b/crates/chat-cli/src/cli/chat/prompt.rs index 291fe35ba3..b785faa1d9 100644 --- a/crates/chat-cli/src/cli/chat/prompt.rs +++ b/crates/chat-cli/src/cli/chat/prompt.rs @@ -1,4 +1,5 @@ use std::borrow::Cow; +use std::cell::RefCell; use eyre::Result; use rustyline::completion::{ @@ -37,6 +38,10 @@ use winnow::stream::AsChar; pub use super::prompt_parser::generate_prompt; use super::prompt_parser::parse_prompt_components; +use super::tool_manager::{ + PromptQuery, + PromptQueryResult, +}; use crate::database::settings::Setting; use crate::os::Os; @@ -85,6 +90,9 @@ pub const COMMANDS: &[&str] = &[ "/subscribe", ]; +pub type PromptQuerySender = tokio::sync::broadcast::Sender; +pub type PromptQueryResponseReceiver = tokio::sync::broadcast::Receiver; + /// Complete commands that start with a slash fn complete_command(word: &str, start: usize) -> (usize, Vec) { ( @@ -134,29 +142,63 @@ impl PathCompleter { } pub struct PromptCompleter { - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, + sender: PromptQuerySender, + receiver: RefCell, } impl PromptCompleter { - fn new(sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>) -> Self { - PromptCompleter { sender, receiver } + fn new(sender: PromptQuerySender, receiver: PromptQueryResponseReceiver) -> Self { + PromptCompleter { + sender, + receiver: RefCell::new(receiver), + } } fn complete_prompt(&self, word: &str) -> Result, ReadlineError> { let sender = &self.sender; - let receiver = &self.receiver; + let receiver = self.receiver.borrow_mut(); + let query = PromptQuery::Search(if !word.is_empty() { Some(word.to_string()) } else { None }); + sender - .send(if !word.is_empty() { Some(word.to_string()) } else { None }) + .send(query) .map_err(|e| ReadlineError::Io(std::io::Error::other(e.to_string())))?; - let prompt_info = receiver - .recv() - .map_err(|e| ReadlineError::Io(std::io::Error::other(e.to_string())))? - .iter() - .map(|n| format!("@{n}")) - .collect::>(); + // We only want stuff from the current tail end onward + let mut new_receiver = receiver.resubscribe(); + + // Here we poll on the receiver for [max_attempts] number of times. + // The reason for this is because we are trying to receive something managed by an async + // channel from a sync context. + // If we ever switch back to a single threaded runtime for whatever reason, this function + // will not panic but nothing will be fetched because the thread that is doing + // try_recv is also the thread that is supposed to be doing the sending. + let mut attempts = 0; + let max_attempts = 5; + let query_res = loop { + match new_receiver.try_recv() { + Ok(result) => break result, + Err(_e) if attempts < max_attempts - 1 => { + attempts += 1; + std::thread::sleep(std::time::Duration::from_millis(100)); + }, + Err(e) => { + return Err(ReadlineError::Io(std::io::Error::other(eyre::eyre!( + "Failed to receive prompt info from complete prompt after {} attempts: {:?}", + max_attempts, + e + )))); + }, + } + }; + let matches = match query_res { + PromptQueryResult::Search(list) => list.into_iter().map(|n| format!("@{n}")).collect::>(), + PromptQueryResult::List(_) => { + return Err(ReadlineError::Io(std::io::Error::other(eyre::eyre!( + "Wrong query response type received", + )))); + }, + }; - Ok(prompt_info) + Ok(matches) } } @@ -166,7 +208,7 @@ pub struct ChatCompleter { } impl ChatCompleter { - fn new(sender: std::sync::mpsc::Sender>, receiver: std::sync::mpsc::Receiver>) -> Self { + fn new(sender: PromptQuerySender, receiver: PromptQueryResponseReceiver) -> Self { Self { path_completer: PathCompleter::new(), prompt_completer: PromptCompleter::new(sender, receiver), @@ -370,8 +412,8 @@ impl Highlighter for ChatHelper { pub fn rl( os: &Os, - sender: std::sync::mpsc::Sender>, - receiver: std::sync::mpsc::Receiver>, + sender: PromptQuerySender, + receiver: PromptQueryResponseReceiver, ) -> Result> { let edit_mode = match os.database.settings.get_string(Setting::ChatEditMode).as_deref() { Some("vi" | "vim") => EditMode::Vi, @@ -428,8 +470,8 @@ mod tests { #[test] fn test_chat_completer_command_completion() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); let line = "/h"; let pos = 2; // Position at the end of "/h" @@ -450,8 +492,8 @@ mod tests { #[test] fn test_chat_completer_no_completion() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let completer = ChatCompleter::new(prompt_request_sender, prompt_response_receiver); let line = "Hello, how are you?"; let pos = line.len(); @@ -469,8 +511,8 @@ mod tests { #[test] fn test_highlight_prompt_basic() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), hinter: ChatHinter::new(true), @@ -485,8 +527,8 @@ mod tests { #[test] fn test_highlight_prompt_with_warning() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), hinter: ChatHinter::new(true), @@ -501,8 +543,8 @@ mod tests { #[test] fn test_highlight_prompt_with_profile() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), hinter: ChatHinter::new(true), @@ -517,8 +559,8 @@ mod tests { #[test] fn test_highlight_prompt_with_profile_and_warning() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), hinter: ChatHinter::new(true), @@ -536,8 +578,8 @@ mod tests { #[test] fn test_highlight_prompt_invalid_format() { - let (prompt_request_sender, _) = std::sync::mpsc::channel::>(); - let (_, prompt_response_receiver) = std::sync::mpsc::channel::>(); + let (prompt_request_sender, _) = tokio::sync::broadcast::channel::(5); + let (_, prompt_response_receiver) = tokio::sync::broadcast::channel::(5); let helper = ChatHelper { completer: ChatCompleter::new(prompt_request_sender, prompt_response_receiver), hinter: ChatHinter::new(true), diff --git a/crates/chat-cli/src/cli/chat/server_messenger.rs b/crates/chat-cli/src/cli/chat/server_messenger.rs index 966600fc44..aaf685c399 100644 --- a/crates/chat-cli/src/cli/chat/server_messenger.rs +++ b/crates/chat-cli/src/cli/chat/server_messenger.rs @@ -19,21 +19,30 @@ pub enum UpdateEventMessage { ToolsListResult { server_name: String, result: eyre::Result, + pid: Option, }, PromptsListResult { server_name: String, result: eyre::Result, + pid: Option, }, ResourcesListResult { server_name: String, result: eyre::Result, + pid: Option, }, ResourceTemplatesListResult { server_name: String, result: eyre::Result, + pid: Option, }, InitStart { server_name: String, + pid: Option, + }, + Deinit { + server_name: String, + pid: Option, }, } @@ -55,6 +64,7 @@ impl ServerMessengerBuilder { ServerMessenger { server_name, update_event_sender: self.update_event_sender.clone(), + pid: None, } } } @@ -63,6 +73,7 @@ impl ServerMessengerBuilder { pub struct ServerMessenger { pub server_name: String, pub update_event_sender: Sender, + pub pid: Option, } #[async_trait::async_trait] @@ -73,6 +84,7 @@ impl Messenger for ServerMessenger { .send(UpdateEventMessage::ToolsListResult { server_name: self.server_name.clone(), result, + pid: self.pid, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -84,6 +96,7 @@ impl Messenger for ServerMessenger { .send(UpdateEventMessage::PromptsListResult { server_name: self.server_name.clone(), result, + pid: self.pid, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -98,6 +111,7 @@ impl Messenger for ServerMessenger { .send(UpdateEventMessage::ResourcesListResult { server_name: self.server_name.clone(), result, + pid: self.pid, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -112,6 +126,7 @@ impl Messenger for ServerMessenger { .send(UpdateEventMessage::ResourceTemplatesListResult { server_name: self.server_name.clone(), result, + pid: self.pid, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) @@ -122,11 +137,21 @@ impl Messenger for ServerMessenger { .update_event_sender .send(UpdateEventMessage::InitStart { server_name: self.server_name.clone(), + pid: self.pid, }) .await .map_err(|e| MessengerError::Custom(e.to_string()))?) } + fn send_deinit_msg(&self) { + let sender = self.update_event_sender.clone(); + let server_name = self.server_name.clone(); + let pid = self.pid; + tokio::spawn(async move { + let _ = sender.send(UpdateEventMessage::Deinit { server_name, pid }).await; + }); + } + fn duplicate(&self) -> Box { Box::new(self.clone()) } diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index 95739a1068..a8c8db40d5 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -14,14 +14,11 @@ use std::io::{ }; use std::path::PathBuf; use std::pin::Pin; +use std::sync::Arc; use std::sync::atomic::{ AtomicBool, Ordering, }; -use std::sync::{ - Arc, - RwLock as SyncRwLock, -}; use std::time::{ Duration, Instant, @@ -50,9 +47,11 @@ use tokio::sync::{ use tokio::task::JoinHandle; use tracing::{ error, + info, warn, }; +use super::tools::custom_tool::CustomToolConfig; use crate::api_client::model::{ ToolResult, ToolResultContentBlock, @@ -149,22 +148,81 @@ pub enum LoadingRecord { Err(String), } -#[derive(Default)] pub struct ToolManagerBuilder { - prompt_list_sender: Option>>, - prompt_list_receiver: Option>>, + prompt_query_result_sender: Option>, + prompt_query_receiver: Option>, + prompt_query_sender: Option>, + prompt_query_result_receiver: Option>, + messenger_builder: Option, conversation_id: Option, - agent: Option, + has_new_stuff: Arc, + mcp_load_record: Arc>>>, + new_tool_specs: NewToolSpecs, + is_first_launch: bool, + agent: Option>>, +} + +impl Default for ToolManagerBuilder { + fn default() -> Self { + Self { + prompt_query_result_sender: Default::default(), + prompt_query_receiver: Default::default(), + prompt_query_sender: Default::default(), + prompt_query_result_receiver: Default::default(), + messenger_builder: Default::default(), + conversation_id: Default::default(), + has_new_stuff: Default::default(), + mcp_load_record: Default::default(), + new_tool_specs: Default::default(), + is_first_launch: true, + agent: Default::default(), + } + } +} + +impl From<&mut ToolManager> for ToolManagerBuilder { + fn from(value: &mut ToolManager) -> Self { + Self { + conversation_id: Some(value.conversation_id.clone()), + agent: Some(value.agent.clone()), + prompt_query_sender: value + .prompts_sender_receiver_pair + .as_ref() + .map(|(sender, _)| sender.clone()), + prompt_query_result_receiver: value.prompts_sender_receiver_pair.take().map(|(_, receiver)| receiver), + messenger_builder: value.messenger_builder.take(), + has_new_stuff: value.has_new_stuff.clone(), + mcp_load_record: value.mcp_load_record.clone(), + new_tool_specs: value.new_tool_specs.clone(), + // if we are getting a builder from an instantiated tool manager this field would be + // false + is_first_launch: false, + ..Default::default() + } + } } impl ToolManagerBuilder { - pub fn prompt_list_sender(mut self, sender: std::sync::mpsc::Sender>) -> Self { - self.prompt_list_sender.replace(sender); + pub fn prompt_query_result_sender(mut self, sender: tokio::sync::broadcast::Sender) -> Self { + self.prompt_query_result_sender.replace(sender); + self + } + + pub fn prompt_query_receiver(mut self, receiver: tokio::sync::broadcast::Receiver) -> Self { + self.prompt_query_receiver.replace(receiver); self } - pub fn prompt_list_receiver(mut self, receiver: std::sync::mpsc::Receiver>) -> Self { - self.prompt_list_receiver.replace(receiver); + pub fn prompt_query_sender(mut self, sender: tokio::sync::broadcast::Sender) -> Self { + self.prompt_query_sender.replace(sender); + self + } + + pub fn prompt_query_result_receiver( + mut self, + receiver: tokio::sync::broadcast::Receiver, + ) -> Self { + self.prompt_query_result_receiver.replace(receiver); self } @@ -174,17 +232,28 @@ impl ToolManagerBuilder { } pub fn agent(mut self, agent: Agent) -> Self { + let agent = Arc::new(Mutex::new(agent)); self.agent.replace(agent); self } + /// Creates a [ToolManager] based on the current fields populated, which consists of the + /// following: + /// - Instantiates child processes associated with the list of mcp servers in scope + /// - Spawns a loading display task that is used to show server loading status (if applicable) + /// - Spawns the orchestrator task (see [spawn_orchestrator_task] for more detail) (if + /// applicable) + /// - Finally, creates an instance of [ToolManager] pub async fn build( mut self, os: &mut Os, mut output: Box, interactive: bool, ) -> eyre::Result { - let McpServerConfig { mcp_servers } = self.agent.as_ref().map(|a| a.mcp_servers.clone()).unwrap_or_default(); + let McpServerConfig { mcp_servers } = match &self.agent { + Some(agent) => agent.lock().await.mcp_servers.clone(), + None => Default::default(), + }; debug_assert!(self.conversation_id.is_some()); let conversation_id = self.conversation_id.ok_or(eyre::eyre!("Missing conversation id"))?; @@ -202,22 +271,7 @@ impl ToolManagerBuilder { let pre_initialized = enabled_servers .into_iter() .filter_map(|(server_name, server_config)| { - if server_name.contains(MCP_SERVER_TOOL_DELIMITER) { - let _ = queue!( - output, - style::SetForegroundColor(style::Color::Red), - style::Print("✗ Invalid server name "), - style::SetForegroundColor(style::Color::Blue), - style::Print(&server_name), - style::ResetColor, - style::Print(". Server name cannot contain "), - style::SetForegroundColor(style::Color::Yellow), - style::Print(MCP_SERVER_TOOL_DELIMITER), - style::ResetColor, - style::Print("\n") - ); - None - } else if server_name == "builtin" { + if server_name == "builtin" { let _ = queue!( output, style::SetForegroundColor(style::Color::Red), @@ -249,361 +303,65 @@ impl ToolManagerBuilder { // Spawn a task for displaying the mcp loading statuses. // This is only necessary when we are in interactive mode AND there are servers to load. // Otherwise we do not need to be spawning this. - let (loading_display_task, loading_status_sender) = if interactive - && (total > 0 || !disabled_servers.is_empty()) - { - let (tx, mut rx) = tokio::sync::mpsc::channel::(50); - let disabled_servers_display_clone = disabled_servers_display.clone(); - ( - Some(tokio::task::spawn(async move { - let mut spinner_logo_idx: usize = 0; - let mut complete: usize = 0; - let mut failed: usize = 0; - - // Show disabled servers immediately - for server_name in &disabled_servers_display_clone { - queue_disabled_message(server_name, &mut output)?; - } - - if total > 0 { - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - } - - loop { - match tokio::time::timeout(Duration::from_millis(50), rx.recv()).await { - Ok(Some(recv_result)) => match recv_result { - LoadingMsg::Done { name, time } => { - complete += 1; - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_success_message(&name, &time, &mut output)?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - }, - LoadingMsg::Error { name, msg, time } => { - failed += 1; - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - queue_failure_message(&name, &msg, time.as_str(), &mut output)?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - }, - LoadingMsg::Warn { name, msg, time } => { - complete += 1; - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = eyre::eyre!(msg.to_string()); - queue_warn_message(&name, &msg, time.as_str(), &mut output)?; - queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; - }, - LoadingMsg::Terminate { still_loading } => { - if !still_loading.is_empty() && total > 0 { - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - let msg = still_loading.iter().fold(String::new(), |mut acc, server_name| { - acc.push_str(format!("\n - {server_name}").as_str()); - acc - }); - let msg = eyre::eyre!(msg); - queue_incomplete_load_message(complete, total, &msg, &mut output)?; - } else if total > 0 { - // Clear the loading line if we have enabled servers - execute!( - output, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - terminal::Clear(terminal::ClearType::CurrentLine), - )?; - } - execute!(output, style::Print("\n"),)?; - break; - }, - }, - Err(_e) => { - spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); - execute!( - output, - cursor::SavePosition, - cursor::MoveToColumn(0), - cursor::MoveUp(1), - style::Print(SPINNER_CHARS[spinner_logo_idx]), - cursor::RestorePosition - )?; - }, - _ => break, - } - output.flush()?; - } - Ok::<_, eyre::Report>(()) - })), - Some(tx), - ) - } else { - (None, None) - }; + let (loading_display_task, loading_status_sender) = + spawn_display_task(interactive, total, disabled_servers, output); let mut clients = HashMap::>::new(); - let mut loading_status_sender_clone = loading_status_sender.clone(); - let conv_id_clone = conversation_id.clone(); - let regex = Regex::new(VALID_TOOL_NAME)?; - let new_tool_specs = Arc::new(Mutex::new(HashMap::new())); - let new_tool_specs_clone = new_tool_specs.clone(); - let has_new_stuff = Arc::new(AtomicBool::new(false)); - let has_new_stuff_clone = has_new_stuff.clone(); + let new_tool_specs = self.new_tool_specs; + let has_new_stuff = self.has_new_stuff; let pending = Arc::new(RwLock::new(HashSet::::new())); - let pending_clone = pending.clone(); - let (mut msg_rx, messenger_builder) = ServerMessengerBuilder::new(20); - let telemetry_clone = os.telemetry.clone(); let notify = Arc::new(Notify::new()); - let notify_weak = Arc::downgrade(¬ify); - let load_record = Arc::new(Mutex::new(HashMap::>::new())); - let load_record_clone = load_record.clone(); - let agent = Arc::new(Mutex::new(self.agent.unwrap_or_default())); - let agent_clone = agent.clone(); + let load_record = self.mcp_load_record; + let agent = self.agent.unwrap_or_default(); let database = os.database.clone(); + let mut messenger_builder = self.messenger_builder.take(); + + // This is the orchestrator task that serves as a bridge between tool manager and mcp + // clients for server initiated async events + if let (Some(prompt_list_sender), Some(prompt_list_receiver)) = ( + self.prompt_query_result_sender.clone(), + self.prompt_query_receiver.as_ref().map(|r| r.resubscribe()), + ) { + let (msg_rx, builder) = ServerMessengerBuilder::new(20); + messenger_builder.replace(builder); + + let has_new_stuff = has_new_stuff.clone(); + let notify_weak = Arc::downgrade(¬ify); + let telemetry = os.telemetry.clone(); + let loading_status_sender = loading_status_sender.clone(); + let new_tool_specs = new_tool_specs.clone(); + let conv_id = conversation_id.clone(); + let pending = pending.clone(); + let regex = Regex::new(VALID_TOOL_NAME)?; + + spawn_orchestrator_task( + has_new_stuff, + loading_servers, + msg_rx, + prompt_list_receiver, + prompt_list_sender, + pending, + agent.clone(), + database, + regex, + notify_weak, + load_record.clone(), + telemetry, + loading_status_sender, + new_tool_specs, + total, + conv_id, + ); + } - tokio::spawn(async move { - let mut record_temp_buf = Vec::::new(); - let mut initialized = HashSet::::new(); - - enum ToolFilter { - All, - List(HashSet), - } - - impl ToolFilter { - pub fn should_include(&self, tool_name: &str) -> bool { - match self { - Self::All => true, - Self::List(set) => set.contains(tool_name), - } - } - } - - while let Some(msg) = msg_rx.recv().await { - record_temp_buf.clear(); - // For now we will treat every list result as if they contain the - // complete set of tools. This is not necessarily true in the future when - // request method on the mcp client no longer buffers all the pages from - // list calls. - match msg { - UpdateEventMessage::ToolsListResult { server_name, result } => { - let time_taken = loading_servers - .remove(&server_name) - .map_or("0.0".to_owned(), |init_time| { - let time_taken = (std::time::Instant::now() - init_time).as_secs_f64().abs(); - format!("{:.2}", time_taken) - }); - pending_clone.write().await.remove(&server_name); - let (tool_filter, alias_list) = { - let agent_lock = agent_clone.lock().await; - - // We will assume all tools are allowed if the tool list consists of 1 - // element and it's a * - let tool_filter = if agent_lock.tools.len() == 1 - && agent_lock.tools.first().map(String::as_str).is_some_and(|c| c == "*") - { - ToolFilter::All - } else { - let set = agent_lock - .tools - .iter() - .filter(|tool_name| tool_name.starts_with(&format!("@{server_name}"))) - .map(|full_name| { - match full_name.split_once(MCP_SERVER_TOOL_DELIMITER) { - Some((_, tool_name)) if !tool_name.is_empty() => tool_name, - _ => "*", - } - .to_string() - }) - .collect::>(); - - if set.contains("*") { - ToolFilter::All - } else { - ToolFilter::List(set) - } - }; - - let server_prefix = format!("@{server_name}"); - let alias_list = agent_lock.tool_aliases.iter().fold( - HashMap::::new(), - |mut acc, (full_path, model_tool_name)| { - if full_path.starts_with(&server_prefix) { - if let Some((_, host_tool_name)) = - full_path.split_once(MCP_SERVER_TOOL_DELIMITER) - { - acc.insert(host_tool_name.to_string(), model_tool_name.clone()); - } - } - acc - }, - ); - - (tool_filter, alias_list) - }; - - match result { - Ok(result) => { - let mut specs = result - .tools - .into_iter() - .filter_map(|v| serde_json::from_value::(v).ok()) - .filter(|spec| tool_filter.should_include(&spec.name)) - .collect::>(); - let mut sanitized_mapping = HashMap::::new(); - let process_result = process_tool_specs( - &database, - conv_id_clone.as_str(), - &server_name, - &mut specs, - &mut sanitized_mapping, - &alias_list, - ®ex, - &telemetry_clone, - ) - .await; - if let Some(sender) = &loading_status_sender_clone { - // Anomalies here are not considered fatal, thus we shall give - // warnings. - let msg = match process_result { - Ok(_) => LoadingMsg::Done { - name: server_name.clone(), - time: time_taken.clone(), - }, - Err(ref e) => LoadingMsg::Warn { - name: server_name.clone(), - msg: eyre::eyre!(e.to_string()), - time: time_taken.clone(), - }, - }; - if let Err(e) = sender.send(msg).await { - warn!( - "Error sending update message to display task: {:?}\nAssume display task has completed", - e - ); - loading_status_sender_clone.take(); - } - } - new_tool_specs_clone - .lock() - .await - .insert(server_name.clone(), (sanitized_mapping, specs)); - has_new_stuff_clone.store(true, Ordering::Release); - // Maintain a record of the server load: - let mut buf_writer = BufWriter::new(&mut record_temp_buf); - if let Err(e) = &process_result { - let _ = queue_warn_message( - server_name.as_str(), - e, - time_taken.as_str(), - &mut buf_writer, - ); - } else { - let _ = queue_success_message( - server_name.as_str(), - time_taken.as_str(), - &mut buf_writer, - ); - } - let _ = buf_writer.flush(); - drop(buf_writer); - let record = String::from_utf8_lossy(&record_temp_buf).to_string(); - let record = if process_result.is_err() { - LoadingRecord::Warn(record) - } else { - LoadingRecord::Success(record) - }; - load_record_clone - .lock() - .await - .entry(server_name.clone()) - .and_modify(|load_record| { - load_record.push(record.clone()); - }) - .or_insert(vec![record]); - }, - Err(e) => { - // Log error to chat Log - error!("Error loading server {server_name}: {:?}", e); - // Maintain a record of the server load: - let mut buf_writer = BufWriter::new(&mut record_temp_buf); - let _ = queue_failure_message(server_name.as_str(), &e, &time_taken, &mut buf_writer); - let _ = buf_writer.flush(); - drop(buf_writer); - let record = String::from_utf8_lossy(&record_temp_buf).to_string(); - let record = LoadingRecord::Err(record); - load_record_clone - .lock() - .await - .entry(server_name.clone()) - .and_modify(|load_record| { - load_record.push(record.clone()); - }) - .or_insert(vec![record]); - // Errors surfaced at this point (i.e. before [process_tool_specs] - // is called) are fatals and should be considered errors - if let Some(sender) = &loading_status_sender_clone { - let msg = LoadingMsg::Error { - name: server_name.clone(), - msg: e, - time: time_taken, - }; - if let Err(e) = sender.send(msg).await { - warn!( - "Error sending update message to display task: {:?}\nAssume display task has completed", - e - ); - loading_status_sender_clone.take(); - } - } - }, - } - if let Some(notify) = notify_weak.upgrade() { - initialized.insert(server_name); - if initialized.len() >= total { - notify.notify_one(); - } - } - }, - UpdateEventMessage::PromptsListResult { - server_name: _, - result: _, - } => {}, - UpdateEventMessage::ResourcesListResult { - server_name: _, - result: _, - } => {}, - UpdateEventMessage::ResourceTemplatesListResult { - server_name: _, - result: _, - } => {}, - UpdateEventMessage::InitStart { server_name } => { - pending_clone.write().await.insert(server_name.clone()); - loading_servers.insert(server_name, std::time::Instant::now()); - }, - } - } - }); - + debug_assert!(messenger_builder.is_some()); + let messenger_builder = messenger_builder.unwrap(); for (mut name, init_res) in pre_initialized { - let messenger = messenger_builder.build_with_name(name.clone()); + let mut messenger = messenger_builder.build_with_name(name.clone()); match init_res { Ok(mut client) => { + let pid = client.get_pid(); + messenger.pid = pid; client.assign_messenger(Box::new(messenger)); let mut client = Arc::new(client); while let Some(collided_client) = clients.insert(name.clone(), client) { @@ -624,99 +382,9 @@ impl ToolManagerBuilder { } } - // Set up task to handle prompt requests - let sender = self.prompt_list_sender.take(); - let receiver = self.prompt_list_receiver.take(); - let prompts = Arc::new(SyncRwLock::new(HashMap::default())); - // TODO: accommodate hot reload of mcp servers - if let (Some(sender), Some(receiver)) = (sender, receiver) { - let clients = clients.iter().fold(HashMap::new(), |mut acc, (n, c)| { - acc.insert(n.clone(), Arc::downgrade(c)); - acc - }); - let prompts_clone = prompts.clone(); - tokio::task::spawn_blocking(move || { - let receiver = Arc::new(std::sync::Mutex::new(receiver)); - loop { - let search_word = receiver.lock().map_err(|e| eyre::eyre!("{:?}", e))?.recv()?; - if clients - .values() - .any(|client| client.upgrade().is_some_and(|c| c.is_prompts_out_of_date())) - { - let mut prompts_wl = prompts_clone.write().map_err(|e| { - eyre::eyre!( - "Error retrieving write lock on prompts for tab complete {}", - e.to_string() - ) - })?; - *prompts_wl = clients.iter().fold( - HashMap::>::new(), - |mut acc, (server_name, client)| { - let Some(client) = client.upgrade() else { - return acc; - }; - let prompt_gets = client.list_prompt_gets(); - let Ok(prompt_gets) = prompt_gets.read() else { - tracing::error!("Error retrieving read lock for prompt gets for tab complete"); - return acc; - }; - for (prompt_name, prompt_get) in prompt_gets.iter() { - acc.entry(prompt_name.clone()) - .and_modify(|bundles| { - bundles.push(PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }); - }) - .or_insert(vec![PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }]); - } - client.prompts_updated(); - acc - }, - ); - } - let prompts_rl = prompts_clone.read().map_err(|e| { - eyre::eyre!( - "Error retrieving read lock on prompts for tab complete {}", - e.to_string() - ) - })?; - let filtered_prompts = prompts_rl - .iter() - .flat_map(|(prompt_name, bundles)| { - if bundles.len() > 1 { - bundles - .iter() - .map(|b| format!("{}/{}", b.server_name, prompt_name)) - .collect() - } else { - vec![prompt_name.to_owned()] - } - }) - .filter(|n| { - if let Some(p) = &search_word { - n.contains(p) - } else { - true - } - }) - .collect::>(); - if let Err(e) = sender.send(filtered_prompts) { - error!("Error sending prompts to chat helper: {:?}", e); - } - } - #[allow(unreachable_code)] - Ok::<(), eyre::Report>(()) - }); - } - Ok(ToolManager { conversation_id, clients, - prompts, pending_clients: pending, notify: Some(notify), loading_status_sender, @@ -727,6 +395,15 @@ impl ToolManagerBuilder { mcp_load_record: load_record, agent, disabled_servers: disabled_servers_display, + prompts_sender_receiver_pair: { + if let (Some(sender), Some(receiver)) = (self.prompt_query_sender, self.prompt_query_result_receiver) { + Some((sender, receiver)) + } else { + None + } + }, + messenger_builder: Some(messenger_builder), + is_first_launch: self.is_first_launch, ..Default::default() }) } @@ -743,6 +420,18 @@ pub struct PromptBundle { pub prompt_get: PromptGet, } +#[derive(Clone, Debug)] +pub enum PromptQuery { + List, + Search(Option), +} + +#[derive(Clone, Debug)] +pub enum PromptQueryResult { + List(HashMap>), + Search(Vec), +} + /// Categorizes different types of tool name validation failures: /// - `TooLong`: The tool name exceeds the maximum allowed length /// - `IllegalChar`: The tool name contains characters that are not allowed @@ -790,6 +479,14 @@ type ServerName = String; /// tool name). type NewToolSpecs = Arc, Vec)>>>; +/// A pair of channels used for prompt list communication between the tool manager and chat helper. +/// The sender broadcasts a list of available prompt names, while the receiver listens for +/// search queries to filter the prompt list. +type PromptsChannelPair = ( + tokio::sync::broadcast::Sender, + tokio::sync::broadcast::Receiver, +); + #[derive(Default, Debug)] /// Manages the lifecycle and interactions with tools from various sources, including MCP servers. /// This struct is responsible for initializing tools, handling tool requests, and maintaining @@ -811,19 +508,15 @@ pub struct ToolManager { /// to incorporate newly available tools from MCP servers. pub has_new_stuff: Arc, + /// Used by methods on the [ToolManager] to retrieve information from the orchestrator thread + prompts_sender_receiver_pair: Option, + /// Storage for newly discovered tool specifications from MCP servers that haven't yet been /// integrated into the main tool registry. This field holds a thread-safe reference to a map /// of server names to their tool specifications and name mappings, allowing concurrent updates /// from server initialization processes. new_tool_specs: NewToolSpecs, - /// Cache for prompts collected from different servers. - /// Key: prompt name - /// Value: a list of PromptBundle that has a prompt of this name. - /// This cache helps resolve prompt requests efficiently and handles - /// cases where multiple servers offer prompts with the same name. - pub prompts: Arc>>>, - /// A notifier to understand if the initial loading has completed. /// This is only used for initial loading and is discarded after. notify: Option>, @@ -858,9 +551,17 @@ pub struct ToolManager { /// List of disabled MCP server names for display purposes disabled_servers: Vec, - /// A collection of preferences that pertains to the conversation. + /// A builder for mcp clients to communicate with the orchestrator task + /// We need to store this for when we switch agent - we need to be spawning messengers that are + /// already listened to by the orchestrator task + messenger_builder: Option, + + /// A collection of preferences that pertains to the conversation /// As far as tool manager goes, this is relevant for tool and server filters + /// We need to put this behind a lock because the orchestrator task depends on agent pub agent: Arc>, + + is_first_launch: bool, } impl Clone for ToolManager { @@ -870,7 +571,6 @@ impl Clone for ToolManager { clients: self.clients.clone(), has_new_stuff: self.has_new_stuff.clone(), new_tool_specs: self.new_tool_specs.clone(), - prompts: self.prompts.clone(), tn_map: self.tn_map.clone(), schema: self.schema.clone(), is_interactive: self.is_interactive, @@ -882,6 +582,35 @@ impl Clone for ToolManager { } impl ToolManager { + /// Swapping agent involves the following: + /// - Dropping all of the clients first to avoid resource contention + /// - Clearing fields that are already referenced by background tasks. We can't simply spawn new + /// instances of these fields because one or more background tasks are already depending on it + /// - Building a new tool manager builder from the current tool manager + /// - Building a tool manager from said tool manager builder + /// - Swapping the old with the new (the old would be dropped after we exit the scope of this + /// function) + /// - Calling load tools + pub async fn swap_agent(&mut self, os: &mut Os, output: &mut impl Write, agent: &Agent) -> eyre::Result<()> { + self.clients.clear(); + + let mut agent_lock = self.agent.lock().await; + *agent_lock = agent.clone(); + drop(agent_lock); + + self.mcp_load_record.lock().await.clear(); + + let builder = ToolManagerBuilder::from(&mut *self); + let mut new_tool_manager = builder.build(os, Box::new(std::io::sink()), true).await?; + std::mem::swap(self, &mut new_tool_manager); + + // we can discard the output here and let background server load take care of getting the + // new tools + let _ = self.load_tools(os, output).await?; + + Ok(()) + } + pub async fn load_tools( &mut self, os: &mut Os, @@ -957,7 +686,7 @@ impl ToolManager { }); // We need to cast it to erase the type otherwise the compiler will default to static // dispatch, which would result in an error of inconsistent match arm return type. - let timeout_fut: Pin>> = if self.clients.is_empty() { + let timeout_fut: Pin>> = if self.clients.is_empty() || !self.is_first_launch { // If there is no server loaded, we want to resolve immediately Box::pin(future::ready(())) } else if self.is_interactive { @@ -1185,7 +914,26 @@ impl ToolManager { } } - #[allow(clippy::await_holding_lock)] + pub async fn list_prompts(&self) -> Result>, GetPromptError> { + if let Some((query_sender, query_result_receiver)) = &self.prompts_sender_receiver_pair { + let mut new_receiver = query_result_receiver.resubscribe(); + query_sender + .send(PromptQuery::List) + .map_err(|e| GetPromptError::General(eyre::eyre!(e)))?; + let query_result = new_receiver + .recv() + .await + .map_err(|e| GetPromptError::General(eyre::eyre!(e)))?; + + Ok(match query_result { + PromptQueryResult::List(list) => list, + PromptQueryResult::Search(_) => return Err(GetPromptError::IncorrectResponseType), + }) + } else { + Err(GetPromptError::MissingChannel) + } + } + pub async fn get_prompt( &self, name: String, @@ -1196,93 +944,61 @@ impl ToolManager { Some((server_name, prompt_name)) => (Some(server_name.to_string()), Some(prompt_name.to_string())), }; let prompt_name = prompt_name.ok_or(GetPromptError::MissingPromptName)?; - // We need to use a sync lock here because this lock is also used in a blocking thread, - // necessitated by the fact that said thread is also responsible for using a sync channel, - // which is itself necessitated by the fact that consumer of said channel is calling from a - // sync function - let mut prompts_wl = self - .prompts - .write() - .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; - let mut maybe_bundles = prompts_wl.get(&prompt_name); - let mut has_retried = false; - 'blk: loop { - match (maybe_bundles, server_name.as_ref(), has_retried) { + + if let Some((query_sender, query_result_receiver)) = &self.prompts_sender_receiver_pair { + query_sender + .send(PromptQuery::List) + .map_err(|e| GetPromptError::General(eyre::eyre!(e)))?; + let prompts = query_result_receiver + .resubscribe() + .recv() + .await + .map_err(|e| GetPromptError::General(eyre::eyre!(e)))?; + let PromptQueryResult::List(prompts) = prompts else { + return Err(GetPromptError::IncorrectResponseType); + }; + + match (prompts.get(&prompt_name), server_name.as_ref()) { // If we have more than one eligible clients but no server name specified - (Some(bundles), None, _) if bundles.len() > 1 => { - break 'blk Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { + (Some(bundles), None) if bundles.len() > 1 => { + Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { bundles.iter().fold("\n".to_string(), |mut acc, b| { acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); acc }) - })); + })) }, // Normal case where we have enough info to proceed // Note that if bundle exists, it should never be empty - (Some(bundles), sn, _) => { + (Some(bundles), sn) => { let bundle = if bundles.len() > 1 { - let Some(server_name) = sn else { - maybe_bundles = None; - continue 'blk; + let Some(sn) = sn else { + return Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { + bundles.iter().fold("\n".to_string(), |mut acc, b| { + acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); + acc + }) + })); }; - let bundle = bundles.iter().find(|b| b.server_name == *server_name); + let bundle = bundles.iter().find(|b| b.server_name == *sn); match bundle { Some(bundle) => bundle, None => { - maybe_bundles = None; - continue 'blk; + return Err(GetPromptError::AmbiguousPrompt(prompt_name.clone(), { + bundles.iter().fold("\n".to_string(), |mut acc, b| { + acc.push_str(&format!("- @{}/{}\n", b.server_name, prompt_name)); + acc + }) + })); }, } } else { bundles.first().ok_or(GetPromptError::MissingPromptInfo)? }; - let server_name = bundle.server_name.clone(); - let client = self.clients.get(&server_name).ok_or(GetPromptError::MissingClient)?; - // Here we lazily update the out of date cache - if client.is_prompts_out_of_date() { - let prompt_gets = client.list_prompt_gets(); - let prompt_gets = prompt_gets - .read() - .map_err(|e| GetPromptError::Synchronization(e.to_string()))?; - for (prompt_name, prompt_get) in prompt_gets.iter() { - prompts_wl - .entry(prompt_name.clone()) - .and_modify(|bundles| { - let mut is_modified = false; - for bundle in &mut *bundles { - let mut updated_bundle = PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }; - if bundle.server_name == *server_name { - std::mem::swap(bundle, &mut updated_bundle); - is_modified = true; - break; - } - } - if !is_modified { - bundles.push(PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }); - } - }) - .or_insert(vec![PromptBundle { - server_name: server_name.clone(), - prompt_get: prompt_get.clone(), - }]); - } - client.prompts_updated(); - } - - let PromptBundle { prompt_get, .. } = prompts_wl - .get(&prompt_name) - .and_then(|bundles| bundles.iter().find(|b| b.server_name == server_name)) - .ok_or(GetPromptError::MissingPromptInfo)?; - // Here we need to convert the positional arguments into key value pair - // The assignment order is assumed to be the order of args as they are - // presented in PromptGet::arguments + let server_name = &bundle.server_name; + let client = self.clients.get(server_name).ok_or(GetPromptError::MissingClient)?; + let PromptBundle { prompt_get, .. } = bundle; let args = if let (Some(schema), Some(value)) = (&prompt_get.arguments, &arguments) { let params = schema.iter().zip(value.iter()).fold( HashMap::::new(), @@ -1304,57 +1020,573 @@ impl ToolManager { Some(serde_json::Value::Object(params)) }; let resp = client.request("prompts/get", params).await?; - break 'blk Ok(resp); - }, - // If we have no eligible clients this would mean one of the following: - // - The prompt does not exist, OR - // - This is the first time we have a query / our cache is out of date - // Both of which means we would have to requery - (None, _, false) => { - has_retried = true; - self.refresh_prompts(&mut prompts_wl)?; - maybe_bundles = prompts_wl.get(&prompt_name); - }, - (_, _, true) => { - break 'blk Err(GetPromptError::PromptNotFound(prompt_name)); + Ok(resp) }, + (None, _) => Err(GetPromptError::PromptNotFound(prompt_name)), } + } else { + Err(GetPromptError::MissingChannel) } } - pub fn refresh_prompts(&self, prompts_wl: &mut HashMap>) -> Result<(), GetPromptError> { - *prompts_wl = self.clients.iter().fold( - HashMap::>::new(), - |mut acc, (server_name, client)| { - let prompt_gets = client.list_prompt_gets(); - let Ok(prompt_gets) = prompt_gets.read() else { - tracing::error!("Error encountered while retrieving read lock"); - return acc; - }; - for (prompt_name, prompt_get) in prompt_gets.iter() { - acc.entry(prompt_name.clone()) - .and_modify(|bundles| { - bundles.push(PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }); - }) - .or_insert(vec![PromptBundle { - server_name: server_name.to_owned(), - prompt_get: prompt_get.clone(), - }]); - } - acc - }, - ); - Ok(()) - } - pub async fn pending_clients(&self) -> Vec { self.pending_clients.read().await.iter().cloned().collect::>() } } +type DisplayTaskJoinHandle = JoinHandle>; +type LoadingStatusSender = tokio::sync::mpsc::Sender; + +/// This function spawns a background task whose sole responsibility is to listen for incoming +/// server loading status and display them to the output. +/// It returns a join handle to the task as well as a sender with which loading status is to be +/// reported. +fn spawn_display_task( + interactive: bool, + total: usize, + disabled_servers: Vec<(String, CustomToolConfig)>, + mut output: Box, +) -> (Option, Option) { + if interactive && (total > 0 || !disabled_servers.is_empty()) { + let (tx, mut rx) = tokio::sync::mpsc::channel::(50); + ( + Some(tokio::task::spawn(async move { + let mut spinner_logo_idx: usize = 0; + let mut complete: usize = 0; + let mut failed: usize = 0; + + // Show disabled servers immediately + for (server_name, _) in &disabled_servers { + queue_disabled_message(server_name, &mut output)?; + } + + if total > 0 { + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + } + + loop { + match tokio::time::timeout(Duration::from_millis(50), rx.recv()).await { + Ok(Some(recv_result)) => match recv_result { + LoadingMsg::Done { name, time } => { + complete += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_success_message(&name, &time, &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Error { name, msg, time } => { + failed += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + queue_failure_message(&name, &msg, time.as_str(), &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Warn { name, msg, time } => { + complete += 1; + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = eyre::eyre!(msg.to_string()); + queue_warn_message(&name, &msg, time.as_str(), &mut output)?; + queue_init_message(spinner_logo_idx, complete, failed, total, &mut output)?; + }, + LoadingMsg::Terminate { still_loading } => { + if !still_loading.is_empty() && total > 0 { + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + let msg = still_loading.iter().fold(String::new(), |mut acc, server_name| { + acc.push_str(format!("\n - {server_name}").as_str()); + acc + }); + let msg = eyre::eyre!(msg); + queue_incomplete_load_message(complete, total, &msg, &mut output)?; + } else if total > 0 { + // Clear the loading line if we have enabled servers + execute!( + output, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + terminal::Clear(terminal::ClearType::CurrentLine), + )?; + } + execute!(output, style::Print("\n"),)?; + break; + }, + }, + Err(_e) => { + spinner_logo_idx = (spinner_logo_idx + 1) % SPINNER_CHARS.len(); + execute!( + output, + cursor::SavePosition, + cursor::MoveToColumn(0), + cursor::MoveUp(1), + style::Print(SPINNER_CHARS[spinner_logo_idx]), + cursor::RestorePosition + )?; + }, + _ => break, + } + output.flush()?; + } + Ok::<_, eyre::Report>(()) + })), + Some(tx), + ) + } else { + (None, None) + } +} + +/// This function spawns the orchestrator task that has the following responsibilities: +/// - Listens for server driven events (see [UpdateEventMessage] for a list of current applicable +/// events). These are things such as tool list (because we fetch tools in the background), prompt +/// list, tool list update, and prompt list updates. In the future, if when we support sampling +/// and we have not yet moved to the official rust MCP crate, we would also be using this task to +/// facilitate it. +/// - Listens for prompt list request and serve them. Unlike tools, we do *not* cache prompts on the +/// conversation state. This is because prompts do not need to be sent to the model every turn. +/// Instead, the prompts are cached in a hashmap that is owned by the orchestrator task. +/// +/// Note that there should be exactly one instance of this task running per session. Should there +/// be any need to instantiate a new [ToolManager] (e.g. swapping agents), see +/// [ToolManager::swap_agent] for how this should be done. +#[allow(clippy::too_many_arguments)] +fn spawn_orchestrator_task( + has_new_stuff: Arc, + mut loading_servers: HashMap, + mut msg_rx: tokio::sync::mpsc::Receiver, + mut prompt_list_receiver: tokio::sync::broadcast::Receiver, + mut prompt_list_sender: tokio::sync::broadcast::Sender, + pending: Arc>>, + agent: Arc>, + database: Database, + regex: Regex, + notify_weak: std::sync::Weak, + load_record: Arc>>>, + telemetry: TelemetryThread, + loading_status_sender: Option, + new_tool_specs: NewToolSpecs, + total: usize, + conv_id: String, +) { + tokio::spawn(async move { + use tokio::sync::broadcast::Sender as BroadcastSender; + use tokio::sync::mpsc::Sender as MpscSender; + + let mut record_temp_buf = Vec::::new(); + let mut initialized = HashSet::::new(); + let mut prompts = HashMap::>::new(); + + enum ToolFilter { + All, + List(HashSet), + } + + impl ToolFilter { + pub fn should_include(&self, tool_name: &str) -> bool { + match self { + Self::All => true, + Self::List(set) => set.contains(tool_name), + } + } + } + + // We separate this into its own function for ease of maintenance since things written + // in select arms don't have type hints + #[inline] + async fn handle_prompt_queries( + query: PromptQuery, + prompts: &HashMap>, + prompt_query_response_sender: &mut BroadcastSender, + ) { + match query { + PromptQuery::List => { + let query_res = PromptQueryResult::List(prompts.clone()); + if let Err(e) = prompt_query_response_sender.send(query_res) { + error!("Error sending prompts to chat helper: {:?}", e); + } + }, + PromptQuery::Search(search_word) => { + let filtered_prompts = prompts + .iter() + .flat_map(|(prompt_name, bundles)| { + if bundles.len() > 1 { + bundles + .iter() + .map(|b| format!("{}/{}", b.server_name, prompt_name)) + .collect() + } else { + vec![prompt_name.to_owned()] + } + }) + .filter(|n| { + if let Some(p) = &search_word { + n.contains(p) + } else { + true + } + }) + .collect::>(); + + let query_res = PromptQueryResult::Search(filtered_prompts); + if let Err(e) = prompt_query_response_sender.send(query_res) { + error!("Error sending prompts to chat helper: {:?}", e); + } + }, + } + } + + // We separate this into its own function for ease of maintenance since things written + // in select arms don't have type hints + #[inline] + #[allow(clippy::too_many_arguments)] + async fn handle_messenger_msg( + msg: UpdateEventMessage, + loading_servers: &mut HashMap, + record_temp_buf: &mut Vec, + pending: &Arc>>, + agent: &Arc>, + database: &Database, + conv_id: &str, + regex: &Regex, + telemetry_clone: &TelemetryThread, + mut loading_status_sender: Option<&MpscSender>, + new_tool_specs: &NewToolSpecs, + has_new_stuff: &Arc, + load_record: &Arc>>>, + notify_weak: &std::sync::Weak, + initialized: &mut HashSet, + prompts: &mut HashMap>, + total: usize, + ) { + record_temp_buf.clear(); + // For now we will treat every list result as if they contain the + // complete set of tools. This is not necessarily true in the future when + // request method on the mcp client no longer buffers all the pages from + // list calls. + match msg { + UpdateEventMessage::ToolsListResult { + server_name, + result, + pid, + } => { + let pid = pid.unwrap(); + if !is_process_running(pid) { + info!( + "Received tool list result from {server_name} but its associated process {pid} is no longer running. Ignoring." + ); + return; + } + let time_taken = loading_servers + .remove(&server_name) + .map_or("0.0".to_owned(), |init_time| { + let time_taken = (std::time::Instant::now() - init_time).as_secs_f64().abs(); + format!("{:.2}", time_taken) + }); + pending.write().await.remove(&server_name); + let (tool_filter, alias_list) = { + let agent_lock = agent.lock().await; + + // We will assume all tools are allowed if the tool list consists of 1 + // element and it's a * + let tool_filter = if agent_lock.tools.len() == 1 + && agent_lock.tools.first().map(String::as_str).is_some_and(|c| c == "*") + { + ToolFilter::All + } else { + let set = agent_lock + .tools + .iter() + .filter(|tool_name| tool_name.starts_with(&format!("@{server_name}"))) + .map(|full_name| { + match full_name.split_once(MCP_SERVER_TOOL_DELIMITER) { + Some((_, tool_name)) if !tool_name.is_empty() => tool_name, + _ => "*", + } + .to_string() + }) + .collect::>(); + + if set.contains("*") { + ToolFilter::All + } else { + ToolFilter::List(set) + } + }; + + let server_prefix = format!("@{server_name}"); + let alias_list = agent_lock.tool_aliases.iter().fold( + HashMap::::new(), + |mut acc, (full_path, model_tool_name)| { + if full_path.starts_with(&server_prefix) { + if let Some((_, host_tool_name)) = full_path.split_once(MCP_SERVER_TOOL_DELIMITER) { + acc.insert(host_tool_name.to_string(), model_tool_name.clone()); + } + } + acc + }, + ); + + (tool_filter, alias_list) + }; + + match result { + Ok(result) => { + let mut specs = result + .tools + .into_iter() + .filter_map(|v| serde_json::from_value::(v).ok()) + .filter(|spec| tool_filter.should_include(&spec.name)) + .collect::>(); + let mut sanitized_mapping = HashMap::::new(); + let process_result = process_tool_specs( + database, + conv_id, + &server_name, + &mut specs, + &mut sanitized_mapping, + &alias_list, + regex, + telemetry_clone, + ) + .await; + if let Some(sender) = &loading_status_sender { + // Anomalies here are not considered fatal, thus we shall give + // warnings. + let msg = match process_result { + Ok(_) => LoadingMsg::Done { + name: server_name.clone(), + time: time_taken.clone(), + }, + Err(ref e) => LoadingMsg::Warn { + name: server_name.clone(), + msg: eyre::eyre!(e.to_string()), + time: time_taken.clone(), + }, + }; + if let Err(e) = sender.send(msg).await { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + loading_status_sender.take(); + } + } + new_tool_specs + .lock() + .await + .insert(server_name.clone(), (sanitized_mapping, specs)); + has_new_stuff.store(true, Ordering::Release); + // Maintain a record of the server load: + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + if let Err(e) = &process_result { + let _ = + queue_warn_message(server_name.as_str(), e, time_taken.as_str(), &mut buf_writer); + } else { + let _ = + queue_success_message(server_name.as_str(), time_taken.as_str(), &mut buf_writer); + } + let _ = buf_writer.flush(); + drop(buf_writer); + let record = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = if process_result.is_err() { + LoadingRecord::Warn(record) + } else { + LoadingRecord::Success(record) + }; + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + }, + Err(e) => { + // Log error to chat Log + error!("Error loading server {server_name}: {:?}", e); + // Maintain a record of the server load: + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + let _ = queue_failure_message(server_name.as_str(), &e, &time_taken, &mut buf_writer); + let _ = buf_writer.flush(); + drop(buf_writer); + let record = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = LoadingRecord::Err(record); + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + // Errors surfaced at this point (i.e. before [process_tool_specs] + // is called) are fatals and should be considered errors + if let Some(sender) = &loading_status_sender { + let msg = LoadingMsg::Error { + name: server_name.clone(), + msg: e, + time: time_taken, + }; + if let Err(e) = sender.send(msg).await { + warn!( + "Error sending update message to display task: {:?}\nAssume display task has completed", + e + ); + loading_status_sender.take(); + } + } + }, + } + if let Some(notify) = notify_weak.upgrade() { + initialized.insert(server_name); + if initialized.len() >= total { + notify.notify_one(); + } + } + }, + UpdateEventMessage::PromptsListResult { + server_name, + result, + pid, + } => match result { + Ok(prompt_list_result) if pid.is_some() => { + let pid = pid.unwrap(); + if !is_process_running(pid) { + info!( + "Received prompt list result from {server_name} but its associated process {pid} is no longer running. Ignoring." + ); + return; + } + // We first need to clear all the PromptGets that are associated with + // this server because PromptsListResult is declaring what is available + // (and not the diff) + prompts + .values_mut() + .for_each(|bundles| bundles.retain(|bundle| bundle.server_name != server_name)); + + // And then we update them with the new comers + for result in prompt_list_result.prompts { + let Ok(prompt_get) = serde_json::from_value::(result) else { + error!("Failed to deserialize prompt get from server {server_name}"); + continue; + }; + prompts + .entry(prompt_get.name.clone()) + .and_modify(|bundles| { + bundles.push(PromptBundle { + server_name: server_name.clone(), + prompt_get: prompt_get.clone(), + }); + }) + .or_insert_with(|| { + vec![PromptBundle { + server_name: server_name.clone(), + prompt_get, + }] + }); + } + }, + Ok(_) => { + error!("Received prompt list result without pid from {server_name}. Ignoring."); + }, + Err(e) => { + error!("Error fetching prompts from server {server_name}: {:?}", e); + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + let _ = queue_prompts_load_error_message(&server_name, &e, &mut buf_writer); + let _ = buf_writer.flush(); + drop(buf_writer); + let record = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = LoadingRecord::Err(record); + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + }, + }, + UpdateEventMessage::ResourcesListResult { + server_name: _, + result: _, + pid: _, + } => {}, + UpdateEventMessage::ResourceTemplatesListResult { + server_name: _, + result: _, + pid: _, + } => {}, + UpdateEventMessage::InitStart { server_name, .. } => { + pending.write().await.insert(server_name.clone()); + loading_servers.insert(server_name, std::time::Instant::now()); + }, + UpdateEventMessage::Deinit { server_name, .. } => { + // Only prompts are stored here so we'll just be clearing that + // In the future if we are also storing tools, we need to make sure that + // the tools are also pruned. + for (_prompt_name, bundles) in prompts.iter_mut() { + bundles.retain(|bundle| bundle.server_name != server_name); + } + prompts.retain(|_, bundles| !bundles.is_empty()); + has_new_stuff.store(true, Ordering::Release); + }, + } + } + + loop { + tokio::select! { + Ok(query) = prompt_list_receiver.recv() => { + handle_prompt_queries(query, &prompts, &mut prompt_list_sender).await; + }, + Some(msg) = msg_rx.recv() => { + handle_messenger_msg( + msg, + &mut loading_servers, + &mut record_temp_buf, + &pending, + &agent, + &database, + conv_id.as_str(), + ®ex, + &telemetry, + loading_status_sender.as_ref(), + &new_tool_specs, + &has_new_stuff, + &load_record, + ¬ify_weak, + &mut initialized, + &mut prompts, + total + ).await; + }, + // Nothing else to poll + else => { + tracing::info!("Tool manager orchestrator task exited"); + break; + }, + } + } + }); +} + #[allow(clippy::too_many_arguments)] async fn process_tool_specs( database: &Database, @@ -1476,6 +1708,22 @@ fn sanitize_name(orig: String, regex: ®ex::Regex, hasher: &mut impl Hasher) - } } +// Add this function to check if a process is still running +fn is_process_running(pid: u32) -> bool { + #[cfg(unix)] + { + let system = sysinfo::System::new_all(); + system.process(sysinfo::Pid::from(pid as usize)).is_some() + } + #[cfg(windows)] + { + // TODO: fill in the process health check for windows when when we officially support + // windows + _ = pid; + true + } +} + fn queue_success_message(name: &str, time_taken: &str, output: &mut impl Write) -> eyre::Result<()> { Ok(queue!( output, @@ -1623,6 +1871,14 @@ fn queue_incomplete_load_message( )?) } +fn queue_prompts_load_error_message(name: &str, msg: &eyre::Report, output: &mut impl Write) -> eyre::Result<()> { + Ok(queue!( + output, + style::Print(format!("Prompt list for {name} failed with the following message: \n")), + style::Print(msg), + )?) +} + #[cfg(test)] mod tests { use super::*; diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 0163a37ae5..3daaf50752 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::io::Write; use std::sync::Arc; -use std::sync::atomic::Ordering; use crossterm::{ queue, @@ -31,7 +30,6 @@ use crate::mcp_client::{ JsonRpcStdioTransport, MessageContent, Messenger, - PromptGet, ServerCapabilities, StdioTransport, ToolCallResult, @@ -172,9 +170,9 @@ impl CustomToolClient { } } - pub fn list_prompt_gets(&self) -> Arc>> { + pub fn get_pid(&self) -> Option { match self { - CustomToolClient::Stdio { client, .. } => client.prompt_gets.clone(), + CustomToolClient::Stdio { client, .. } => client.server_process_id.as_ref().map(|pid| pid.as_u32()), } } @@ -184,18 +182,6 @@ impl CustomToolClient { CustomToolClient::Stdio { client, .. } => Ok(client.notify(method, params).await?), } } - - pub fn is_prompts_out_of_date(&self) -> bool { - match self { - CustomToolClient::Stdio { client, .. } => client.is_prompts_out_of_date.load(Ordering::Relaxed), - } - } - - pub fn prompts_updated(&self) { - match self { - CustomToolClient::Stdio { client, .. } => client.is_prompts_out_of_date.store(false, Ordering::Relaxed), - } - } } /// Represents a custom tool that can be invoked through the Model Context Protocol (MCP). diff --git a/crates/chat-cli/src/mcp_client/client.rs b/crates/chat-cli/src/mcp_client/client.rs index 004c0623a9..27918c5773 100644 --- a/crates/chat-cli/src/mcp_client/client.rs +++ b/crates/chat-cli/src/mcp_client/client.rs @@ -124,7 +124,7 @@ pub struct Client { server_name: String, transport: Arc, timeout: u64, - server_process_id: Option, + pub server_process_id: Option, client_info: serde_json::Value, current_id: Arc, pub messenger: Option>, @@ -276,6 +276,9 @@ where if let Some(process_id) = self.server_process_id { let _ = terminate_process(process_id); } + if let Some(ref messenger) = self.messenger { + messenger.send_deinit_msg(); + } } } @@ -604,44 +607,34 @@ fn examine_server_capabilities(ser_cap: &JsonRpcResponse) -> Result<(), ClientEr Ok(()) } -// TODO: after we move prompts to tool manager, use the messenger to notify the listener spawned by -// tool manager to update its own field. Currently this function does not make use of the -// messesnger. #[allow(clippy::borrowed_box)] -async fn fetch_prompts_and_notify_with_messenger(client: &Client, _messenger: Option<&Box>) +async fn fetch_prompts_and_notify_with_messenger(client: &Client, messenger: Option<&Box>) where T: Transport, { - let Ok(resp) = client.request("prompts/list", None).await else { - tracing::error!("Prompt list query failed for {0}", client.server_name); - return; - }; - let Some(result) = resp.result else { - tracing::warn!("Prompt list query returned no result for {0}", client.server_name); - return; - }; - let Some(prompts) = result.get("prompts") else { - tracing::warn!( - "Prompt list query result contained no field named prompts for {0}", - client.server_name - ); - return; - }; - let Ok(prompts) = serde_json::from_value::>(prompts.clone()) else { - tracing::error!("Prompt list query deserialization failed for {0}", client.server_name); - return; - }; - let Ok(mut lock) = client.prompt_gets.write() else { - tracing::error!( - "Failed to obtain write lock for prompt list query for {0}", - client.server_name - ); - return; + let prompt_list_result = 'prompt_list_result: { + let Ok(resp) = client.request("prompts/list", None).await else { + tracing::error!("Prompt list query failed for {0}", client.server_name); + return; + }; + let Some(result) = resp.result else { + tracing::warn!("Prompt list query returned no result for {0}", client.server_name); + return; + }; + let prompt_list_result = match serde_json::from_value::(result) { + Ok(res) => res, + Err(e) => { + let msg = format!("Failed to deserialize tool result from {}: {:?}", client.server_name, e); + break 'prompt_list_result Err(eyre::eyre!(msg)); + }, + }; + Ok::(prompt_list_result) }; - lock.clear(); - for prompt in prompts { - let name = prompt.name.clone(); - lock.insert(name, prompt); + + if let Some(messenger) = messenger { + if let Err(e) = messenger.send_prompts_list_result(prompt_list_result).await { + tracing::error!("Failed to send prompt result through messenger: {:?}", e); + } } } @@ -674,11 +667,11 @@ where }; Ok::(tool_list_result) }; + if let Some(messenger) = messenger { - let _ = messenger - .send_tools_list_result(tool_list_result) - .await - .map_err(|e| tracing::error!("Failed to send tool result through messenger {:?}", e)); + if let Err(e) = messenger.send_tools_list_result(tool_list_result).await { + tracing::error!("Failed to send tool result through messenger {:?}", e); + } } } diff --git a/crates/chat-cli/src/mcp_client/messenger.rs b/crates/chat-cli/src/mcp_client/messenger.rs index 14f79e518a..75723cd9c7 100644 --- a/crates/chat-cli/src/mcp_client/messenger.rs +++ b/crates/chat-cli/src/mcp_client/messenger.rs @@ -37,6 +37,9 @@ pub trait Messenger: std::fmt::Debug + Send + Sync + 'static { /// Signals to the orchestrator that a server has started initializing async fn send_init_msg(&self) -> Result<(), MessengerError>; + /// Signals to the orchestrator that a server has deinitialized + fn send_deinit_msg(&self); + /// Creates a duplicate of the messenger object /// This function is used to create a new instance of the messenger with the same configuration fn duplicate(&self) -> Box; @@ -79,6 +82,8 @@ impl Messenger for NullMessenger { Ok(()) } + fn send_deinit_msg(&self) {} + fn duplicate(&self) -> Box { Box::new(NullMessenger) } From fdeb08d39123b05fbad77eea145fbccec91896e4 Mon Sep 17 00:00:00 2001 From: abhraina-aws Date: Wed, 20 Aug 2025 16:19:25 -0700 Subject: [PATCH 3/7] feat: Implement wildcard pattern matching for agent allowedTools (#2612) - Add globset-based pattern matching to support wildcards (* and ?) in allowedTools - Create util/pattern_matching.rs module with matches_any_pattern function - Update all native tools (fs_read, fs_write, execute_bash, use_aws, knowledge) to use pattern matching - Update MCP custom tools to support wildcard patterns while preserving exact server-level matching - Standardize imports across tool files for consistency - Maintain backward compatibility with existing exact-match behavior Enables agent configs like: - "fs_*" matches fs_read, fs_write - "@mcp-server/tool_*" matches tool_read, tool_write - "execute_*" matches execute_bash, execute_cmd --- crates/chat-cli/src/cli/agent/mod.rs | 137 ++++++++++++++++-- crates/chat-cli/src/cli/chat/conversation.rs | 2 + .../src/cli/chat/tools/custom_tool.rs | 21 +-- .../src/cli/chat/tools/execute/mod.rs | 3 +- crates/chat-cli/src/cli/chat/tools/fs_read.rs | 3 +- .../chat-cli/src/cli/chat/tools/fs_write.rs | 3 +- .../chat-cli/src/cli/chat/tools/knowledge.rs | 3 +- crates/chat-cli/src/cli/chat/tools/use_aws.rs | 3 +- crates/chat-cli/src/util/mod.rs | 1 + crates/chat-cli/src/util/pattern_matching.rs | 65 +++++++++ docs/agent-format.md | 64 +++++++- 11 files changed, 276 insertions(+), 29 deletions(-) create mode 100644 crates/chat-cli/src/util/pattern_matching.rs diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index a11c0cb7e2..0f8e425e57 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -693,18 +693,31 @@ impl Agents { /// Returns a label to describe the permission status for a given tool. pub fn display_label(&self, tool_name: &str, origin: &ToolOrigin) -> String { + use crate::util::pattern_matching::matches_any_pattern; + let tool_trusted = self.get_active().is_some_and(|a| { + if matches!(origin, &ToolOrigin::Native) { + return matches_any_pattern(&a.allowed_tools, tool_name); + } + a.allowed_tools.iter().any(|name| { - // Here the tool names can take the following forms: - // - @{server_name}{delimiter}{tool_name} - // - native_tool_name - name == tool_name && matches!(origin, &ToolOrigin::Native) - || name.strip_prefix("@").is_some_and(|remainder| { - remainder - .split_once(MCP_SERVER_TOOL_DELIMITER) - .is_some_and(|(_left, right)| right == tool_name) - || remainder == >::borrow(origin) - }) + name.strip_prefix("@").is_some_and(|remainder| { + remainder + .split_once(MCP_SERVER_TOOL_DELIMITER) + .is_some_and(|(_left, right)| right == tool_name) + || remainder == >::borrow(origin) + }) || { + if let Some(server_name) = name.strip_prefix("@").and_then(|s| s.split('/').next()) { + if server_name == >::borrow(origin) { + let tool_pattern = format!("@{}/{}", server_name, tool_name); + matches_any_pattern(&a.allowed_tools, &tool_pattern) + } else { + false + } + } else { + false + } + } }) }); @@ -942,4 +955,108 @@ mod tests { assert!(validate_agent_name("invalid!").is_err()); assert!(validate_agent_name("invalid space").is_err()); } + + #[test] + fn test_display_label_no_active_agent() { + let agents = Agents::default(); + + let label = agents.display_label("fs_read", &ToolOrigin::Native); + // With no active agent, it should fall back to default permissions + // fs_read has a default of "trusted" + assert!(label.contains("trusted"), "fs_read should show default trusted permission, instead found: {}", label); + } + + #[test] + fn test_display_label_trust_all_tools() { + let mut agents = Agents::default(); + agents.trust_all_tools = true; + + // Should be trusted even if not in allowed_tools + let label = agents.display_label("random_tool", &ToolOrigin::Native); + assert!(label.contains("trusted"), "trust_all_tools should make everything trusted, instead found: {}", label); + } + + #[test] + fn test_display_label_default_permissions() { + let agents = Agents::default(); + + // Test default permissions for known tools + let fs_read_label = agents.display_label("fs_read", &ToolOrigin::Native); + assert!(fs_read_label.contains("trusted"), "fs_read should be trusted by default, instead found: {}", fs_read_label); + + let fs_write_label = agents.display_label("fs_write", &ToolOrigin::Native); + assert!(fs_write_label.contains("not trusted"), "fs_write should not be trusted by default, instead found: {}", fs_write_label); + + let execute_bash_label = agents.display_label("execute_bash", &ToolOrigin::Native); + assert!(execute_bash_label.contains("read-only"), "execute_bash should show read-only by default, instead found: {}", execute_bash_label); + } + + #[test] + fn test_display_label_comprehensive_patterns() { + let mut agents = Agents::default(); + + // Create agent with all types of patterns + let mut allowed_tools = HashSet::new(); + // Native exact match + allowed_tools.insert("fs_read".to_string()); + // Native wildcard + allowed_tools.insert("execute_*".to_string()); + // MCP server exact (allows all tools from that server) + allowed_tools.insert("@server1".to_string()); + // MCP tool exact + allowed_tools.insert("@server2/specific_tool".to_string()); + // MCP tool wildcard + allowed_tools.insert("@server3/tool_*".to_string()); + + let agent = Agent { + schema: "test".to_string(), + name: "test-agent".to_string(), + description: None, + prompt: None, + mcp_servers: Default::default(), + tools: Vec::new(), + tool_aliases: Default::default(), + allowed_tools, + tools_settings: Default::default(), + resources: Vec::new(), + hooks: Default::default(), + use_legacy_mcp_json: false, + path: None, + }; + + agents.agents.insert("test-agent".to_string(), agent); + agents.active_idx = "test-agent".to_string(); + + // Test 1: Native exact match + let label = agents.display_label("fs_read", &ToolOrigin::Native); + assert!(label.contains("trusted"), "fs_read should be trusted (exact match), instead found: {}", label); + + // Test 2: Native wildcard match + let label = agents.display_label("execute_bash", &ToolOrigin::Native); + assert!(label.contains("trusted"), "execute_bash should match execute_* pattern, instead found: {}", label); + + // Test 3: Native no match + let label = agents.display_label("fs_write", &ToolOrigin::Native); + assert!(!label.contains("trusted") || label.contains("not trusted"), "fs_write should not be trusted, instead found: {}", label); + + // Test 4: MCP server exact match (allows any tool from server1) + let label = agents.display_label("any_tool", &ToolOrigin::McpServer("server1".to_string())); + assert!(label.contains("trusted"), "Server-level permission should allow any tool, instead found: {}", label); + + // Test 5: MCP tool exact match + let label = agents.display_label("specific_tool", &ToolOrigin::McpServer("server2".to_string())); + assert!(label.contains("trusted"), "Exact MCP tool should be trusted, instead found: {}", label); + + // Test 6: MCP tool wildcard match + let label = agents.display_label("tool_read", &ToolOrigin::McpServer("server3".to_string())); + assert!(label.contains("trusted"), "tool_read should match @server3/tool_* pattern, instead found: {}", label); + + // Test 7: MCP tool no match + let label = agents.display_label("other_tool", &ToolOrigin::McpServer("server2".to_string())); + assert!(!label.contains("trusted") || label.contains("not trusted"), "Non-matching MCP tool should not be trusted, instead found: {}", label); + + // Test 8: MCP server no match + let label = agents.display_label("some_tool", &ToolOrigin::McpServer("unknown_server".to_string())); + assert!(!label.contains("trusted") || label.contains("not trusted"), "Unknown server should not be trusted, instead found: {}", label); + } } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 7c58febcf7..6b1751cd15 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -135,6 +135,7 @@ impl ConversationState { current_model_id: Option, os: &Os, ) -> Self { + let model = if let Some(model_id) = current_model_id { match get_model_info(&model_id, os).await { Ok(info) => Some(info), @@ -1278,4 +1279,5 @@ mod tests { conversation.set_next_user_message(i.to_string()).await; } } + } diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 3daaf50752..67134f7bf7 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -35,6 +35,8 @@ use crate::mcp_client::{ ToolCallResult, }; use crate::os::Os; +use crate::util::pattern_matching::matches_any_pattern; +use crate::util::MCP_SERVER_TOOL_DELIMITER; // TODO: support http transport type #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] @@ -274,7 +276,6 @@ impl CustomTool { } pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { - use crate::util::MCP_SERVER_TOOL_DELIMITER; let Self { name: tool_name, client, @@ -282,15 +283,17 @@ impl CustomTool { } = self; let server_name = client.get_server_name(); - if agent.allowed_tools.contains(&format!("@{server_name}")) - || agent - .allowed_tools - .contains(&format!("@{server_name}{MCP_SERVER_TOOL_DELIMITER}{tool_name}")) - { - PermissionEvalResult::Allow - } else { - PermissionEvalResult::Ask + let server_pattern = format!("@{server_name}"); + if agent.allowed_tools.contains(&server_pattern) { + return PermissionEvalResult::Allow; + } + + let tool_pattern = format!("@{server_name}{MCP_SERVER_TOOL_DELIMITER}{tool_name}"); + if matches_any_pattern(&agent.allowed_tools, &tool_pattern) { + return PermissionEvalResult::Allow; } + + PermissionEvalResult::Ask } } diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs index f035f9e601..a1e7b9c8e5 100644 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs @@ -23,6 +23,7 @@ use crate::cli::chat::tools::{ }; use crate::cli::chat::util::truncate_safe; use crate::os::Os; +use crate::util::pattern_matching::matches_any_pattern; // Platform-specific modules #[cfg(windows)] @@ -204,7 +205,7 @@ impl ExecuteCommand { let Self { command, .. } = self; let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; - let is_in_allowlist = agent.allowed_tools.contains(tool_name); + let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, tool_name); match agent.tools_settings.get(tool_name) { Some(settings) if is_in_allowlist => { let Settings { diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs index 9285d79305..dc30c336cb 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_read.rs @@ -48,6 +48,7 @@ use crate::cli::chat::{ sanitize_unicode_tags, }; use crate::os::Os; +use crate::util::pattern_matching::matches_any_pattern; #[derive(Debug, Clone, Deserialize)] pub struct FsRead { @@ -118,7 +119,7 @@ impl FsRead { true } - let is_in_allowlist = agent.allowed_tools.contains("fs_read"); + let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_read"); match agent.tools_settings.get("fs_read") { Some(settings) if is_in_allowlist => { let Settings { diff --git a/crates/chat-cli/src/cli/chat/tools/fs_write.rs b/crates/chat-cli/src/cli/chat/tools/fs_write.rs index e305d9304c..79151244f6 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_write.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_write.rs @@ -47,6 +47,7 @@ use crate::cli::agent::{ }; use crate::cli::chat::line_tracker::FileLineTracker; use crate::os::Os; +use crate::util::pattern_matching::matches_any_pattern; static SYNTAX_SET: LazyLock = LazyLock::new(SyntaxSet::load_defaults_newlines); static THEME_SET: LazyLock = LazyLock::new(ThemeSet::load_defaults); @@ -425,7 +426,7 @@ impl FsWrite { denied_paths: Vec, } - let is_in_allowlist = agent.allowed_tools.contains("fs_write"); + let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_write"); match agent.tools_settings.get("fs_write") { Some(settings) if is_in_allowlist => { let Settings { diff --git a/crates/chat-cli/src/cli/chat/tools/knowledge.rs b/crates/chat-cli/src/cli/chat/tools/knowledge.rs index 191fbb86d5..b4392d183f 100644 --- a/crates/chat-cli/src/cli/chat/tools/knowledge.rs +++ b/crates/chat-cli/src/cli/chat/tools/knowledge.rs @@ -19,6 +19,7 @@ use crate::cli::agent::{ }; use crate::database::settings::Setting; use crate::os::Os; +use crate::util::pattern_matching::matches_any_pattern; use crate::util::knowledge_store::KnowledgeStore; /// The Knowledge tool allows storing and retrieving information across chat sessions. @@ -490,7 +491,7 @@ impl Knowledge { pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { _ = self; - if agent.allowed_tools.contains("knowledge") { + if matches_any_pattern(&agent.allowed_tools, "knowledge") { PermissionEvalResult::Allow } else { PermissionEvalResult::Ask diff --git a/crates/chat-cli/src/cli/chat/tools/use_aws.rs b/crates/chat-cli/src/cli/chat/tools/use_aws.rs index 2a70cd1604..01b09126e8 100644 --- a/crates/chat-cli/src/cli/chat/tools/use_aws.rs +++ b/crates/chat-cli/src/cli/chat/tools/use_aws.rs @@ -29,6 +29,7 @@ use crate::cli::agent::{ PermissionEvalResult, }; use crate::os::Os; +use crate::util::pattern_matching::matches_any_pattern; const READONLY_OPS: [&str; 6] = ["get", "describe", "list", "ls", "search", "batch_get"]; @@ -184,7 +185,7 @@ impl UseAws { } let Self { service_name, .. } = self; - let is_in_allowlist = agent.allowed_tools.contains("use_aws"); + let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "use_aws"); match agent.tools_settings.get("use_aws") { Some(settings) if is_in_allowlist => { let settings = match serde_json::from_value::(settings.clone()) { diff --git a/crates/chat-cli/src/util/mod.rs b/crates/chat-cli/src/util/mod.rs index 576ba37acc..5282b6b6ce 100644 --- a/crates/chat-cli/src/util/mod.rs +++ b/crates/chat-cli/src/util/mod.rs @@ -2,6 +2,7 @@ pub mod consts; pub mod directories; pub mod knowledge_store; pub mod open; +pub mod pattern_matching; pub mod process; pub mod spinner; pub mod system_info; diff --git a/crates/chat-cli/src/util/pattern_matching.rs b/crates/chat-cli/src/util/pattern_matching.rs new file mode 100644 index 0000000000..616f1d098e --- /dev/null +++ b/crates/chat-cli/src/util/pattern_matching.rs @@ -0,0 +1,65 @@ +use std::collections::HashSet; +use globset::Glob; + +/// Check if a string matches any pattern in a set of patterns +pub fn matches_any_pattern(patterns: &HashSet, text: &str) -> bool { + patterns.iter().any(|pattern| { + // Exact match first + if pattern == text { + return true; + } + + // Glob pattern match if contains wildcards + if pattern.contains('*') || pattern.contains('?') { + if let Ok(glob) = Glob::new(pattern) { + return glob.compile_matcher().is_match(text); + } + } + + false + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::HashSet; + + #[test] + fn test_exact_match() { + let mut patterns = HashSet::new(); + patterns.insert("fs_read".to_string()); + + assert!(matches_any_pattern(&patterns, "fs_read")); + assert!(!matches_any_pattern(&patterns, "fs_write")); + } + + #[test] + fn test_wildcard_patterns() { + let mut patterns = HashSet::new(); + patterns.insert("fs_*".to_string()); + + assert!(matches_any_pattern(&patterns, "fs_read")); + assert!(matches_any_pattern(&patterns, "fs_write")); + assert!(!matches_any_pattern(&patterns, "execute_bash")); + } + + #[test] + fn test_mcp_patterns() { + let mut patterns = HashSet::new(); + patterns.insert("@mcp-server/*".to_string()); + + assert!(matches_any_pattern(&patterns, "@mcp-server/tool1")); + assert!(matches_any_pattern(&patterns, "@mcp-server/tool2")); + assert!(!matches_any_pattern(&patterns, "@other-server/tool")); + } + + #[test] + fn test_question_mark_wildcard() { + let mut patterns = HashSet::new(); + patterns.insert("fs_?ead".to_string()); + + assert!(matches_any_pattern(&patterns, "fs_read")); + assert!(!matches_any_pattern(&patterns, "fs_write")); + } +} diff --git a/docs/agent-format.md b/docs/agent-format.md index 4fdfe1a276..328eb9689d 100644 --- a/docs/agent-format.md +++ b/docs/agent-format.md @@ -144,18 +144,72 @@ The `allowedTools` field specifies which tools can be used without prompting the { "allowedTools": [ "fs_read", + "fs_*", "@git/git_status", + "@server/read_*", "@fetch" ] } ``` -You can allow: -- Specific built-in tools by name (e.g., `"fs_read"`) -- Specific MCP tools using `@server_name/tool_name` (e.g., `"@git/git_status"`) -- All tools from an MCP server using `@server_name` (e.g., `"@fetch"`) +You can allow tools using several patterns: -Unlike the `tools` field, the `allowedTools` field does not support the `"*"` wildcard for allowing all tools. To allow specific tools, you must list them individually or use server-level wildcards with the `@server_name` syntax. +### Exact Matches +- **Built-in tools**: `"fs_read"`, `"execute_bash"`, `"knowledge"` +- **Specific MCP tools**: `"@server_name/tool_name"` (e.g., `"@git/git_status"`) +- **All tools from MCP server**: `"@server_name"` (e.g., `"@fetch"`) + +### Wildcard Patterns +The `allowedTools` field supports glob-style wildcard patterns using `*` and `?`: + +#### Native Tool Patterns +- **Prefix wildcard**: `"fs_*"` → matches `fs_read`, `fs_write`, `fs_anything` +- **Suffix wildcard**: `"*_bash"` → matches `execute_bash`, `run_bash` +- **Middle wildcard**: `"fs_*_tool"` → matches `fs_read_tool`, `fs_write_tool` +- **Single character**: `"fs_?ead"` → matches `fs_read`, `fs_head` (but not `fs_write`) + +#### MCP Tool Patterns +- **Tool prefix**: `"@server/read_*"` → matches `@server/read_file`, `@server/read_config` +- **Tool suffix**: `"@server/*_get"` → matches `@server/issue_get`, `@server/data_get` +- **Server pattern**: `"@*-mcp/read_*"` → matches `@git-mcp/read_file`, `@db-mcp/read_data` +- **Any tool from pattern servers**: `"@git-*/*"` → matches any tool from servers matching `git-*` + +### Examples + +```json +{ + "allowedTools": [ + // Exact matches + "fs_read", + "knowledge", + "@server/specific_tool", + + // Native tool wildcards + "fs_*", // All filesystem tools + "execute_*", // All execute tools + "*_test", // Any tool ending in _test + + // MCP tool wildcards + "@server/api_*", // All API tools from server + "@server/read_*", // All read tools from server + "@git-server/get_*_info", // Tools like get_user_info, get_repo_info + "@*/status", // Status tool from any server + + // Server-level permissions + "@fetch", // All tools from fetch server + "@git-*" // All tools from any git-* server + ] +} +``` + +### Pattern Matching Rules +- **`*`** matches any sequence of characters (including none) +- **`?`** matches exactly one character +- **Exact matches** take precedence over patterns +- **Server-level permissions** (`@server_name`) allow all tools from that server +- **Case-sensitive** matching + +Unlike the `tools` field, the `allowedTools` field does not support the `"*"` wildcard for allowing all tools. To allow tools, you must use specific patterns or server-level permissions. ## ToolsSettings Field From 334cbd0d8b6fac739c38b8249c51eebf622dd0b6 Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Wed, 20 Aug 2025 16:45:34 -0700 Subject: [PATCH 4/7] fixes unwrap on pid (#2657) --- crates/chat-cli/src/cli/chat/tool_manager.rs | 37 ++++++++++++++++---- 1 file changed, 30 insertions(+), 7 deletions(-) diff --git a/crates/chat-cli/src/cli/chat/tool_manager.rs b/crates/chat-cli/src/cli/chat/tool_manager.rs index a8c8db40d5..0304a2568b 100644 --- a/crates/chat-cli/src/cli/chat/tool_manager.rs +++ b/crates/chat-cli/src/cli/chat/tool_manager.rs @@ -1285,13 +1285,6 @@ fn spawn_orchestrator_task( result, pid, } => { - let pid = pid.unwrap(); - if !is_process_running(pid) { - info!( - "Received tool list result from {server_name} but its associated process {pid} is no longer running. Ignoring." - ); - return; - } let time_taken = loading_servers .remove(&server_name) .map_or("0.0".to_owned(), |init_time| { @@ -1347,6 +1340,36 @@ fn spawn_orchestrator_task( match result { Ok(result) => { + if pid.is_none_or(|pid| !is_process_running(pid)) { + let pid = pid.map_or("unknown".to_string(), |pid| pid.to_string()); + info!( + "Received tool list result from {server_name} but its associated process {pid} is no longer running. Ignoring." + ); + + let mut buf_writer = BufWriter::new(&mut *record_temp_buf); + let _ = queue_failure_message( + &server_name, + &eyre::eyre!("Process associated is no longer running"), + &time_taken, + &mut buf_writer, + ); + let _ = buf_writer.flush(); + drop(buf_writer); + let record_content = String::from_utf8_lossy(record_temp_buf).to_string(); + let record = LoadingRecord::Err(record_content); + + load_record + .lock() + .await + .entry(server_name.clone()) + .and_modify(|load_record| { + load_record.push(record.clone()); + }) + .or_insert(vec![record]); + + return; + } + let mut specs = result .tools .into_iter() From 0ce09b46f5dcb40adbba222100170ecdbc1fa014 Mon Sep 17 00:00:00 2001 From: evanliu048 Date: Wed, 20 Aug 2025 16:54:52 -0700 Subject: [PATCH 5/7] ci fix (#2658) --- crates/chat-cli/src/cli/agent/mod.rs | 114 +++++++++++++----- crates/chat-cli/src/cli/chat/conversation.rs | 2 - .../src/cli/chat/tools/custom_tool.rs | 2 +- .../chat-cli/src/cli/chat/tools/knowledge.rs | 2 +- crates/chat-cli/src/util/pattern_matching.rs | 16 +-- 5 files changed, 94 insertions(+), 42 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 0f8e425e57..d75a1f733d 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -694,12 +694,12 @@ impl Agents { /// Returns a label to describe the permission status for a given tool. pub fn display_label(&self, tool_name: &str, origin: &ToolOrigin) -> String { use crate::util::pattern_matching::matches_any_pattern; - + let tool_trusted = self.get_active().is_some_and(|a| { if matches!(origin, &ToolOrigin::Native) { return matches_any_pattern(&a.allowed_tools, tool_name); } - + a.allowed_tools.iter().any(|name| { name.strip_prefix("@").is_some_and(|remainder| { remainder @@ -959,42 +959,62 @@ mod tests { #[test] fn test_display_label_no_active_agent() { let agents = Agents::default(); - + let label = agents.display_label("fs_read", &ToolOrigin::Native); // With no active agent, it should fall back to default permissions // fs_read has a default of "trusted" - assert!(label.contains("trusted"), "fs_read should show default trusted permission, instead found: {}", label); + assert!( + label.contains("trusted"), + "fs_read should show default trusted permission, instead found: {}", + label + ); } #[test] fn test_display_label_trust_all_tools() { let mut agents = Agents::default(); agents.trust_all_tools = true; - + // Should be trusted even if not in allowed_tools let label = agents.display_label("random_tool", &ToolOrigin::Native); - assert!(label.contains("trusted"), "trust_all_tools should make everything trusted, instead found: {}", label); + assert!( + label.contains("trusted"), + "trust_all_tools should make everything trusted, instead found: {}", + label + ); } #[test] fn test_display_label_default_permissions() { let agents = Agents::default(); - + // Test default permissions for known tools let fs_read_label = agents.display_label("fs_read", &ToolOrigin::Native); - assert!(fs_read_label.contains("trusted"), "fs_read should be trusted by default, instead found: {}", fs_read_label); - + assert!( + fs_read_label.contains("trusted"), + "fs_read should be trusted by default, instead found: {}", + fs_read_label + ); + let fs_write_label = agents.display_label("fs_write", &ToolOrigin::Native); - assert!(fs_write_label.contains("not trusted"), "fs_write should not be trusted by default, instead found: {}", fs_write_label); - + assert!( + fs_write_label.contains("not trusted"), + "fs_write should not be trusted by default, instead found: {}", + fs_write_label + ); + let execute_bash_label = agents.display_label("execute_bash", &ToolOrigin::Native); - assert!(execute_bash_label.contains("read-only"), "execute_bash should show read-only by default, instead found: {}", execute_bash_label); + assert!( + execute_bash_label.contains("read-only"), + "execute_bash should show read-only by default, instead found: {}", + execute_bash_label + ); } #[test] fn test_display_label_comprehensive_patterns() { let mut agents = Agents::default(); - + // Create agent with all types of patterns let mut allowed_tools = HashSet::new(); // Native exact match @@ -1007,7 +1027,7 @@ mod tests { allowed_tools.insert("@server2/specific_tool".to_string()); // MCP tool wildcard allowed_tools.insert("@server3/tool_*".to_string()); - + let agent = Agent { schema: "test".to_string(), name: "test-agent".to_string(), @@ -1023,40 +1043,72 @@ mod tests { use_legacy_mcp_json: false, path: None, }; - + agents.agents.insert("test-agent".to_string(), agent); agents.active_idx = "test-agent".to_string(); - + // Test 1: Native exact match let label = agents.display_label("fs_read", &ToolOrigin::Native); - assert!(label.contains("trusted"), "fs_read should be trusted (exact match), instead found: {}", label); - + assert!( + label.contains("trusted"), + "fs_read should be trusted (exact match), instead found: {}", + label + ); + // Test 2: Native wildcard match let label = agents.display_label("execute_bash", &ToolOrigin::Native); - assert!(label.contains("trusted"), "execute_bash should match execute_* pattern, instead found: {}", label); - + assert!( + label.contains("trusted"), + "execute_bash should match execute_* pattern, instead found: {}", + label + ); + // Test 3: Native no match let label = agents.display_label("fs_write", &ToolOrigin::Native); - assert!(!label.contains("trusted") || label.contains("not trusted"), "fs_write should not be trusted, instead found: {}", label); - + assert!( + !label.contains("trusted") || label.contains("not trusted"), + "fs_write should not be trusted, instead found: {}", + label + ); + // Test 4: MCP server exact match (allows any tool from server1) let label = agents.display_label("any_tool", &ToolOrigin::McpServer("server1".to_string())); - assert!(label.contains("trusted"), "Server-level permission should allow any tool, instead found: {}", label); - + assert!( + label.contains("trusted"), + "Server-level permission should allow any tool, instead found: {}", + label + ); + // Test 5: MCP tool exact match let label = agents.display_label("specific_tool", &ToolOrigin::McpServer("server2".to_string())); - assert!(label.contains("trusted"), "Exact MCP tool should be trusted, instead found: {}", label); - + assert!( + label.contains("trusted"), + "Exact MCP tool should be trusted, instead found: {}", + label + ); + // Test 6: MCP tool wildcard match let label = agents.display_label("tool_read", &ToolOrigin::McpServer("server3".to_string())); - assert!(label.contains("trusted"), "tool_read should match @server3/tool_* pattern, instead found: {}", label); - + assert!( + label.contains("trusted"), + "tool_read should match @server3/tool_* pattern, instead found: {}", + label + ); + // Test 7: MCP tool no match let label = agents.display_label("other_tool", &ToolOrigin::McpServer("server2".to_string())); - assert!(!label.contains("trusted") || label.contains("not trusted"), "Non-matching MCP tool should not be trusted, instead found: {}", label); - + assert!( + !label.contains("trusted") || label.contains("not trusted"), + "Non-matching MCP tool should not be trusted, instead found: {}", + label + ); + // Test 8: MCP server no match let label = agents.display_label("some_tool", &ToolOrigin::McpServer("unknown_server".to_string())); - assert!(!label.contains("trusted") || label.contains("not trusted"), "Unknown server should not be trusted, instead found: {}", label); + assert!( + !label.contains("trusted") || label.contains("not trusted"), + "Unknown server should not be trusted, instead found: {}", + label + ); } } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 6b1751cd15..7c58febcf7 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -135,7 +135,6 @@ impl ConversationState { current_model_id: Option, os: &Os, ) -> Self { - let model = if let Some(model_id) = current_model_id { match get_model_info(&model_id, os).await { Ok(info) => Some(info), @@ -1279,5 +1278,4 @@ mod tests { conversation.set_next_user_message(i.to_string()).await; } } - } diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 67134f7bf7..2fe2aa1f37 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -35,8 +35,8 @@ use crate::mcp_client::{ ToolCallResult, }; use crate::os::Os; -use crate::util::pattern_matching::matches_any_pattern; use crate::util::MCP_SERVER_TOOL_DELIMITER; +use crate::util::pattern_matching::matches_any_pattern; // TODO: support http transport type #[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)] diff --git a/crates/chat-cli/src/cli/chat/tools/knowledge.rs b/crates/chat-cli/src/cli/chat/tools/knowledge.rs index b4392d183f..639e0969e6 100644 --- a/crates/chat-cli/src/cli/chat/tools/knowledge.rs +++ b/crates/chat-cli/src/cli/chat/tools/knowledge.rs @@ -19,8 +19,8 @@ use crate::cli::agent::{ }; use crate::database::settings::Setting; use crate::os::Os; -use crate::util::pattern_matching::matches_any_pattern; use crate::util::knowledge_store::KnowledgeStore; +use crate::util::pattern_matching::matches_any_pattern; /// The Knowledge tool allows storing and retrieving information across chat sessions. /// It provides semantic search capabilities for files, directories, and text content. diff --git a/crates/chat-cli/src/util/pattern_matching.rs b/crates/chat-cli/src/util/pattern_matching.rs index 616f1d098e..cb663ac035 100644 --- a/crates/chat-cli/src/util/pattern_matching.rs +++ b/crates/chat-cli/src/util/pattern_matching.rs @@ -1,4 +1,5 @@ use std::collections::HashSet; + use globset::Glob; /// Check if a string matches any pattern in a set of patterns @@ -8,28 +9,29 @@ pub fn matches_any_pattern(patterns: &HashSet, text: &str) -> bool { if pattern == text { return true; } - + // Glob pattern match if contains wildcards if pattern.contains('*') || pattern.contains('?') { if let Ok(glob) = Glob::new(pattern) { return glob.compile_matcher().is_match(text); } } - + false }) } #[cfg(test)] mod tests { - use super::*; use std::collections::HashSet; + use super::*; + #[test] fn test_exact_match() { let mut patterns = HashSet::new(); patterns.insert("fs_read".to_string()); - + assert!(matches_any_pattern(&patterns, "fs_read")); assert!(!matches_any_pattern(&patterns, "fs_write")); } @@ -38,7 +40,7 @@ mod tests { fn test_wildcard_patterns() { let mut patterns = HashSet::new(); patterns.insert("fs_*".to_string()); - + assert!(matches_any_pattern(&patterns, "fs_read")); assert!(matches_any_pattern(&patterns, "fs_write")); assert!(!matches_any_pattern(&patterns, "execute_bash")); @@ -48,7 +50,7 @@ mod tests { fn test_mcp_patterns() { let mut patterns = HashSet::new(); patterns.insert("@mcp-server/*".to_string()); - + assert!(matches_any_pattern(&patterns, "@mcp-server/tool1")); assert!(matches_any_pattern(&patterns, "@mcp-server/tool2")); assert!(!matches_any_pattern(&patterns, "@other-server/tool")); @@ -58,7 +60,7 @@ mod tests { fn test_question_mark_wildcard() { let mut patterns = HashSet::new(); patterns.insert("fs_?ead".to_string()); - + assert!(matches_any_pattern(&patterns, "fs_read")); assert!(!matches_any_pattern(&patterns, "fs_write")); } From 5c6fe2800c10e90fc85df8a333f38ee353083ac9 Mon Sep 17 00:00:00 2001 From: evanliu048 Date: Wed, 20 Aug 2025 17:19:25 -0700 Subject: [PATCH 6/7] feat: added mcp admin level configuration with GetProfile (#2639) * first pass * add notification when /mcp & /tools * clear all tool related filed in agent * store mcp_enabled in chatsession & conversationstate * delete duplicate api call * set mcp_enabled value after load * remove clear mcp configs method * clippy * remain@builtin/ and *, add a ut for clear mcp config --- crates/chat-cli/src/api_client/error.rs | 10 ++ crates/chat-cli/src/api_client/mod.rs | 16 ++ crates/chat-cli/src/cli/agent/mod.rs | 170 +++++++++++++++--- .../src/cli/agent/root_command_args.rs | 15 +- crates/chat-cli/src/cli/chat/cli/mcp.rs | 22 ++- crates/chat-cli/src/cli/chat/cli/persist.rs | 2 + crates/chat-cli/src/cli/chat/cli/profile.rs | 7 +- crates/chat-cli/src/cli/chat/cli/tools.rs | 11 ++ crates/chat-cli/src/cli/chat/conversation.rs | 11 ++ crates/chat-cli/src/cli/chat/mod.rs | 41 ++++- crates/chat-cli/src/cli/mcp.rs | 9 +- 11 files changed, 272 insertions(+), 42 deletions(-) diff --git a/crates/chat-cli/src/api_client/error.rs b/crates/chat-cli/src/api_client/error.rs index 37420fb72e..4ac80f329c 100644 --- a/crates/chat-cli/src/api_client/error.rs +++ b/crates/chat-cli/src/api_client/error.rs @@ -1,5 +1,6 @@ use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubscriptionTokenError; use amzn_codewhisperer_client::operation::generate_completions::GenerateCompletionsError; +use amzn_codewhisperer_client::operation::get_profile::GetProfileError; use amzn_codewhisperer_client::operation::list_available_customizations::ListAvailableCustomizationsError; use amzn_codewhisperer_client::operation::list_available_models::ListAvailableModelsError; use amzn_codewhisperer_client::operation::list_available_profiles::ListAvailableProfilesError; @@ -100,6 +101,9 @@ pub enum ApiClientError { #[error("No default model found in the ListAvailableModels API response")] DefaultModelNotFound, + + #[error(transparent)] + GetProfileError(#[from] SdkError), } impl ApiClientError { @@ -125,6 +129,7 @@ impl ApiClientError { Self::Credentials(_e) => None, Self::ListAvailableModelsError(e) => sdk_status_code(e), Self::DefaultModelNotFound => None, + Self::GetProfileError(e) => sdk_status_code(e), } } } @@ -152,6 +157,7 @@ impl ReasonCode for ApiClientError { Self::Credentials(_) => "CredentialsError".to_string(), Self::ListAvailableModelsError(e) => sdk_error_code(e), Self::DefaultModelNotFound => "DefaultModelNotFound".to_string(), + Self::GetProfileError(e) => sdk_error_code(e), } } } @@ -199,6 +205,10 @@ mod tests { ListAvailableCustomizationsError::unhandled(""), response(), )), + ApiClientError::GetProfileError(SdkError::service_error( + GetProfileError::unhandled(""), + response(), + )), ApiClientError::ListAvailableModelsError(SdkError::service_error( ListAvailableModelsError::unhandled(""), response(), diff --git a/crates/chat-cli/src/api_client/mod.rs b/crates/chat-cli/src/api_client/mod.rs index 20b97c71a2..26f1e7a1f3 100644 --- a/crates/chat-cli/src/api_client/mod.rs +++ b/crates/chat-cli/src/api_client/mod.rs @@ -16,6 +16,7 @@ use amzn_codewhisperer_client::operation::create_subscription_token::CreateSubsc use amzn_codewhisperer_client::types::Origin::Cli; use amzn_codewhisperer_client::types::{ Model, + OptInFeatureToggle, OptOutPreference, SubscriptionStatus, TelemetryEvent, @@ -334,6 +335,21 @@ impl ApiClient { Ok(res) } + pub async fn is_mcp_enabled(&self) -> Result { + let request = self + .client + .get_profile() + .set_profile_arn(self.profile.as_ref().map(|p| p.arn.clone())); + + let response = request.send().await?; + let mcp_enabled = response + .profile() + .opt_in_features() + .and_then(|features| features.mcp_configuration()) + .is_none_or(|config| matches!(config.toggle(), OptInFeatureToggle::On)); + Ok(mcp_enabled) + } + pub async fn create_subscription_token(&self) -> Result { if cfg!(test) { return Ok(CreateSubscriptionTokenOutput::builder() diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index d75a1f733d..7089b33ff9 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -286,6 +286,7 @@ impl Agent { os: &Os, agent_path: impl AsRef, legacy_mcp_config: &mut Option, + mcp_enabled: bool, ) -> Result { let content = os.fs.read(&agent_path).await?; let mut agent = serde_json::from_slice::(&content).map_err(|e| AgentConfigError::InvalidJson { @@ -293,16 +294,44 @@ impl Agent { path: agent_path.as_ref().to_path_buf(), })?; - if agent.use_legacy_mcp_json && legacy_mcp_config.is_none() { - let config = load_legacy_mcp_config(os).await.unwrap_or_default(); - if let Some(config) = config { - legacy_mcp_config.replace(config); + if mcp_enabled { + if agent.use_legacy_mcp_json && legacy_mcp_config.is_none() { + let config = load_legacy_mcp_config(os).await.unwrap_or_default(); + if let Some(config) = config { + legacy_mcp_config.replace(config); + } } + agent.thaw(agent_path.as_ref(), legacy_mcp_config.as_ref())?; + } else { + agent.clear_mcp_configs(); + // Thaw the agent with empty MCP config to finalize normalization. + agent.thaw(agent_path.as_ref(), None)?; } - - agent.thaw(agent_path.as_ref(), legacy_mcp_config.as_ref())?; Ok(agent) } + + /// Clear all MCP configurations while preserving built-in tools + pub fn clear_mcp_configs(&mut self) { + self.mcp_servers = McpServerConfig::default(); + self.use_legacy_mcp_json = false; + + // Transform tools: "*" → "@builtin", remove MCP refs + self.tools = self + .tools + .iter() + .filter_map(|tool| match tool.as_str() { + "*" => Some("@builtin".to_string()), + t if !is_mcp_tool_ref(t) => Some(t.to_string()), + _ => None, + }) + .collect(); + + // Remove MCP references from other fields + self.allowed_tools.retain(|tool| !is_mcp_tool_ref(tool)); + self.tool_aliases.retain(|orig, _| !is_mcp_tool_ref(&orig.to_string())); + self.tools_settings + .retain(|target, _| !is_mcp_tool_ref(&target.to_string())); + } } /// Result of evaluating tool permissions, indicating whether a tool should be allowed, @@ -382,7 +411,19 @@ impl Agents { agent_name: Option<&str>, skip_migration: bool, output: &mut impl Write, + mcp_enabled: bool, ) -> (Self, AgentsLoadMetadata) { + if !mcp_enabled { + let _ = execute!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print("\n"), + style::Print("⚠️ WARNING: "), + style::SetForegroundColor(Color::Reset), + style::Print("MCP functionality has been disabled by your administrator.\n\n"), + ); + } + // Tracking metadata about the performed load operation. let mut load_metadata = AgentsLoadMetadata::default(); @@ -429,7 +470,7 @@ impl Agents { }; let mut agents = Vec::::new(); - let results = load_agents_from_entries(files, os, &mut global_mcp_config).await; + let results = load_agents_from_entries(files, os, &mut global_mcp_config, mcp_enabled).await; for result in results { match result { Ok(agent) => agents.push(agent), @@ -467,7 +508,7 @@ impl Agents { }; let mut agents = Vec::::new(); - let results = load_agents_from_entries(files, os, &mut global_mcp_config).await; + let results = load_agents_from_entries(files, os, &mut global_mcp_config, mcp_enabled).await; for result in results { match result { Ok(agent) => agents.push(agent), @@ -607,27 +648,30 @@ impl Agents { all_agents.push({ let mut agent = Agent::default(); - 'load_legacy_mcp_json: { - if global_mcp_config.is_none() { - let Ok(global_mcp_path) = directories::chat_legacy_global_mcp_config(os) else { - tracing::error!("Error obtaining legacy mcp json path. Skipping"); - break 'load_legacy_mcp_json; - }; - let legacy_mcp_config = match McpServerConfig::load_from_file(os, global_mcp_path).await { - Ok(config) => config, - Err(e) => { - tracing::error!("Error loading global mcp json path: {e}. Skipping"); + if mcp_enabled { + 'load_legacy_mcp_json: { + if global_mcp_config.is_none() { + let Ok(global_mcp_path) = directories::chat_legacy_global_mcp_config(os) else { + tracing::error!("Error obtaining legacy mcp json path. Skipping"); break 'load_legacy_mcp_json; - }, - }; - global_mcp_config.replace(legacy_mcp_config); + }; + let legacy_mcp_config = match McpServerConfig::load_from_file(os, global_mcp_path).await { + Ok(config) => config, + Err(e) => { + tracing::error!("Error loading global mcp json path: {e}. Skipping"); + break 'load_legacy_mcp_json; + }, + }; + global_mcp_config.replace(legacy_mcp_config); + } } - } - if let Some(config) = &global_mcp_config { - agent.mcp_servers = config.clone(); + if let Some(config) = &global_mcp_config { + agent.mcp_servers = config.clone(); + } + } else { + agent.mcp_servers = McpServerConfig::default(); } - agent }); @@ -763,6 +807,7 @@ async fn load_agents_from_entries( mut files: ReadDir, os: &Os, global_mcp_config: &mut Option, + mcp_enabled: bool, ) -> Vec> { let mut res = Vec::>::new(); @@ -773,7 +818,7 @@ async fn load_agents_from_entries( .and_then(OsStr::to_str) .is_some_and(|s| s == "json") { - res.push(Agent::load(os, file_path, global_mcp_config).await); + res.push(Agent::load(os, file_path, global_mcp_config, mcp_enabled).await); } } @@ -820,6 +865,13 @@ fn default_schema() -> String { "https://raw.githubusercontent.com/aws/amazon-q-developer-cli/refs/heads/main/schemas/agent-v1.json".into() } +// Check if a tool reference is MCP-specific (not @builtin and starts with @) +pub fn is_mcp_tool_ref(s: &str) -> bool { + // @builtin is not MCP, it's a reference to all built-in tools + // Any other @ prefix is MCP (e.g., "@git", "@git/git_status") + !s.starts_with("@builtin") && s.starts_with('@') +} + #[cfg(test)] fn validate_agent_name(name: &str) -> eyre::Result<()> { // Check if name is empty @@ -840,8 +892,9 @@ fn validate_agent_name(name: &str) -> eyre::Result<()> { #[cfg(test)] mod tests { - use super::*; + use serde_json::json; + use super::*; const INPUT: &str = r#" { "name": "some_agent", @@ -956,6 +1009,69 @@ mod tests { assert!(validate_agent_name("invalid space").is_err()); } + #[test] + fn test_clear_mcp_configs_with_builtin_variants() { + let mut agent: Agent = serde_json::from_value(json!({ + "name": "test", + "tools": [ + "@builtin", + "@builtin/fs_read", + "@builtin/execute_bash", + "@git", + "@git/status", + "fs_write" + ], + "allowedTools": [ + "@builtin/fs_read", + "@git/status", + "fs_write" + ], + "toolAliases": { + "@builtin/fs_read": "read", + "@git/status": "git_st" + }, + "toolsSettings": { + "@builtin/fs_write": { "allowedPaths": ["~/**"] }, + "@git/commit": { "sign": true } + } + })) + .unwrap(); + + agent.clear_mcp_configs(); + + // All @builtin variants should be preserved while MCP tools should be removed + assert!(agent.tools.contains(&"@builtin".to_string())); + assert!(agent.tools.contains(&"@builtin/fs_read".to_string())); + assert!(agent.tools.contains(&"@builtin/execute_bash".to_string())); + assert!(agent.tools.contains(&"fs_write".to_string())); + assert!(!agent.tools.contains(&"@git".to_string())); + assert!(!agent.tools.contains(&"@git/status".to_string())); + + assert!(agent.allowed_tools.contains("@builtin/fs_read")); + assert!(agent.allowed_tools.contains("fs_write")); + assert!(!agent.allowed_tools.contains("@git/status")); + + // Check tool aliases - need to iterate since we can't construct OriginalToolName directly + let has_builtin_alias = agent + .tool_aliases + .iter() + .any(|(k, v)| k.to_string() == "@builtin/fs_read" && v == "read"); + assert!(has_builtin_alias, "@builtin/fs_read alias should be preserved"); + + let has_git_alias = agent.tool_aliases.iter().any(|(k, _)| k.to_string() == "@git/status"); + assert!(!has_git_alias, "@git/status alias should be removed"); + + // Check tool settings - need to iterate since we can't construct ToolSettingTarget directly + let has_builtin_setting = agent + .tools_settings + .iter() + .any(|(k, _)| k.to_string() == "@builtin/fs_write"); + assert!(has_builtin_setting, "@builtin/fs_write settings should be preserved"); + + let has_git_setting = agent.tools_settings.iter().any(|(k, _)| k.to_string() == "@git/commit"); + assert!(!has_git_setting, "@git/commit settings should be removed"); + } + #[test] fn test_display_label_no_active_agent() { let agents = Agents::default(); diff --git a/crates/chat-cli/src/cli/agent/root_command_args.rs b/crates/chat-cli/src/cli/agent/root_command_args.rs index 0a7c01eaec..f0b6ded75a 100644 --- a/crates/chat-cli/src/cli/agent/root_command_args.rs +++ b/crates/chat-cli/src/cli/agent/root_command_args.rs @@ -74,9 +74,16 @@ pub struct AgentArgs { impl AgentArgs { pub async fn execute(self, os: &mut Os) -> Result { let mut stderr = std::io::stderr(); + let mcp_enabled = match os.client.is_mcp_enabled().await { + Ok(enabled) => enabled, + Err(err) => { + tracing::warn!(?err, "Failed to check MCP configuration, defaulting to enabled"); + true + }, + }; match self.cmd { Some(AgentSubcommands::List) | None => { - let agents = Agents::load(os, None, true, &mut stderr).await.0; + let agents = Agents::load(os, None, true, &mut stderr, mcp_enabled).await.0; let agent_with_path = agents .agents @@ -101,7 +108,7 @@ impl AgentArgs { writeln!(stderr, "{}", output_str)?; }, Some(AgentSubcommands::Create { name, directory, from }) => { - let mut agents = Agents::load(os, None, true, &mut stderr).await.0; + let mut agents = Agents::load(os, None, true, &mut stderr, mcp_enabled).await.0; let path_with_file_name = create_agent(os, &mut agents, name.clone(), directory, from).await?; let editor_cmd = std::env::var("EDITOR").unwrap_or_else(|_| "vi".to_string()); let mut cmd = std::process::Command::new(editor_cmd); @@ -133,7 +140,7 @@ impl AgentArgs { }, Some(AgentSubcommands::Validate { path }) => { let mut global_mcp_config = None::; - let agent = Agent::load(os, path.as_str(), &mut global_mcp_config).await; + let agent = Agent::load(os, path.as_str(), &mut global_mcp_config, mcp_enabled).await; 'validate: { match agent { @@ -251,7 +258,7 @@ impl AgentArgs { } }, Some(AgentSubcommands::SetDefault { name }) => { - let mut agents = Agents::load(os, None, true, &mut stderr).await.0; + let mut agents = Agents::load(os, None, true, &mut stderr, mcp_enabled).await.0; match agents.switch(&name) { Ok(agent) => { os.database diff --git a/crates/chat-cli/src/cli/chat/cli/mcp.rs b/crates/chat-cli/src/cli/chat/cli/mcp.rs index e653ddca7d..82a9740c5e 100644 --- a/crates/chat-cli/src/cli/chat/cli/mcp.rs +++ b/crates/chat-cli/src/cli/chat/cli/mcp.rs @@ -1,9 +1,10 @@ use std::io::Write; use clap::Args; -use crossterm::{ - queue, - style, +use crossterm::queue; +use crossterm::style::{ + self, + Color, }; use crate::cli::chat::tool_manager::LoadingRecord; @@ -19,6 +20,21 @@ pub struct McpArgs; impl McpArgs { pub async fn execute(self, session: &mut ChatSession) -> Result { + if !session.conversation.mcp_enabled { + queue!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("\n"), + style::Print("⚠️ WARNING: "), + style::SetForegroundColor(Color::Reset), + style::Print("MCP functionality has been disabled by your administrator.\n\n"), + )?; + session.stderr.flush()?; + return Ok(ChatState::PromptUser { + skip_printing_tools: true, + }); + } + let terminal_width = session.terminal_width(); let still_loading = session .conversation diff --git a/crates/chat-cli/src/cli/chat/cli/persist.rs b/crates/chat-cli/src/cli/chat/cli/persist.rs index 197a5ab2cb..1f4568f7ed 100644 --- a/crates/chat-cli/src/cli/chat/cli/persist.rs +++ b/crates/chat-cli/src/cli/chat/cli/persist.rs @@ -95,6 +95,8 @@ impl PersistSubcommand { let mut new_state: ConversationState = tri!(serde_json::from_str(&contents), "import from", &path); std::mem::swap(&mut new_state.tool_manager, &mut session.conversation.tool_manager); + std::mem::swap(&mut new_state.mcp_enabled, &mut session.conversation.mcp_enabled); + std::mem::swap(&mut new_state.model_info, &mut session.conversation.model_info); std::mem::swap( &mut new_state.context_manager, &mut session.conversation.context_manager, diff --git a/crates/chat-cli/src/cli/chat/cli/profile.rs b/crates/chat-cli/src/cli/chat/cli/profile.rs index fb6b17a67d..868944f9b3 100644 --- a/crates/chat-cli/src/cli/chat/cli/profile.rs +++ b/crates/chat-cli/src/cli/chat/cli/profile.rs @@ -132,7 +132,9 @@ impl AgentSubcommand { .map_err(|e| ChatError::Custom(format!("Error printing agent schema: {e}").into()))?; }, Self::Create { name, directory, from } => { - let mut agents = Agents::load(os, None, true, &mut session.stderr).await.0; + let mut agents = Agents::load(os, None, true, &mut session.stderr, session.conversation.mcp_enabled) + .await + .0; let path_with_file_name = create_agent(os, &mut agents, name.clone(), directory, from) .await .map_err(|e| ChatError::Custom(Cow::Owned(e.to_string())))?; @@ -144,7 +146,8 @@ impl AgentSubcommand { return Err(ChatError::Custom("Editor process did not exit with success".into())); } - let new_agent = Agent::load(os, &path_with_file_name, &mut None).await; + let new_agent = + Agent::load(os, &path_with_file_name, &mut None, session.conversation.mcp_enabled).await; match new_agent { Ok(agent) => { session.conversation.agents.agents.insert(agent.name.clone(), agent); diff --git a/crates/chat-cli/src/cli/chat/cli/tools.rs b/crates/chat-cli/src/cli/chat/cli/tools.rs index 04aecca0a4..ce05dce3cb 100644 --- a/crates/chat-cli/src/cli/chat/cli/tools.rs +++ b/crates/chat-cli/src/cli/chat/cli/tools.rs @@ -170,6 +170,17 @@ impl ToolsArgs { )?; } + if !session.conversation.mcp_enabled { + queue!( + session.stderr, + style::SetForegroundColor(Color::Yellow), + style::Print("\n"), + style::Print("⚠️ WARNING: "), + style::SetForegroundColor(Color::Reset), + style::Print("MCP functionality has been disabled by your administrator.\n\n"), + )?; + } + Ok(ChatState::default()) } diff --git a/crates/chat-cli/src/cli/chat/conversation.rs b/crates/chat-cli/src/cli/chat/conversation.rs index 7c58febcf7..ef1bf241b6 100644 --- a/crates/chat-cli/src/cli/chat/conversation.rs +++ b/crates/chat-cli/src/cli/chat/conversation.rs @@ -124,6 +124,8 @@ pub struct ConversationState { /// Maps from a file path to [FileLineTracker] #[serde(default)] pub file_line_tracker: HashMap, + #[serde(default = "default_true")] + pub mcp_enabled: bool, } impl ConversationState { @@ -134,6 +136,7 @@ impl ConversationState { tool_manager: ToolManager, current_model_id: Option, os: &Os, + mcp_enabled: bool, ) -> Self { let model = if let Some(model_id) = current_model_id { match get_model_info(&model_id, os).await { @@ -180,6 +183,7 @@ impl ConversationState { model: None, model_info: model, file_line_tracker: HashMap::new(), + mcp_enabled, } } @@ -1006,6 +1010,9 @@ fn enforce_tool_use_history_invariants(history: &mut VecDeque, too } } +fn default_true() -> bool { + true +} #[cfg(test)] mod tests { use super::super::message::AssistantToolUse; @@ -1124,6 +1131,7 @@ mod tests { tool_manager, None, &os, + false, ) .await; @@ -1156,6 +1164,7 @@ mod tests { tool_manager.clone(), None, &os, + false, ) .await; conversation.set_next_user_message("start".to_string()).await; @@ -1191,6 +1200,7 @@ mod tests { tool_manager.clone(), None, &os, + false, ) .await; conversation.set_next_user_message("start".to_string()).await; @@ -1245,6 +1255,7 @@ mod tests { tool_manager, None, &os, + false, ) .await; diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 103973565c..2b167e272d 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -251,9 +251,19 @@ impl ChatArgs { let conversation_id = uuid::Uuid::new_v4().to_string(); info!(?conversation_id, "Generated new conversation id"); + // Check MCP status once at the beginning of the session + let mcp_enabled = match os.client.is_mcp_enabled().await { + Ok(enabled) => enabled, + Err(err) => { + tracing::warn!(?err, "Failed to check MCP configuration, defaulting to enabled"); + true + }, + }; + let agents = { let skip_migration = self.no_interactive; - let (mut agents, md) = Agents::load(os, self.agent.as_deref(), skip_migration, &mut stderr).await; + let (mut agents, md) = + Agents::load(os, self.agent.as_deref(), skip_migration, &mut stderr, mcp_enabled).await; agents.trust_all_tools = self.trust_all_tools; os.telemetry @@ -268,9 +278,11 @@ impl ChatArgs { .map_err(|err| error!(?err, "failed to send agent config init telemetry")) .ok(); - if agents - .get_active() - .is_some_and(|a| !a.mcp_servers.mcp_servers.is_empty()) + // Only show MCP safety message if MCP is enabled and has servers + if mcp_enabled + && agents + .get_active() + .is_some_and(|a| !a.mcp_servers.mcp_servers.is_empty()) { if !self.no_interactive && !os.database.settings.get_bool(Setting::McpLoadedBefore).unwrap_or(false) { execute!( @@ -364,6 +376,7 @@ impl ChatArgs { model_id, tool_config, !self.no_interactive, + mcp_enabled, ) .await? .spawn(os) @@ -589,6 +602,7 @@ impl ChatSession { model_id: Option, tool_config: HashMap, interactive: bool, + mcp_enabled: bool, ) -> Result { // Reload prior conversation let mut existing_conversation = false; @@ -624,11 +638,23 @@ impl ChatSession { } } cs.agents = agents; + cs.mcp_enabled = mcp_enabled; cs.update_state(true).await; cs.enforce_tool_use_history_invariants(); cs }, - false => ConversationState::new(conversation_id, agents, tool_config, tool_manager, model_id, os).await, + false => { + ConversationState::new( + conversation_id, + agents, + tool_config, + tool_manager, + model_id, + os, + mcp_enabled, + ) + .await + }, }; // Spawn a task for listening and broadcasting sigints. @@ -2967,6 +2993,7 @@ mod tests { None, tool_config, true, + false, ) .await .unwrap() @@ -3108,6 +3135,7 @@ mod tests { None, tool_config, true, + false, ) .await .unwrap() @@ -3204,6 +3232,7 @@ mod tests { None, tool_config, true, + false, ) .await .unwrap() @@ -3278,6 +3307,7 @@ mod tests { None, tool_config, true, + false, ) .await .unwrap() @@ -3328,6 +3358,7 @@ mod tests { None, tool_config, true, + false, ) .await .unwrap() diff --git a/crates/chat-cli/src/cli/mcp.rs b/crates/chat-cli/src/cli/mcp.rs index cc57f32345..c70951a9b5 100644 --- a/crates/chat-cli/src/cli/mcp.rs +++ b/crates/chat-cli/src/cli/mcp.rs @@ -416,7 +416,14 @@ impl StatusArgs { async fn get_mcp_server_configs(os: &mut Os) -> Result, bool)>>> { let mut results = BTreeMap::new(); let mut stderr = std::io::stderr(); - let agents = Agents::load(os, None, true, &mut stderr).await.0; + let mcp_enabled = match os.client.is_mcp_enabled().await { + Ok(enabled) => enabled, + Err(err) => { + tracing::warn!(?err, "Failed to check MCP configuration, defaulting to enabled"); + true + }, + }; + let agents = Agents::load(os, None, true, &mut stderr, mcp_enabled).await.0; let global_path = directories::chat_global_agent_path(os)?; for (_, agent) in agents.agents { let scope = if agent From b0ddbda032a496c602bc163e621dd52ad0eb970d Mon Sep 17 00:00:00 2001 From: Felix Ding Date: Thu, 21 Aug 2025 09:48:25 -0700 Subject: [PATCH 7/7] fix(agent): tool permission (#2619) * adds warnings for when tool settings are overridden by allowed tools * adjusts tool settings eval order * modifies doc * moves warning to be displayed after splash screen * canonicalizes paths prior to making glob sets * simplifies overridden warning message printing logic * adds more doc on path globbing --- crates/chat-cli/src/cli/agent/mod.rs | 58 +++++- crates/chat-cli/src/cli/chat/mod.rs | 20 ++- .../src/cli/chat/tools/custom_tool.rs | 2 +- .../src/cli/chat/tools/execute/mod.rs | 66 ++++--- crates/chat-cli/src/cli/chat/tools/fs_read.rs | 114 +++++++----- .../chat-cli/src/cli/chat/tools/fs_write.rs | 165 +++++++++++++----- .../chat-cli/src/cli/chat/tools/knowledge.rs | 4 +- crates/chat-cli/src/cli/chat/tools/mod.rs | 14 +- crates/chat-cli/src/cli/chat/tools/use_aws.rs | 38 ++-- crates/chat-cli/src/util/directories.rs | 78 +++++++++ docs/agent-format.md | 1 + docs/built-in-tools.md | 8 +- 12 files changed, 425 insertions(+), 143 deletions(-) diff --git a/crates/chat-cli/src/cli/agent/mod.rs b/crates/chat-cli/src/cli/agent/mod.rs index 7089b33ff9..668a5f1c28 100644 --- a/crates/chat-cli/src/cli/agent/mod.rs +++ b/crates/chat-cli/src/cli/agent/mod.rs @@ -213,8 +213,8 @@ impl Agent { self.path = Some(path.to_path_buf()); + let mut stderr = std::io::stderr(); if let (true, Some(legacy_mcp_config)) = (self.use_legacy_mcp_json, legacy_mcp_config) { - let mut stderr = std::io::stderr(); for (name, legacy_server) in &legacy_mcp_config.mcp_servers { if mcp_servers.mcp_servers.contains_key(name) { let _ = queue!( @@ -238,6 +238,31 @@ impl Agent { } } + stderr.flush()?; + + Ok(()) + } + + pub fn print_overridden_permissions(&self, output: &mut impl Write) -> Result<(), AgentConfigError> { + let execute_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + for allowed_tool in &self.allowed_tools { + if let Some(settings) = self.tools_settings.get(allowed_tool.as_str()) { + // currently we only have four native tools that offers tool settings + let overridden_settings_key = match allowed_tool.as_str() { + "fs_read" | "fs_write" => Some("allowedPaths"), + "use_aws" => Some("allowedServices"), + name if name == execute_name => Some("allowedCommands"), + _ => None, + }; + + if let Some(key) = overridden_settings_key { + if let Some(ref override_settings) = settings.get(key).map(|value| format!("{key}: {value}")) { + queue_permission_override_warning(allowed_tool.as_str(), override_settings, output)?; + } + } + } + } + Ok(()) } @@ -861,6 +886,28 @@ async fn load_legacy_mcp_config(os: &Os) -> eyre::Result }) } +pub fn queue_permission_override_warning( + tool_name: &str, + overridden_settings: &str, + output: &mut impl Write, +) -> Result<(), std::io::Error> { + Ok(queue!( + output, + style::SetForegroundColor(Color::Yellow), + style::Print("WARNING: "), + style::ResetColor, + style::Print("You have trusted "), + style::SetForegroundColor(Color::Green), + style::Print(tool_name), + style::ResetColor, + style::Print(" tool, which overrides the toolsSettings: "), + style::SetForegroundColor(Color::Cyan), + style::Print(overridden_settings), + style::ResetColor, + style::Print("\n"), + )?) +} + fn default_schema() -> String { "https://raw.githubusercontent.com/aws/amazon-q-developer-cli/refs/heads/main/schemas/agent-v1.json".into() } @@ -1088,8 +1135,10 @@ mod tests { #[test] fn test_display_label_trust_all_tools() { - let mut agents = Agents::default(); - agents.trust_all_tools = true; + let agents = Agents { + trust_all_tools: true, + ..Default::default() + }; // Should be trusted even if not in allowed_tools let label = agents.display_label("random_tool", &ToolOrigin::Native); @@ -1119,7 +1168,8 @@ mod tests { fs_write_label ); - let execute_bash_label = agents.display_label("execute_bash", &ToolOrigin::Native); + let execute_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; + let execute_bash_label = agents.display_label(execute_name, &ToolOrigin::Native); assert!( execute_bash_label.contains("read-only"), "execute_bash should show read-only by default, instead found: {}", diff --git a/crates/chat-cli/src/cli/chat/mod.rs b/crates/chat-cli/src/cli/chat/mod.rs index 2b167e272d..a0c1779318 100644 --- a/crates/chat-cli/src/cli/chat/mod.rs +++ b/crates/chat-cli/src/cli/chat/mod.rs @@ -125,7 +125,10 @@ use util::{ use winnow::Partial; use winnow::stream::Offset; -use super::agent::PermissionEvalResult; +use super::agent::{ + DEFAULT_AGENT_NAME, + PermissionEvalResult, +}; use crate::api_client::model::ToolResultStatus; use crate::api_client::{ self, @@ -634,7 +637,7 @@ impl ChatSession { ": cannot resume conversation with {profile} because it no longer exists. Using default.\n" )) )?; - let _ = agents.switch("default"); + let _ = agents.switch(DEFAULT_AGENT_NAME); } } cs.agents = agents; @@ -1207,6 +1210,11 @@ impl ChatSession { )) )?; } + + if let Some(agent) = self.conversation.agents.get_active() { + agent.print_overridden_permissions(&mut self.stderr)?; + } + self.stderr.flush()?; if let Some(ref model_info) = self.conversation.model_info { @@ -1775,6 +1783,12 @@ impl ChatSession { .clone() .unwrap_or(tool_use.name.clone()); self.conversation.agents.trust_tools(vec![formatted_tool_name]); + + if let Some(agent) = self.conversation.agents.get_active() { + agent + .print_overridden_permissions(&mut self.stderr) + .map_err(|_e| ChatError::Custom("Failed to validate agent tool settings".into()))?; + } } tool_use.accepted = true; @@ -1842,7 +1856,7 @@ impl ChatSession { self.conversation .agents .get_active() - .is_some_and(|a| match tool.tool.requires_acceptance(a) { + .is_some_and(|a| match tool.tool.requires_acceptance(os, a) { PermissionEvalResult::Allow => true, PermissionEvalResult::Ask => false, PermissionEvalResult::Deny(matches) => { diff --git a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs index 2fe2aa1f37..fafb55a9a4 100644 --- a/crates/chat-cli/src/cli/chat/tools/custom_tool.rs +++ b/crates/chat-cli/src/cli/chat/tools/custom_tool.rs @@ -275,7 +275,7 @@ impl CustomTool { + TokenCounter::count_tokens(self.params.as_ref().map_or("", |p| p.as_str().unwrap_or_default())) } - pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + pub fn eval_perm(&self, _os: &Os, agent: &Agent) -> PermissionEvalResult { let Self { name: tool_name, client, diff --git a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs index a1e7b9c8e5..388c48476b 100644 --- a/crates/chat-cli/src/cli/chat/tools/execute/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/execute/mod.rs @@ -187,7 +187,7 @@ impl ExecuteCommand { Ok(()) } - pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + pub fn eval_perm(&self, _os: &Os, agent: &Agent) -> PermissionEvalResult { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct Settings { @@ -207,7 +207,7 @@ impl ExecuteCommand { let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, tool_name); match agent.tools_settings.get(tool_name) { - Some(settings) if is_in_allowlist => { + Some(settings) => { let Settings { allowed_commands, denied_commands, @@ -231,7 +231,9 @@ impl ExecuteCommand { return PermissionEvalResult::Deny(denied_match_set); } - if self.requires_acceptance(Some(&allowed_commands), allow_read_only) { + if is_in_allowlist { + PermissionEvalResult::Allow + } else if self.requires_acceptance(Some(&allowed_commands), allow_read_only) { PermissionEvalResult::Ask } else { PermissionEvalResult::Allow @@ -268,10 +270,7 @@ pub fn format_output(output: &str, max_size: usize) -> String { #[cfg(test)] mod tests { - use std::collections::{ - HashMap, - HashSet, - }; + use std::collections::HashMap; use super::*; use crate::cli::agent::ToolSettingTarget; @@ -422,21 +421,17 @@ mod tests { } } - #[test] - fn test_eval_perm() { + #[tokio::test] + async fn test_eval_perm() { let tool_name = if cfg!(windows) { "execute_cmd" } else { "execute_bash" }; - let agent = Agent { + let mut agent = Agent { name: "test_agent".to_string(), - allowed_tools: { - let mut allowed_tools = HashSet::::new(); - allowed_tools.insert(tool_name.to_string()); - allowed_tools - }, tools_settings: { let mut map = HashMap::::new(); map.insert( ToolSettingTarget(tool_name.to_string()), serde_json::json!({ + "allowedCommands": ["allow_wild_card .*", "allow_exact"], "deniedCommands": ["git .*"] }), ); @@ -444,22 +439,53 @@ mod tests { }, ..Default::default() }; + let os = Os::new().await.unwrap(); - let tool = serde_json::from_value::(serde_json::json!({ + let tool_one = serde_json::from_value::(serde_json::json!({ "command": "git status", })) .unwrap(); - let res = tool.eval_perm(&agent); + let res = tool_one.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Deny(ref rules) if rules.contains(&"\\Agit .*\\z".to_string()))); - let tool = serde_json::from_value::(serde_json::json!({ - "command": "echo hello", + let tool_two = serde_json::from_value::(serde_json::json!({ + "command": "this_is_not_a_read_only_command", + })) + .unwrap(); + + let res = tool_two.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Ask)); + + let tool_allow_wild_card = serde_json::from_value::(serde_json::json!({ + "command": "allow_wild_card some_arg", + })) + .unwrap(); + let res = tool_allow_wild_card.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + let tool_allow_exact_should_ask = serde_json::from_value::(serde_json::json!({ + "command": "allow_exact some_arg", + })) + .unwrap(); + let res = tool_allow_exact_should_ask.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Ask)); + + let tool_allow_exact_should_allow = serde_json::from_value::(serde_json::json!({ + "command": "allow_exact", })) .unwrap(); + let res = tool_allow_exact_should_allow.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + agent.allowed_tools.insert(tool_name.to_string()); - let res = tool.eval_perm(&agent); + let res = tool_two.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Allow)); + + // Denied list should remain denied + let res = tool_one.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Deny(ref rules) if rules.contains(&"\\Agit .*\\z".to_string()))); } #[tokio::test] diff --git a/crates/chat-cli/src/cli/chat/tools/fs_read.rs b/crates/chat-cli/src/cli/chat/tools/fs_read.rs index dc30c336cb..a11924e9a2 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_read.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_read.rs @@ -11,10 +11,7 @@ use eyre::{ Result, bail, }; -use globset::{ - Glob, - GlobSetBuilder, -}; +use globset::GlobSetBuilder; use serde::{ Deserialize, Serialize, @@ -48,6 +45,7 @@ use crate::cli::chat::{ sanitize_unicode_tags, }; use crate::os::Os; +use crate::util::directories; use crate::util::pattern_matching::matches_any_pattern; #[derive(Debug, Clone, Deserialize)] @@ -103,7 +101,7 @@ impl FsRead { } } - pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + pub fn eval_perm(&self, os: &Os, agent: &Agent) -> PermissionEvalResult { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct Settings { @@ -121,7 +119,7 @@ impl FsRead { let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_read"); match agent.tools_settings.get("fs_read") { - Some(settings) if is_in_allowlist => { + Some(settings) => { let Settings { allowed_paths, denied_paths, @@ -136,10 +134,11 @@ impl FsRead { let allow_set = { let mut builder = GlobSetBuilder::new(); for path in &allowed_paths { - if let Ok(glob) = Glob::new(path) { - builder.add(glob); - } else { - warn!("Failed to create glob from path given: {path}. Ignoring."); + let Ok(path) = directories::canonicalizes_path(os, path) else { + continue; + }; + if let Err(e) = directories::add_gitignore_globs(&mut builder, path.as_str()) { + warn!("Failed to create glob from path given: {path}: {e}. Ignoring."); } } builder.build() @@ -149,11 +148,17 @@ impl FsRead { let deny_set = { let mut builder = GlobSetBuilder::new(); for path in &denied_paths { - if let Ok(glob) = Glob::new(path) { - sanitized_deny_list.push(path); - builder.add(glob); - } else { - warn!("Failed to create glob from path given: {path}. Ignoring."); + let Ok(processed_path) = directories::canonicalizes_path(os, path) else { + continue; + }; + match directories::add_gitignore_globs(&mut builder, processed_path.as_str()) { + Ok(_) => { + // Note that we need to push twice here because for each rule we + // are creating two globs (one for file and one for directory) + sanitized_deny_list.push(path); + sanitized_deny_list.push(path); + }, + Err(e) => warn!("Failed to create glob from path given: {path}: {e}. Ignoring."), } } builder.build() @@ -169,7 +174,11 @@ impl FsRead { FsReadOperation::Line(FsLine { path, .. }) | FsReadOperation::Directory(FsDirectory { path, .. }) | FsReadOperation::Search(FsSearch { path, .. }) => { - let denied_match_set = deny_set.matches(path); + let Ok(path) = directories::canonicalizes_path(os, path) else { + ask = true; + continue; + }; + let denied_match_set = deny_set.matches(path.as_ref() as &str); if !denied_match_set.is_empty() { let deny_res = PermissionEvalResult::Deny({ denied_match_set @@ -183,14 +192,24 @@ impl FsRead { // We only want to ask if we are not allowing read only // operation - if !allow_read_only && !allow_set.is_match(path) { + if !is_in_allowlist + && !allow_read_only + && !allow_set.is_match(path.as_ref() as &str) + { ask = true; } }, FsReadOperation::Image(fs_image) => { let paths = &fs_image.image_paths; - let denied_match_set = - paths.iter().flat_map(|p| deny_set.matches(p)).collect::>(); + let denied_match_set = paths + .iter() + .flat_map(|path| { + let Ok(path) = directories::canonicalizes_path(os, path) else { + return vec![]; + }; + deny_set.matches(path.as_ref() as &str) + }) + .collect::>(); if !denied_match_set.is_empty() { let deny_res = PermissionEvalResult::Deny({ denied_match_set @@ -204,7 +223,10 @@ impl FsRead { // We only want to ask if we are not allowing read only // operation - if !allow_read_only && !paths.iter().any(|path| allow_set.is_match(path)) { + if !is_in_allowlist + && !allow_read_only + && !paths.iter().any(|path| allow_set.is_match(path)) + { ask = true; } }, @@ -839,10 +861,7 @@ fn format_mode(mode: u32) -> [char; 9] { #[cfg(test)] mod tests { - use std::collections::{ - HashMap, - HashSet, - }; + use std::collections::HashMap; use super::*; use crate::cli::agent::ToolSettingTarget; @@ -1377,24 +1396,19 @@ mod tests { ); } - #[test] - fn test_eval_perm() { - const DENIED_PATH_ONE: &str = "/some/denied/path"; - const DENIED_PATH_GLOB: &str = "/denied/glob/**/path"; + #[tokio::test] + async fn test_eval_perm() { + const DENIED_PATH_OR_FILE: &str = "/some/denied/path"; + const DENIED_PATH_OR_FILE_GLOB: &str = "/denied/glob/**/path"; - let agent = Agent { + let mut agent = Agent { name: "test_agent".to_string(), - allowed_tools: { - let mut allowed_tools = HashSet::::new(); - allowed_tools.insert("fs_read".to_string()); - allowed_tools - }, tools_settings: { let mut map = HashMap::::new(); map.insert( ToolSettingTarget("fs_read".to_string()), serde_json::json!({ - "deniedPaths": [DENIED_PATH_ONE, DENIED_PATH_GLOB] + "deniedPaths": [DENIED_PATH_OR_FILE, DENIED_PATH_OR_FILE_GLOB] }), ); map @@ -1402,23 +1416,35 @@ mod tests { ..Default::default() }; - let tool = serde_json::from_value::(serde_json::json!({ + let os = Os::new().await.unwrap(); + + let tool_one = serde_json::from_value::(serde_json::json!({ "operations": [ - { "path": DENIED_PATH_ONE, "mode": "Line", "start_line": 1, "end_line": 2 }, - { "path": "/denied/glob", "mode": "Directory" }, - { "path": "/denied/glob/child_one/path", "mode": "Directory" }, - { "path": "/denied/glob/child_one/grand_child_one/path", "mode": "Directory" }, - { "path": TEST_FILE_PATH, "mode": "Search", "pattern": "hello" } + { "path": DENIED_PATH_OR_FILE, "mode": "Line", "start_line": 1, "end_line": 2 }, + { "path": format!("{DENIED_PATH_OR_FILE}/child"), "mode": "Line", "start_line": 1, "end_line": 2 }, + { "path": "/denied/glob/middle_one/middle_two/path", "mode": "Line", "start_line": 1, "end_line": 2 }, + { "path": "/denied/glob/middle_one/middle_two/path/child", "mode": "Line", "start_line": 1, "end_line": 2 }, ], })) .unwrap(); - let res = tool.eval_perm(&agent); + let res = tool_one.eval_perm(&os, &agent); + assert!(matches!( + res, + PermissionEvalResult::Deny(ref deny_list) + if deny_list.iter().filter(|p| *p == DENIED_PATH_OR_FILE_GLOB).collect::>().len() == 2 + && deny_list.iter().filter(|p| *p == DENIED_PATH_OR_FILE).collect::>().len() == 2 + )); + + agent.allowed_tools.insert("fs_read".to_string()); + + // Denied set should remain denied + let res = tool_one.eval_perm(&os, &agent); assert!(matches!( res, PermissionEvalResult::Deny(ref deny_list) - if deny_list.iter().filter(|p| *p == DENIED_PATH_GLOB).collect::>().len() == 2 - && deny_list.iter().filter(|p| *p == DENIED_PATH_ONE).collect::>().len() == 1 + if deny_list.iter().filter(|p| *p == DENIED_PATH_OR_FILE_GLOB).collect::>().len() == 2 + && deny_list.iter().filter(|p| *p == DENIED_PATH_OR_FILE).collect::>().len() == 2 )); } } diff --git a/crates/chat-cli/src/cli/chat/tools/fs_write.rs b/crates/chat-cli/src/cli/chat/tools/fs_write.rs index 79151244f6..6222b0cd57 100644 --- a/crates/chat-cli/src/cli/chat/tools/fs_write.rs +++ b/crates/chat-cli/src/cli/chat/tools/fs_write.rs @@ -17,10 +17,7 @@ use eyre::{ bail, eyre, }; -use globset::{ - Glob, - GlobSetBuilder, -}; +use globset::GlobSetBuilder; use serde::Deserialize; use similar::DiffableStr; use syntect::easy::HighlightLines; @@ -47,6 +44,7 @@ use crate::cli::agent::{ }; use crate::cli::chat::line_tracker::FileLineTracker; use crate::os::Os; +use crate::util::directories; use crate::util::pattern_matching::matches_any_pattern; static SYNTAX_SET: LazyLock = LazyLock::new(SyntaxSet::load_defaults_newlines); @@ -416,7 +414,7 @@ impl FsWrite { } } - pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + pub fn eval_perm(&self, os: &Os, agent: &Agent) -> PermissionEvalResult { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct Settings { @@ -428,7 +426,7 @@ impl FsWrite { let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "fs_write"); match agent.tools_settings.get("fs_write") { - Some(settings) if is_in_allowlist => { + Some(settings) => { let Settings { allowed_paths, denied_paths, @@ -442,10 +440,11 @@ impl FsWrite { let allow_set = { let mut builder = GlobSetBuilder::new(); for path in &allowed_paths { - if let Ok(glob) = Glob::new(path) { - builder.add(glob); - } else { - warn!("Failed to create glob from path given: {path}. Ignoring."); + let Ok(path) = directories::canonicalizes_path(os, path) else { + continue; + }; + if let Err(e) = directories::add_gitignore_globs(&mut builder, path.as_str()) { + warn!("Failed to create glob from path given: {path}: {e}. Ignoring."); } } builder.build() @@ -455,11 +454,17 @@ impl FsWrite { let deny_set = { let mut builder = GlobSetBuilder::new(); for path in &denied_paths { - if let Ok(glob) = Glob::new(path) { - sanitized_deny_list.push(path); - builder.add(glob); - } else { - warn!("Failed to create glob from path given: {path}. Ignoring."); + let Ok(processed_path) = directories::canonicalizes_path(os, path) else { + continue; + }; + match directories::add_gitignore_globs(&mut builder, processed_path.as_str()) { + Ok(_) => { + // Note that we need to push twice here because for each rule we + // are creating two globs (one for file and one for directory) + sanitized_deny_list.push(path); + sanitized_deny_list.push(path); + }, + Err(e) => warn!("Failed to create glob from path given: {path}: {e}. Ignoring."), } } builder.build() @@ -472,7 +477,10 @@ impl FsWrite { | Self::Insert { path, .. } | Self::Append { path, .. } | Self::StrReplace { path, .. } => { - let denied_match_set = deny_set.matches(path); + let Ok(path) = directories::canonicalizes_path(os, path) else { + return PermissionEvalResult::Ask; + }; + let denied_match_set = deny_set.matches(path.as_ref() as &str); if !denied_match_set.is_empty() { return PermissionEvalResult::Deny({ denied_match_set @@ -481,7 +489,7 @@ impl FsWrite { .collect::>() }); } - if allow_set.is_match(path) { + if is_in_allowlist || allow_set.is_match(path.as_ref() as &str) { return PermissionEvalResult::Allow; } }, @@ -803,10 +811,7 @@ fn syntect_to_crossterm_color(syntect: syntect::highlighting::Color) -> style::C #[cfg(test)] mod tests { - use std::collections::{ - HashMap, - HashSet, - }; + use std::collections::HashMap; use super::*; use crate::cli::agent::ToolSettingTarget; @@ -1260,23 +1265,21 @@ mod tests { assert_eq!(nested_content, "content in nested path\n"); } - #[test] - fn test_eval_perm() { - const DENIED_PATH_ONE: &str = "/some/denied/path/**"; - const DENIED_PATH_GLOB: &str = "/denied/glob/**/path/**"; + #[tokio::test] + async fn test_eval_perm() { + const DENIED_PATH_ONE: &str = "/some/denied/path"; + const DENIED_PATH_GLOB: &str = "/denied/glob/**/path"; + const ALLOW_PATH_ONE: &str = "/some/allow/path"; + const ALLOW_PATH_GLOB: &str = "/allowed/glob/**/path"; - let agent = Agent { + let mut agent = Agent { name: "test_agent".to_string(), - allowed_tools: { - let mut allowed_tools = HashSet::::new(); - allowed_tools.insert("fs_write".to_string()); - allowed_tools - }, tools_settings: { let mut map = HashMap::::new(); map.insert( ToolSettingTarget("fs_write".to_string()), serde_json::json!({ + "allowedPaths": [ALLOW_PATH_ONE, ALLOW_PATH_GLOB], "deniedPaths": [DENIED_PATH_ONE, DENIED_PATH_GLOB] }), ); @@ -1285,51 +1288,129 @@ mod tests { ..Default::default() }; - let tool = serde_json::from_value::(serde_json::json!({ + let os = Os::new().await.unwrap(); + + // Test path not matching any patterns - should ask + let tool_should_ask = serde_json::from_value::(serde_json::json!({ "path": "/not/a/denied/path/file.txt", "command": "create", "file_text": "content in nested path" })) .unwrap(); - let res = tool.eval_perm(&agent); + let res = tool_should_ask.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Ask)); - let tool = serde_json::from_value::(serde_json::json!({ - "path": format!("{DENIED_PATH_ONE}/file.txt"), + // Test path matching denied pattern - should deny + let tool_should_deny = serde_json::from_value::(serde_json::json!({ + "path": "/some/denied/path/file.txt", + "command": "create", + "file_text": "content in nested path" + })) + .unwrap(); + + let res = tool_should_deny.eval_perm(&os, &agent); + assert!( + matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_ONE.to_string())) + ); + + let tool_should_deny = serde_json::from_value::(serde_json::json!({ + "path": "/some/denied/path/subdir/", "command": "create", "file_text": "content in nested path" })) .unwrap(); - let res = tool.eval_perm(&agent); + let res = tool_should_deny.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Deny(ref deny_list) if + deny_list.contains(&DENIED_PATH_ONE.to_string()))); + + let tool_should_deny = serde_json::from_value::(serde_json::json!({ + "path": "/some/denied/path", + "command": "create", + "file_text": "content in nested path" + })) + .unwrap(); + + let res = tool_should_deny.eval_perm(&os, &agent); assert!( matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_ONE.to_string())) ); - let tool = serde_json::from_value::(serde_json::json!({ - "path": format!("/denied/glob/child_one/path/file.txt"), + // Test nested glob pattern matching - should deny + let tool_three = serde_json::from_value::(serde_json::json!({ + "path": "/denied/glob/child_one/path/file.txt", "command": "create", "file_text": "content in nested path" })) .unwrap(); - let res = tool.eval_perm(&agent); + let res = tool_three.eval_perm(&os, &agent); assert!( matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_GLOB.to_string())) ); - let tool = serde_json::from_value::(serde_json::json!({ - "path": format!("/denied/glob/child_one/grand_child_one/path/file.txt"), + // Test deeply nested glob pattern matching - should deny + let tool_four = serde_json::from_value::(serde_json::json!({ + "path": "/denied/glob/child_one/grand_child_one/path/file.txt", "command": "create", "file_text": "content in nested path" })) .unwrap(); - let res = tool.eval_perm(&agent); + let res = tool_four.eval_perm(&os, &agent); assert!( matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_GLOB.to_string())) ); + + let tool_should_allow = serde_json::from_value::(serde_json::json!({ + "path": "/some/allow/path/some_file.txt", + "command": "create", + "file_text": "content in nested path" + })) + .unwrap(); + + let res = tool_should_allow.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + let tool_should_allow_with_subdir = serde_json::from_value::(serde_json::json!({ + "path": "/some/allow/path/subdir/file.txt", + "command": "create", + "file_text": "content in nested path" + })) + .unwrap(); + + let res = tool_should_allow_with_subdir.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + let tool_should_allow_glob = serde_json::from_value::(serde_json::json!({ + "path": "/allowed/glob/child_one/grand_child_one/path/some_file.txt", + "command": "create", + "file_text": "content in nested path" + })) + .unwrap(); + + let res = tool_should_allow_glob.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Test that denied patterns take precedence over allowed tools list + agent.allowed_tools.insert("fs_write".to_string()); + + let res = tool_four.eval_perm(&os, &agent); + assert!( + matches!(res, PermissionEvalResult::Deny(ref deny_list) if deny_list.contains(&DENIED_PATH_GLOB.to_string())) + ); + + // Test that exact directory name in allowed pattern works + let tool_exact_allowed_dir = serde_json::from_value::(serde_json::json!({ + "path": "/some/allow/path", + "command": "create", + "file_text": "content" + })) + .unwrap(); + + let res = tool_exact_allowed_dir.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); } #[tokio::test] diff --git a/crates/chat-cli/src/cli/chat/tools/knowledge.rs b/crates/chat-cli/src/cli/chat/tools/knowledge.rs index 639e0969e6..9f85c1b01c 100644 --- a/crates/chat-cli/src/cli/chat/tools/knowledge.rs +++ b/crates/chat-cli/src/cli/chat/tools/knowledge.rs @@ -489,8 +489,10 @@ impl Knowledge { }) } - pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + pub fn eval_perm(&self, os: &Os, agent: &Agent) -> PermissionEvalResult { _ = self; + _ = os; + if matches_any_pattern(&agent.allowed_tools, "knowledge") { PermissionEvalResult::Allow } else { diff --git a/crates/chat-cli/src/cli/chat/tools/mod.rs b/crates/chat-cli/src/cli/chat/tools/mod.rs index ea2aef2529..ae7a30900f 100644 --- a/crates/chat-cli/src/cli/chat/tools/mod.rs +++ b/crates/chat-cli/src/cli/chat/tools/mod.rs @@ -101,16 +101,16 @@ impl Tool { } /// Whether or not the tool should prompt the user to accept before [Self::invoke] is called. - pub fn requires_acceptance(&self, agent: &Agent) -> PermissionEvalResult { + pub fn requires_acceptance(&self, os: &Os, agent: &Agent) -> PermissionEvalResult { match self { - Tool::FsRead(fs_read) => fs_read.eval_perm(agent), - Tool::FsWrite(fs_write) => fs_write.eval_perm(agent), - Tool::ExecuteCommand(execute_command) => execute_command.eval_perm(agent), - Tool::UseAws(use_aws) => use_aws.eval_perm(agent), - Tool::Custom(custom_tool) => custom_tool.eval_perm(agent), + Tool::FsRead(fs_read) => fs_read.eval_perm(os, agent), + Tool::FsWrite(fs_write) => fs_write.eval_perm(os, agent), + Tool::ExecuteCommand(execute_command) => execute_command.eval_perm(os, agent), + Tool::UseAws(use_aws) => use_aws.eval_perm(os, agent), + Tool::Custom(custom_tool) => custom_tool.eval_perm(os, agent), Tool::GhIssue(_) => PermissionEvalResult::Allow, Tool::Thinking(_) => PermissionEvalResult::Allow, - Tool::Knowledge(knowledge) => knowledge.eval_perm(agent), + Tool::Knowledge(knowledge) => knowledge.eval_perm(os, agent), } } diff --git a/crates/chat-cli/src/cli/chat/tools/use_aws.rs b/crates/chat-cli/src/cli/chat/tools/use_aws.rs index 01b09126e8..456510b5bf 100644 --- a/crates/chat-cli/src/cli/chat/tools/use_aws.rs +++ b/crates/chat-cli/src/cli/chat/tools/use_aws.rs @@ -174,7 +174,7 @@ impl UseAws { } } - pub fn eval_perm(&self, agent: &Agent) -> PermissionEvalResult { + pub fn eval_perm(&self, _os: &Os, agent: &Agent) -> PermissionEvalResult { #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct Settings { @@ -187,7 +187,7 @@ impl UseAws { let Self { service_name, .. } = self; let is_in_allowlist = matches_any_pattern(&agent.allowed_tools, "use_aws"); match agent.tools_settings.get("use_aws") { - Some(settings) if is_in_allowlist => { + Some(settings) => { let settings = match serde_json::from_value::(settings.clone()) { Ok(settings) => settings, Err(e) => { @@ -198,7 +198,7 @@ impl UseAws { if settings.denied_services.contains(service_name) { return PermissionEvalResult::Deny(vec![service_name.clone()]); } - if settings.allowed_services.contains(service_name) { + if is_in_allowlist || settings.allowed_services.contains(service_name) { return PermissionEvalResult::Allow; } PermissionEvalResult::Ask @@ -217,8 +217,6 @@ impl UseAws { #[cfg(test)] mod tests { - use std::collections::HashSet; - use super::*; use crate::cli::agent::ToolSettingTarget; @@ -342,9 +340,9 @@ mod tests { } } - #[test] - fn test_eval_perm() { - let cmd = use_aws! {{ + #[tokio::test] + async fn test_eval_perm() { + let cmd_one = use_aws! {{ "service_name": "s3", "operation_name": "put-object", "region": "us-west-2", @@ -352,13 +350,8 @@ mod tests { "label": "" }}; - let agent = Agent { + let mut agent = Agent { name: "test_agent".to_string(), - allowed_tools: { - let mut allowed_tools = HashSet::::new(); - allowed_tools.insert("use_aws".to_string()); - allowed_tools - }, tools_settings: { let mut map = HashMap::::new(); map.insert( @@ -372,10 +365,12 @@ mod tests { ..Default::default() }; - let res = cmd.eval_perm(&agent); + let os = Os::new().await.unwrap(); + + let res = cmd_one.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Deny(ref services) if services.contains(&"s3".to_string()))); - let cmd = use_aws! {{ + let cmd_two = use_aws! {{ "service_name": "api_gateway", "operation_name": "request", "region": "us-west-2", @@ -383,7 +378,16 @@ mod tests { "label": "" }}; - let res = cmd.eval_perm(&agent); + let res = cmd_two.eval_perm(&os, &agent); assert!(matches!(res, PermissionEvalResult::Ask)); + + agent.allowed_tools.insert("use_aws".to_string()); + + let res = cmd_two.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Allow)); + + // Denied services should still be denied after trusting tool + let res = cmd_one.eval_perm(&os, &agent); + assert!(matches!(res, PermissionEvalResult::Deny(ref services) if services.contains(&"s3".to_string()))); } } diff --git a/crates/chat-cli/src/util/directories.rs b/crates/chat-cli/src/util/directories.rs index a726cf376c..64006e467a 100644 --- a/crates/chat-cli/src/util/directories.rs +++ b/crates/chat-cli/src/util/directories.rs @@ -1,8 +1,13 @@ +use std::env::VarError; use std::path::{ PathBuf, StripPrefixError, }; +use globset::{ + Glob, + GlobSetBuilder, +}; use thiserror::Error; use crate::os::Os; @@ -28,6 +33,10 @@ pub enum DirectoryError { IntoString(#[from] std::ffi::IntoStringError), #[error(transparent)] StripPrefix(#[from] StripPrefixError), + #[error(transparent)] + PathExpand(#[from] shellexpand::LookupError), + #[error(transparent)] + GlobCreation(#[from] globset::Error), } type Result = std::result::Result; @@ -165,6 +174,27 @@ pub fn chat_local_agent_dir(os: &Os) -> Result { Ok(cwd.join(WORKSPACE_AGENT_DIR_RELATIVE)) } +/// Canonicalizes path given by expanding the path given +pub fn canonicalizes_path(os: &Os, path_as_str: &str) -> Result { + let context = |input: &str| Ok(os.env.get(input).ok()); + let home_dir = || os.env.home().map(|p| p.to_string_lossy().to_string()); + + Ok(shellexpand::full_with_context(path_as_str, home_dir, context)?.to_string()) +} + +/// Given a globset builder and a path, build globs for both the file and directory patterns +/// This is needed because by default glob does not match children of a dir so we need both +/// patterns to exist in a globset. +pub fn add_gitignore_globs(builder: &mut GlobSetBuilder, path: &str) -> Result<()> { + let glob_for_file = Glob::new(path)?; + let glob_for_dir = Glob::new(&format!("{path}/**"))?; + + builder.add(glob_for_file); + builder.add(glob_for_dir); + + Ok(()) +} + /// Derives the absolute path to an agent config directory given a "workspace directory". /// A workspace directory is a directory where q chat is to be launched /// @@ -306,4 +336,52 @@ mod tests { let tmpdir = macos_tempdir().unwrap(); println!("{:?}", tmpdir); } + + #[tokio::test] + async fn test_canonicalizes_path() { + use std::fs; + + use tempfile::TempDir; + + let temp_dir = TempDir::new().unwrap(); + let temp_path = temp_dir.path(); + + // Create a test file and directory + let test_file = temp_path.join("test_file.txt"); + let test_dir = temp_path.join("test_dir"); + fs::write(&test_file, "test content").unwrap(); + fs::create_dir(&test_dir).unwrap(); + + let test_os = Os::new().await.unwrap(); + unsafe { + test_os.env.set_var("HOME", "/home/testuser"); + test_os.env.set_var("TEST_VAR", "test_value"); + } + + // Test home directory expansion + let result = canonicalizes_path(&test_os, "~/test").unwrap(); + assert_eq!(result, "/home/testuser/test"); + + // Test environment variable expansion + let result = canonicalizes_path(&test_os, "$TEST_VAR/path").unwrap(); + assert_eq!(result, "test_value/path"); + + // Test combined expansion + let result = canonicalizes_path(&test_os, "~/$TEST_VAR").unwrap(); + assert_eq!(result, "/home/testuser/test_value"); + + // Test absolute path (no expansion needed) + let result = canonicalizes_path(&test_os, "/absolute/path").unwrap(); + assert_eq!(result, "/absolute/path"); + + // Test relative path (no expansion needed) + let result = canonicalizes_path(&test_os, "relative/path").unwrap(); + assert_eq!(result, "relative/path"); + + // Test glob prefixed paths + let result = canonicalizes_path(&test_os, "**/path").unwrap(); + assert_eq!(result, "**/path"); + let result = canonicalizes_path(&test_os, "**/middle/**/path").unwrap(); + assert_eq!(result, "**/middle/**/path"); + } } diff --git a/docs/agent-format.md b/docs/agent-format.md index 328eb9689d..c005ad13cd 100644 --- a/docs/agent-format.md +++ b/docs/agent-format.md @@ -214,6 +214,7 @@ Unlike the `tools` field, the `allowedTools` field does not support the `"*"` wi ## ToolsSettings Field The `toolsSettings` field provides configuration for specific tools. Each tool can have its own unique configuration options. +Note that specifications that configure allowable patterns will be overridden if the tool is also included in `allowedTools`. ```json { diff --git a/docs/built-in-tools.md b/docs/built-in-tools.md index 28f7a92565..49aaf5d715 100644 --- a/docs/built-in-tools.md +++ b/docs/built-in-tools.md @@ -57,8 +57,8 @@ Tool for reading files, directories, and images. | Option | Type | Default | Description | |--------|------|---------|-------------| -| `allowedPaths` | array of strings | `[]` | List of paths that can be read without prompting. Supports glob patterns | -| `deniedPaths` | array of strings | `[]` | List of paths that are denied. Supports glob patterns. Deny rules are evaluated before allow rules | +| `allowedPaths` | array of strings | `[]` | List of paths that can be read without prompting. Supports glob patterns. Glob patterns have the same behavior as gitignore. For example, `~/temp` would match `~/temp/child` and `~/temp/child/grandchild` | +| `deniedPaths` | array of strings | `[]` | List of paths that are denied. Supports glob patterns. Deny rules are evaluated before allow rules. Glob patterns have the same behavior as gitignore. For example, `~/temp` would match `~/temp/child` and `~/temp/child/grandchild` | ## Fs_write Tool @@ -81,8 +81,8 @@ Tool for creating and editing files. | Option | Type | Default | Description | |--------|------|---------|-------------| -| `allowedPaths` | array of strings | `[]` | List of paths that can be written to without prompting. Supports glob patterns | -| `deniedPaths` | array of strings | `[]` | List of paths that are denied. Supports glob patterns. Deny rules are evaluated before allow rules | +| `allowedPaths` | array of strings | `[]` | List of paths that can be written to without prompting. Supports glob patterns. Glob patterns have the same behavior as gitignore.For example, `~/temp` would match `~/temp/child` and `~/temp/child/grandchild` | +| `deniedPaths` | array of strings | `[]` | List of paths that are denied. Supports glob patterns. Deny rules are evaluated before allow rules. Glob patterns have the same behavior as gitignore.For example, `~/temp` would match `~/temp/child` and `~/temp/child/grandchild` | ## Report_issue Tool