|
| 1 | +use std::io::Write; |
| 2 | +use std::sync::Arc; |
| 3 | + |
| 4 | +use agent::AgentHandle; |
| 5 | +use agent::agent_config::load_agents; |
| 6 | +use agent::agent_loop::protocol::LoopEndReason; |
| 7 | +use agent::mcp::McpManager; |
| 8 | +use agent::protocol::{ |
| 9 | + AgentEvent, |
| 10 | + AgentStopReason, |
| 11 | + ApprovalResult, |
| 12 | + ContentChunk, |
| 13 | + SendApprovalResultArgs, |
| 14 | + SendPromptArgs, |
| 15 | +}; |
| 16 | +use agent::types::AgentSnapshot; |
| 17 | +use eyre::{ |
| 18 | + Result, |
| 19 | + bail, |
| 20 | +}; |
| 21 | +use rts::{ |
| 22 | + RtsModel, |
| 23 | + RtsModelState, |
| 24 | +}; |
| 25 | +use serde::{ |
| 26 | + Deserialize, |
| 27 | + Serialize, |
| 28 | +}; |
| 29 | +use tracing::{ |
| 30 | + debug, |
| 31 | + error, |
| 32 | + info, |
| 33 | + warn, |
| 34 | +}; |
| 35 | + |
| 36 | +use crate::os::Os; |
| 37 | + |
1 | 38 | mod rts; |
| 39 | + |
| 40 | +#[derive(Debug, Clone, Serialize, Deserialize)] |
| 41 | +struct JsonOutput { |
| 42 | + /// Whether or not the user turn completed successfully |
| 43 | + is_error: bool, |
| 44 | + /// Text from the final message, if available |
| 45 | + result: Option<String>, |
| 46 | + /// The number of requests sent to the model |
| 47 | + number_of_requests: u32, |
| 48 | + /// The number of tool use / tool result pairs in the turn |
| 49 | + /// |
| 50 | + /// This could be less than the number of requests in the case of retries |
| 51 | + number_of_cycles: u32, |
| 52 | + /// Duration of the turn, in milliseconds |
| 53 | + duration_ms: u32, |
| 54 | +} |
| 55 | + |
| 56 | +#[derive(Clone, Debug, Serialize, Deserialize)] |
| 57 | +pub struct QueryResult { |
| 58 | + pub context_summary: String, |
| 59 | + pub task_summary: String, |
| 60 | +} |
| 61 | + |
| 62 | +#[derive(Debug)] |
| 63 | +pub struct SubAgent<'a> { |
| 64 | + pub query: &'a str, |
| 65 | + pub agent_name: Option<&'a str>, |
| 66 | + pub embedded_user_msg: Option<&'a str>, |
| 67 | + pub dangerously_trust_all_tools: bool, |
| 68 | +} |
| 69 | + |
| 70 | +impl<'a> SubAgent<'a> { |
| 71 | + pub async fn query(self, os: &mut Os, output: &mut impl Write) -> Result<QueryResult> { |
| 72 | + let mut snapshot = AgentSnapshot::default(); |
| 73 | + |
| 74 | + let model = { |
| 75 | + let rts_state: RtsModelState = snapshot |
| 76 | + .model_state |
| 77 | + .as_ref() |
| 78 | + .and_then(|s| { |
| 79 | + serde_json::from_value(s.clone()) |
| 80 | + .map_err(|err| error!(?err, ?s, "failed to deserialize RTS state")) |
| 81 | + .ok() |
| 82 | + }) |
| 83 | + .unwrap_or({ |
| 84 | + let state = RtsModelState::new(); |
| 85 | + info!(?state.conversation_id, "generated new conversation id"); |
| 86 | + state |
| 87 | + }); |
| 88 | + Arc::new(RtsModel::new( |
| 89 | + os.client.clone(), |
| 90 | + rts_state.conversation_id, |
| 91 | + rts_state.model_id, |
| 92 | + )) |
| 93 | + }; |
| 94 | + |
| 95 | + if let Some(name) = self.agent_name { |
| 96 | + let (configs, _) = load_agents().await?; |
| 97 | + if let Some(cfg) = configs.into_iter().find(|c| c.name() == name) { |
| 98 | + snapshot.agent_config = cfg.config().clone(); |
| 99 | + } else { |
| 100 | + bail!("unable to find agent with name: {}", name); |
| 101 | + } |
| 102 | + }; |
| 103 | + |
| 104 | + let mcp_manager_handle = McpManager::default().spawn(); |
| 105 | + let agent = agent::Agent::new(snapshot, model, mcp_manager_handle).await?.spawn(); |
| 106 | + |
| 107 | + self.main_loop(agent, output).await |
| 108 | + } |
| 109 | + |
| 110 | + async fn main_loop(&self, mut agent: AgentHandle, output: &mut impl Write) -> Result<QueryResult> { |
| 111 | + // First, wait for agent initialization |
| 112 | + while let Ok(evt) = agent.recv().await { |
| 113 | + if matches!(evt, AgentEvent::Mcp(_)) { |
| 114 | + info!(?evt, "received mcp agent event"); |
| 115 | + // TODO: Send it through conduit |
| 116 | + } |
| 117 | + if matches!(evt, AgentEvent::Initialized) { |
| 118 | + break; |
| 119 | + } |
| 120 | + } |
| 121 | + |
| 122 | + agent |
| 123 | + .send_prompt(SendPromptArgs { |
| 124 | + content: vec![ContentChunk::Text(self.query.to_string())], |
| 125 | + should_continue_turn: None, |
| 126 | + }) |
| 127 | + .await?; |
| 128 | + |
| 129 | + // Holds the final result of the user turn. |
| 130 | + #[allow(unused_assignments)] |
| 131 | + let mut user_turn_metadata = None; |
| 132 | + |
| 133 | + loop { |
| 134 | + let Ok(evt) = agent.recv().await else { |
| 135 | + bail!("channel closed"); |
| 136 | + }; |
| 137 | + debug!(?evt, "received new agent event"); |
| 138 | + |
| 139 | + // Check for exit conditions |
| 140 | + match &evt { |
| 141 | + AgentEvent::Update(evt) => { |
| 142 | + info!(?evt, "received update event"); |
| 143 | + println!("received update event {:?}", evt); |
| 144 | + }, |
| 145 | + AgentEvent::EndTurn(metadata) => { |
| 146 | + user_turn_metadata = Some(metadata.clone()); |
| 147 | + break; |
| 148 | + }, |
| 149 | + AgentEvent::Stop(AgentStopReason::Error(agent_error)) => { |
| 150 | + bail!("agent encountered an error: {:?}", agent_error) |
| 151 | + }, |
| 152 | + AgentEvent::ApprovalRequest { id, tool_use, .. } => { |
| 153 | + if !self.dangerously_trust_all_tools { |
| 154 | + bail!("Tool approval is required: {:?}", tool_use); |
| 155 | + } else { |
| 156 | + warn!(?tool_use, "trust all is enabled, ignoring approval request"); |
| 157 | + agent |
| 158 | + .send_tool_use_approval_result(SendApprovalResultArgs { |
| 159 | + id: id.clone(), |
| 160 | + result: ApprovalResult::Approve, |
| 161 | + }) |
| 162 | + .await?; |
| 163 | + } |
| 164 | + }, |
| 165 | + AgentEvent::Mcp(evt) => { |
| 166 | + info!(?evt, "received mcp agent event"); |
| 167 | + }, |
| 168 | + _ => {}, |
| 169 | + } |
| 170 | + } |
| 171 | + |
| 172 | + let md = user_turn_metadata.expect("user turn metadata should exist"); |
| 173 | + let is_error = md.end_reason != LoopEndReason::UserTurnEnd || md.result.as_ref().is_none_or(|v| v.is_err()); |
| 174 | + let result = md.result.and_then(|r| r.ok().map(|m| m.text())); |
| 175 | + |
| 176 | + let output = JsonOutput { |
| 177 | + result, |
| 178 | + is_error, |
| 179 | + number_of_requests: md.total_request_count, |
| 180 | + number_of_cycles: md.number_of_cycles, |
| 181 | + duration_ms: md.turn_duration.map(|d| d.as_millis() as u32).unwrap_or_default(), |
| 182 | + }; |
| 183 | + |
| 184 | + info!(?output, "sub agent routine completed"); |
| 185 | + |
| 186 | + Ok(QueryResult { |
| 187 | + context_summary: Default::default(), |
| 188 | + task_summary: Default::default(), |
| 189 | + }) |
| 190 | + } |
| 191 | +} |
| 192 | + |
| 193 | +pub fn temp_func() { |
| 194 | + let rt = tokio::runtime::Builder::new_multi_thread() |
| 195 | + .enable_all() |
| 196 | + .build() |
| 197 | + .expect("failed to build runtime"); |
| 198 | + |
| 199 | + rt.block_on(test_sub_agent_routine()); |
| 200 | +} |
| 201 | + |
| 202 | +async fn test_sub_agent_routine() { |
| 203 | + let sub_agent = SubAgent { |
| 204 | + query: "What notion docs do I have", |
| 205 | + agent_name: Some("test_test"), |
| 206 | + embedded_user_msg: None, |
| 207 | + dangerously_trust_all_tools: true, |
| 208 | + }; |
| 209 | + |
| 210 | + let mut os = Os::new().await.expect("failed to spawn os"); |
| 211 | + let mut output = Vec::<u8>::new(); |
| 212 | + |
| 213 | + _ = sub_agent.query(&mut os, &mut output).await; |
| 214 | +} |
0 commit comments