Skip to content

Commit 5004b8b

Browse files
committed
implements api to invoke new session with agent crate
1 parent 81d125f commit 5004b8b

File tree

4 files changed

+223
-0
lines changed

4 files changed

+223
-0
lines changed

crates/chat-cli/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ version.workspace = true
88
license.workspace = true
99
default-run = "chat_cli"
1010

11+
[[bin]]
12+
name = "temp_bin"
13+
path = "src/bin/temp_bin.rs"
14+
1115
[lints]
1216
workspace = true
1317

crates/chat-cli/src/agent/mod.rs

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,214 @@
1+
use std::io::Write;
2+
use std::sync::Arc;
3+
4+
use agent::AgentHandle;
5+
use agent::agent_config::load_agents;
6+
use agent::agent_loop::protocol::LoopEndReason;
7+
use agent::mcp::McpManager;
8+
use agent::protocol::{
9+
AgentEvent,
10+
AgentStopReason,
11+
ApprovalResult,
12+
ContentChunk,
13+
SendApprovalResultArgs,
14+
SendPromptArgs,
15+
};
16+
use agent::types::AgentSnapshot;
17+
use eyre::{
18+
Result,
19+
bail,
20+
};
21+
use rts::{
22+
RtsModel,
23+
RtsModelState,
24+
};
25+
use serde::{
26+
Deserialize,
27+
Serialize,
28+
};
29+
use tracing::{
30+
debug,
31+
error,
32+
info,
33+
warn,
34+
};
35+
36+
use crate::os::Os;
37+
138
mod rts;
39+
40+
#[derive(Debug, Clone, Serialize, Deserialize)]
41+
struct JsonOutput {
42+
/// Whether or not the user turn completed successfully
43+
is_error: bool,
44+
/// Text from the final message, if available
45+
result: Option<String>,
46+
/// The number of requests sent to the model
47+
number_of_requests: u32,
48+
/// The number of tool use / tool result pairs in the turn
49+
///
50+
/// This could be less than the number of requests in the case of retries
51+
number_of_cycles: u32,
52+
/// Duration of the turn, in milliseconds
53+
duration_ms: u32,
54+
}
55+
56+
#[derive(Clone, Debug, Serialize, Deserialize)]
57+
pub struct QueryResult {
58+
pub context_summary: String,
59+
pub task_summary: String,
60+
}
61+
62+
#[derive(Debug)]
63+
pub struct SubAgent<'a> {
64+
pub query: &'a str,
65+
pub agent_name: Option<&'a str>,
66+
pub embedded_user_msg: Option<&'a str>,
67+
pub dangerously_trust_all_tools: bool,
68+
}
69+
70+
impl<'a> SubAgent<'a> {
71+
pub async fn query(self, os: &mut Os, output: &mut impl Write) -> Result<QueryResult> {
72+
let mut snapshot = AgentSnapshot::default();
73+
74+
let model = {
75+
let rts_state: RtsModelState = snapshot
76+
.model_state
77+
.as_ref()
78+
.and_then(|s| {
79+
serde_json::from_value(s.clone())
80+
.map_err(|err| error!(?err, ?s, "failed to deserialize RTS state"))
81+
.ok()
82+
})
83+
.unwrap_or({
84+
let state = RtsModelState::new();
85+
info!(?state.conversation_id, "generated new conversation id");
86+
state
87+
});
88+
Arc::new(RtsModel::new(
89+
os.client.clone(),
90+
rts_state.conversation_id,
91+
rts_state.model_id,
92+
))
93+
};
94+
95+
if let Some(name) = self.agent_name {
96+
let (configs, _) = load_agents().await?;
97+
if let Some(cfg) = configs.into_iter().find(|c| c.name() == name) {
98+
snapshot.agent_config = cfg.config().clone();
99+
} else {
100+
bail!("unable to find agent with name: {}", name);
101+
}
102+
};
103+
104+
let mcp_manager_handle = McpManager::default().spawn();
105+
let agent = agent::Agent::new(snapshot, model, mcp_manager_handle).await?.spawn();
106+
107+
self.main_loop(agent, output).await
108+
}
109+
110+
async fn main_loop(&self, mut agent: AgentHandle, output: &mut impl Write) -> Result<QueryResult> {
111+
// First, wait for agent initialization
112+
while let Ok(evt) = agent.recv().await {
113+
if matches!(evt, AgentEvent::Mcp(_)) {
114+
info!(?evt, "received mcp agent event");
115+
// TODO: Send it through conduit
116+
}
117+
if matches!(evt, AgentEvent::Initialized) {
118+
break;
119+
}
120+
}
121+
122+
agent
123+
.send_prompt(SendPromptArgs {
124+
content: vec![ContentChunk::Text(self.query.to_string())],
125+
should_continue_turn: None,
126+
})
127+
.await?;
128+
129+
// Holds the final result of the user turn.
130+
#[allow(unused_assignments)]
131+
let mut user_turn_metadata = None;
132+
133+
loop {
134+
let Ok(evt) = agent.recv().await else {
135+
bail!("channel closed");
136+
};
137+
debug!(?evt, "received new agent event");
138+
139+
// Check for exit conditions
140+
match &evt {
141+
AgentEvent::Update(evt) => {
142+
info!(?evt, "received update event");
143+
println!("received update event {:?}", evt);
144+
},
145+
AgentEvent::EndTurn(metadata) => {
146+
user_turn_metadata = Some(metadata.clone());
147+
break;
148+
},
149+
AgentEvent::Stop(AgentStopReason::Error(agent_error)) => {
150+
bail!("agent encountered an error: {:?}", agent_error)
151+
},
152+
AgentEvent::ApprovalRequest { id, tool_use, .. } => {
153+
if !self.dangerously_trust_all_tools {
154+
bail!("Tool approval is required: {:?}", tool_use);
155+
} else {
156+
warn!(?tool_use, "trust all is enabled, ignoring approval request");
157+
agent
158+
.send_tool_use_approval_result(SendApprovalResultArgs {
159+
id: id.clone(),
160+
result: ApprovalResult::Approve,
161+
})
162+
.await?;
163+
}
164+
},
165+
AgentEvent::Mcp(evt) => {
166+
info!(?evt, "received mcp agent event");
167+
},
168+
_ => {},
169+
}
170+
}
171+
172+
let md = user_turn_metadata.expect("user turn metadata should exist");
173+
let is_error = md.end_reason != LoopEndReason::UserTurnEnd || md.result.as_ref().is_none_or(|v| v.is_err());
174+
let result = md.result.and_then(|r| r.ok().map(|m| m.text()));
175+
176+
let output = JsonOutput {
177+
result,
178+
is_error,
179+
number_of_requests: md.total_request_count,
180+
number_of_cycles: md.number_of_cycles,
181+
duration_ms: md.turn_duration.map(|d| d.as_millis() as u32).unwrap_or_default(),
182+
};
183+
184+
info!(?output, "sub agent routine completed");
185+
186+
Ok(QueryResult {
187+
context_summary: Default::default(),
188+
task_summary: Default::default(),
189+
})
190+
}
191+
}
192+
193+
pub fn temp_func() {
194+
let rt = tokio::runtime::Builder::new_multi_thread()
195+
.enable_all()
196+
.build()
197+
.expect("failed to build runtime");
198+
199+
rt.block_on(test_sub_agent_routine());
200+
}
201+
202+
async fn test_sub_agent_routine() {
203+
let sub_agent = SubAgent {
204+
query: "What notion docs do I have",
205+
agent_name: Some("test_test"),
206+
embedded_user_msg: None,
207+
dangerously_trust_all_tools: true,
208+
};
209+
210+
let mut os = Os::new().await.expect("failed to spawn os");
211+
let mut output = Vec::<u8>::new();
212+
213+
_ = sub_agent.query(&mut os, &mut output).await;
214+
}
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
use chat_cli::agent::temp_func;
2+
3+
fn main() {
4+
temp_func();
5+
}

crates/chat-cli/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
//! This lib.rs is only here for testing purposes.
33
//! `test_mcp_server/test_server.rs` is declared as a separate binary and would need a way to
44
//! reference types defined inside of this crate, hence the export.
5+
pub mod agent;
56
pub mod api_client;
67
pub mod auth;
78
pub mod aws_common;

0 commit comments

Comments
 (0)