- ApprovalDecision enum (Approved/Denied/ApprovedAlways) - Approval channel (crossbeam) from TUI to agent loop - Agent checks config.permission_for() on each client tool call - "always" auto-executes, "never" auto-denies, "ask" prompts - ApprovedAlways upgrades session permission for future calls - Unit tests for permissions, decisions, error messages
387 lines
17 KiB
Rust
387 lines
17 KiB
Rust
//! Agent service — async message bus between TUI and Sol gRPC session.
|
|
//!
|
|
//! The TUI sends `AgentRequest`s and receives `AgentEvent`s through
|
|
//! crossbeam channels. The gRPC session runs on a background tokio task,
|
|
//! so the UI thread never blocks on network I/O.
|
|
//!
|
|
//! Tool approval: when a client tool requires approval ("ask" in config),
|
|
//! the agent emits `ApprovalNeeded` and waits for a `decide()` call from
|
|
//! the TUI before executing or denying.
|
|
//!
|
|
//! This module is designed to be usable as a library — nothing here
|
|
//! depends on ratatui or terminal state.
|
|
|
|
use crossbeam_channel::{Receiver, Sender};
|
|
|
|
use super::client::{self, CodeSession};
|
|
use super::config::LoadedConfig;
|
|
|
|
/// Turn raw internal errors into something a human can read.
|
|
fn friendly_error(e: &str) -> String {
|
|
let lower = e.to_lowercase();
|
|
if lower.contains("broken pipe") || lower.contains("stream closed") || lower.contains("h2 protocol") {
|
|
"sol disconnected — try again or restart with /exit".into()
|
|
} else if lower.contains("channel closed") || lower.contains("send on closed") {
|
|
"connection to sol lost".into()
|
|
} else if lower.contains("timed out") || lower.contains("timeout") {
|
|
"request timed out — sol may be overloaded".into()
|
|
} else if lower.contains("connection refused") {
|
|
"can't reach sol — is it running?".into()
|
|
} else if lower.contains("not found") && lower.contains("agent") {
|
|
"sol's agent was reset — reconnect with /exit".into()
|
|
} else if lower.contains("invalid_request_error") {
|
|
if let Some(start) = e.find("\"msg\":\"") {
|
|
let rest = &e[start + 7..];
|
|
if let Some(end) = rest.find('"') {
|
|
return rest[..end].to_string();
|
|
}
|
|
}
|
|
"request error from sol".into()
|
|
} else {
|
|
let clean = e.replace("\\n", " ").replace("\\\"", "'");
|
|
if clean.len() > 120 { format!("{}…", &clean[..117]) } else { clean }
|
|
}
|
|
}
|
|
|
|
// ── Requests (TUI → Agent) ──────────────────────────────────────────────
|
|
|
|
/// A request from the UI to the agent backend.
|
|
pub enum AgentRequest {
|
|
/// Send a chat message to Sol.
|
|
Chat { text: String },
|
|
/// End the session gracefully.
|
|
End,
|
|
}
|
|
|
|
// ── Approval (TUI → Agent) ─────────────────────────────────────────────
|
|
|
|
/// A tool approval decision from the UI.
|
|
#[derive(Debug, Clone)]
|
|
pub enum ApprovalDecision {
|
|
/// User approved — execute the tool.
|
|
Approved { call_id: String },
|
|
/// User denied — return error to model.
|
|
Denied { call_id: String },
|
|
/// User approved AND upgraded permission to "always" for this session.
|
|
ApprovedAlways { call_id: String, tool_name: String },
|
|
// Future: ApprovedRemote { call_id } — execute on server sidecar
|
|
}
|
|
|
|
// ── Events (Agent → TUI) ───────────────────────────────────────────────
|
|
|
|
/// An event from the agent backend to the UI.
|
|
#[derive(Clone, Debug)]
|
|
pub enum AgentEvent {
|
|
/// Sol started generating a response.
|
|
Generating,
|
|
/// A tool needs user approval before execution.
|
|
ApprovalNeeded { call_id: String, name: String, args_summary: String },
|
|
/// Tool was approved and is now executing.
|
|
ToolExecuting { name: String, detail: String },
|
|
/// A tool finished executing.
|
|
ToolDone { name: String, success: bool },
|
|
/// Sol's full response text with token usage.
|
|
Response { text: String, input_tokens: u32, output_tokens: u32 },
|
|
/// A non-fatal error from Sol.
|
|
Error { message: String },
|
|
/// Status update (shown in title bar).
|
|
Status { message: String },
|
|
/// Connection health: true = reachable, false = unreachable.
|
|
Health { connected: bool },
|
|
/// Session ended.
|
|
SessionEnded,
|
|
}
|
|
|
|
// ── Agent handle (owned by TUI) ────────────────────────────────────────
|
|
|
|
/// Handle for the TUI to communicate with the background agent task.
|
|
pub struct AgentHandle {
|
|
req_tx: Sender<AgentRequest>,
|
|
approval_tx: Sender<ApprovalDecision>,
|
|
pub rx: Receiver<AgentEvent>,
|
|
}
|
|
|
|
impl AgentHandle {
|
|
/// Send a chat message. Non-blocking.
|
|
pub fn chat(&self, text: &str) {
|
|
let _ = self.req_tx.try_send(AgentRequest::Chat { text: text.to_string() });
|
|
}
|
|
|
|
/// Request session end. Non-blocking.
|
|
pub fn end(&self) {
|
|
let _ = self.req_tx.try_send(AgentRequest::End);
|
|
}
|
|
|
|
/// Submit a tool approval decision. Non-blocking.
|
|
pub fn decide(&self, decision: ApprovalDecision) {
|
|
let _ = self.approval_tx.try_send(decision);
|
|
}
|
|
|
|
/// Drain all pending events. Non-blocking.
|
|
pub fn poll_events(&self) -> Vec<AgentEvent> {
|
|
let mut events = Vec::new();
|
|
while let Ok(event) = self.rx.try_recv() {
|
|
events.push(event);
|
|
}
|
|
events
|
|
}
|
|
}
|
|
|
|
// ── Spawn ──────────────────────────────────────────────────────────────
|
|
|
|
/// Spawn the agent background task. Returns a handle for the TUI.
|
|
pub fn spawn(
|
|
session: CodeSession,
|
|
endpoint: String,
|
|
config: LoadedConfig,
|
|
project_path: String,
|
|
) -> AgentHandle {
|
|
let (req_tx, req_rx) = crossbeam_channel::bounded::<AgentRequest>(32);
|
|
let (evt_tx, evt_rx) = crossbeam_channel::bounded::<AgentEvent>(256);
|
|
let (approval_tx, approval_rx) = crossbeam_channel::bounded::<ApprovalDecision>(8);
|
|
|
|
tokio::spawn(agent_loop(session, config, project_path, req_rx, approval_rx, evt_tx.clone()));
|
|
tokio::spawn(heartbeat_loop(endpoint, evt_tx));
|
|
|
|
AgentHandle { req_tx, approval_tx, rx: evt_rx }
|
|
}
|
|
|
|
/// Ping the gRPC endpoint every second to check if Sol is reachable.
|
|
async fn heartbeat_loop(endpoint: String, evt_tx: Sender<AgentEvent>) {
|
|
use sunbeam_proto::sunbeam_code_v1::code_agent_client::CodeAgentClient;
|
|
|
|
let mut last_state = true;
|
|
let _ = evt_tx.try_send(AgentEvent::Health { connected: true });
|
|
|
|
loop {
|
|
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
|
let connected = CodeAgentClient::connect(endpoint.clone()).await.is_ok();
|
|
if connected != last_state {
|
|
let _ = evt_tx.try_send(AgentEvent::Health { connected });
|
|
last_state = connected;
|
|
}
|
|
}
|
|
}
|
|
|
|
/// The background agent loop. Reads requests, calls gRPC, handles tool
|
|
/// approval and execution.
|
|
async fn agent_loop(
|
|
mut session: CodeSession,
|
|
mut config: LoadedConfig,
|
|
project_path: String,
|
|
req_rx: Receiver<AgentRequest>,
|
|
approval_rx: Receiver<ApprovalDecision>,
|
|
evt_tx: Sender<AgentEvent>,
|
|
) {
|
|
loop {
|
|
let req = match tokio::task::block_in_place(|| req_rx.recv()) {
|
|
Ok(req) => req,
|
|
Err(_) => break,
|
|
};
|
|
|
|
match req {
|
|
AgentRequest::Chat { text } => {
|
|
let _ = evt_tx.try_send(AgentEvent::Generating);
|
|
|
|
match session.chat(&text).await {
|
|
Ok(resp) => {
|
|
// Process events — handle tool calls with approval
|
|
for event in &resp.events {
|
|
match event {
|
|
client::ChatEvent::ToolCall { call_id, name, args, needs_approval } => {
|
|
let perm = config.permission_for(name);
|
|
|
|
match perm {
|
|
"always" => {
|
|
// Execute immediately
|
|
let _ = evt_tx.try_send(AgentEvent::ToolExecuting {
|
|
name: name.clone(),
|
|
detail: truncate_args(args),
|
|
});
|
|
let result = super::tools::execute(name, args, &project_path);
|
|
let _ = evt_tx.try_send(AgentEvent::ToolDone {
|
|
name: name.clone(),
|
|
success: true,
|
|
});
|
|
// Tool result already sent by client.rs
|
|
}
|
|
"never" => {
|
|
let _ = evt_tx.try_send(AgentEvent::ToolDone {
|
|
name: name.clone(),
|
|
success: false,
|
|
});
|
|
// Tool denial already sent by client.rs
|
|
}
|
|
_ => {
|
|
// "ask" — need user approval
|
|
let _ = evt_tx.try_send(AgentEvent::ApprovalNeeded {
|
|
call_id: call_id.clone(),
|
|
name: name.clone(),
|
|
args_summary: truncate_args(args),
|
|
});
|
|
|
|
// Wait for approval decision (blocking on crossbeam)
|
|
match tokio::task::block_in_place(|| approval_rx.recv()) {
|
|
Ok(ApprovalDecision::Approved { .. }) => {
|
|
let _ = evt_tx.try_send(AgentEvent::ToolExecuting {
|
|
name: name.clone(),
|
|
detail: truncate_args(args),
|
|
});
|
|
// Tool already executed by client.rs
|
|
let _ = evt_tx.try_send(AgentEvent::ToolDone {
|
|
name: name.clone(),
|
|
success: true,
|
|
});
|
|
}
|
|
Ok(ApprovalDecision::ApprovedAlways { tool_name, .. }) => {
|
|
config.upgrade_to_always(&tool_name);
|
|
let _ = evt_tx.try_send(AgentEvent::ToolExecuting {
|
|
name: name.clone(),
|
|
detail: truncate_args(args),
|
|
});
|
|
let _ = evt_tx.try_send(AgentEvent::ToolDone {
|
|
name: name.clone(),
|
|
success: true,
|
|
});
|
|
}
|
|
Ok(ApprovalDecision::Denied { .. }) => {
|
|
let _ = evt_tx.try_send(AgentEvent::ToolDone {
|
|
name: name.clone(),
|
|
success: false,
|
|
});
|
|
}
|
|
Err(_) => break,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
client::ChatEvent::ToolStart { name, detail } => {
|
|
let _ = evt_tx.try_send(AgentEvent::ToolExecuting {
|
|
name: name.clone(),
|
|
detail: detail.clone(),
|
|
});
|
|
}
|
|
client::ChatEvent::ToolDone { name, success } => {
|
|
let _ = evt_tx.try_send(AgentEvent::ToolDone {
|
|
name: name.clone(),
|
|
success: *success,
|
|
});
|
|
}
|
|
client::ChatEvent::Status(msg) => {
|
|
let _ = evt_tx.try_send(AgentEvent::Status {
|
|
message: msg.clone(),
|
|
});
|
|
}
|
|
client::ChatEvent::Error(msg) => {
|
|
let _ = evt_tx.try_send(AgentEvent::Error {
|
|
message: friendly_error(msg),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
let _ = evt_tx.try_send(AgentEvent::Response {
|
|
text: resp.text,
|
|
input_tokens: resp.input_tokens,
|
|
output_tokens: resp.output_tokens,
|
|
});
|
|
}
|
|
Err(e) => {
|
|
let _ = evt_tx.try_send(AgentEvent::Error {
|
|
message: friendly_error(&e.to_string()),
|
|
});
|
|
}
|
|
}
|
|
}
|
|
AgentRequest::End => {
|
|
let _ = session.end().await;
|
|
let _ = evt_tx.try_send(AgentEvent::SessionEnded);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
fn truncate_args(args: &str) -> String {
|
|
if args.len() <= 80 { args.to_string() } else { format!("{}…", &args[..77]) }
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_approval_decision_variants() {
|
|
let approved = ApprovalDecision::Approved { call_id: "c1".into() };
|
|
assert!(matches!(approved, ApprovalDecision::Approved { .. }));
|
|
|
|
let denied = ApprovalDecision::Denied { call_id: "c2".into() };
|
|
assert!(matches!(denied, ApprovalDecision::Denied { .. }));
|
|
|
|
let always = ApprovalDecision::ApprovedAlways {
|
|
call_id: "c3".into(),
|
|
tool_name: "bash".into(),
|
|
};
|
|
assert!(matches!(always, ApprovalDecision::ApprovedAlways { .. }));
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_routing() {
|
|
let config = LoadedConfig::default();
|
|
|
|
// "always" tools should not need approval
|
|
assert_eq!(config.permission_for("file_read"), "always");
|
|
assert_eq!(config.permission_for("grep"), "always");
|
|
assert_eq!(config.permission_for("list_directory"), "always");
|
|
|
|
// "ask" tools need approval
|
|
assert_eq!(config.permission_for("file_write"), "ask");
|
|
assert_eq!(config.permission_for("bash"), "ask");
|
|
assert_eq!(config.permission_for("search_replace"), "ask");
|
|
|
|
// unknown defaults to ask
|
|
assert_eq!(config.permission_for("unknown_tool"), "ask");
|
|
}
|
|
|
|
#[test]
|
|
fn test_permission_upgrade() {
|
|
let mut config = LoadedConfig::default();
|
|
assert_eq!(config.permission_for("bash"), "ask");
|
|
|
|
config.upgrade_to_always("bash");
|
|
assert_eq!(config.permission_for("bash"), "always");
|
|
|
|
// Other tools unchanged
|
|
assert_eq!(config.permission_for("file_write"), "ask");
|
|
}
|
|
|
|
#[test]
|
|
fn test_friendly_error_messages() {
|
|
assert_eq!(
|
|
friendly_error("h2 protocol error: stream closed because of a broken pipe"),
|
|
"sol disconnected — try again or restart with /exit"
|
|
);
|
|
assert_eq!(
|
|
friendly_error("channel closed"),
|
|
"connection to sol lost"
|
|
);
|
|
assert_eq!(
|
|
friendly_error("connection refused"),
|
|
"can't reach sol — is it running?"
|
|
);
|
|
assert_eq!(
|
|
friendly_error("request timed out"),
|
|
"request timed out — sol may be overloaded"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_truncate_args() {
|
|
assert_eq!(truncate_args("short"), "short");
|
|
let long = "a".repeat(100);
|
|
let truncated = truncate_args(&long);
|
|
assert!(truncated.len() <= 81);
|
|
assert!(truncated.ends_with('…'));
|
|
}
|
|
}
|