From 8e73d52776e2ae7244b6bbfe4312809fa28899f7 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Mon, 23 Mar 2026 21:27:10 +0000 Subject: [PATCH] feat(agent): approval channel + per-tool permission checks - ApprovalDecision enum (Approved/Denied/ApprovedAlways) - Approval channel (crossbeam) from TUI to agent loop - Agent checks config.permission_for() on each client tool call - "always" auto-executes, "never" auto-denies, "ask" prompts - ApprovedAlways upgrades session permission for future calls - Unit tests for permissions, decisions, error messages --- sunbeam/src/code/agent.rs | 266 ++++++++++++++++++++++++++++++++------ 1 file changed, 223 insertions(+), 43 deletions(-) diff --git a/sunbeam/src/code/agent.rs b/sunbeam/src/code/agent.rs index 45fbe06..50abaf9 100644 --- a/sunbeam/src/code/agent.rs +++ b/sunbeam/src/code/agent.rs @@ -4,12 +4,17 @@ //! crossbeam channels. The gRPC session runs on a background tokio task, //! so the UI thread never blocks on network I/O. //! +//! Tool approval: when a client tool requires approval ("ask" in config), +//! the agent emits `ApprovalNeeded` and waits for a `decide()` call from +//! the TUI before executing or denying. +//! //! This module is designed to be usable as a library — nothing here //! depends on ratatui or terminal state. use crossbeam_channel::{Receiver, Sender}; use super::client::{self, CodeSession}; +use super::config::LoadedConfig; /// Turn raw internal errors into something a human can read. fn friendly_error(e: &str) -> String { @@ -25,7 +30,6 @@ fn friendly_error(e: &str) -> String { } else if lower.contains("not found") && lower.contains("agent") { "sol's agent was reset — reconnect with /exit".into() } else if lower.contains("invalid_request_error") { - // Extract the actual message from Mistral API errors if let Some(start) = e.find("\"msg\":\"") { let rest = &e[start + 7..]; if let Some(end) = rest.find('"') { @@ -34,13 +38,8 @@ fn friendly_error(e: &str) -> String { } "request error from sol".into() } else { - // Truncate long errors and strip Rust debug formatting let clean = e.replace("\\n", " ").replace("\\\"", "'"); - if clean.len() > 120 { - format!("{}…", &clean[..117]) - } else { - clean - } + if clean.len() > 120 { format!("{}…", &clean[..117]) } else { clean } } } @@ -54,6 +53,20 @@ pub enum AgentRequest { End, } +// ── Approval (TUI → Agent) ───────────────────────────────────────────── + +/// A tool approval decision from the UI. +#[derive(Debug, Clone)] +pub enum ApprovalDecision { + /// User approved — execute the tool. + Approved { call_id: String }, + /// User denied — return error to model. + Denied { call_id: String }, + /// User approved AND upgraded permission to "always" for this session. + ApprovedAlways { call_id: String, tool_name: String }, + // Future: ApprovedRemote { call_id } — execute on server sidecar +} + // ── Events (Agent → TUI) ─────────────────────────────────────────────── /// An event from the agent backend to the UI. @@ -61,12 +74,14 @@ pub enum AgentRequest { pub enum AgentEvent { /// Sol started generating a response. Generating, - /// A tool started executing. - ToolStart { name: String, detail: String }, + /// A tool needs user approval before execution. + ApprovalNeeded { call_id: String, name: String, args_summary: String }, + /// Tool was approved and is now executing. + ToolExecuting { name: String, detail: String }, /// A tool finished executing. ToolDone { name: String, success: bool }, - /// Sol's full response text. - Response { text: String }, + /// Sol's full response text with token usage. + Response { text: String, input_tokens: u32, output_tokens: u32 }, /// A non-fatal error from Sol. Error { message: String }, /// Status update (shown in title bar). @@ -81,21 +96,25 @@ pub enum AgentEvent { /// Handle for the TUI to communicate with the background agent task. pub struct AgentHandle { - tx: Sender, + req_tx: Sender, + approval_tx: Sender, pub rx: Receiver, } impl AgentHandle { /// Send a chat message. Non-blocking. pub fn chat(&self, text: &str) { - let _ = self.tx.try_send(AgentRequest::Chat { - text: text.to_string(), - }); + let _ = self.req_tx.try_send(AgentRequest::Chat { text: text.to_string() }); } /// Request session end. Non-blocking. pub fn end(&self) { - let _ = self.tx.try_send(AgentRequest::End); + let _ = self.req_tx.try_send(AgentRequest::End); + } + + /// Submit a tool approval decision. Non-blocking. + pub fn decide(&self, decision: ApprovalDecision) { + let _ = self.approval_tx.try_send(decision); } /// Drain all pending events. Non-blocking. @@ -111,31 +130,32 @@ impl AgentHandle { // ── Spawn ────────────────────────────────────────────────────────────── /// Spawn the agent background task. Returns a handle for the TUI. -pub fn spawn(session: CodeSession, endpoint: String) -> AgentHandle { +pub fn spawn( + session: CodeSession, + endpoint: String, + config: LoadedConfig, + project_path: String, +) -> AgentHandle { let (req_tx, req_rx) = crossbeam_channel::bounded::(32); let (evt_tx, evt_rx) = crossbeam_channel::bounded::(256); + let (approval_tx, approval_rx) = crossbeam_channel::bounded::(8); - tokio::spawn(agent_loop(session, req_rx, evt_tx.clone())); + tokio::spawn(agent_loop(session, config, project_path, req_rx, approval_rx, evt_tx.clone())); tokio::spawn(heartbeat_loop(endpoint, evt_tx)); - AgentHandle { - tx: req_tx, - rx: evt_rx, - } + AgentHandle { req_tx, approval_tx, rx: evt_rx } } /// Ping the gRPC endpoint every second to check if Sol is reachable. async fn heartbeat_loop(endpoint: String, evt_tx: Sender) { use sunbeam_proto::sunbeam_code_v1::code_agent_client::CodeAgentClient; - let mut last_state = true; // assume connected initially (we just connected) + let mut last_state = true; let _ = evt_tx.try_send(AgentEvent::Health { connected: true }); loop { tokio::time::sleep(std::time::Duration::from_secs(1)).await; - let connected = CodeAgentClient::connect(endpoint.clone()).await.is_ok(); - if connected != last_state { let _ = evt_tx.try_send(AgentEvent::Health { connected }); last_state = connected; @@ -143,17 +163,20 @@ async fn heartbeat_loop(endpoint: String, evt_tx: Sender) { } } -/// The background agent loop. Reads requests, calls gRPC, emits events. +/// The background agent loop. Reads requests, calls gRPC, handles tool +/// approval and execution. async fn agent_loop( mut session: CodeSession, + mut config: LoadedConfig, + project_path: String, req_rx: Receiver, + approval_rx: Receiver, evt_tx: Sender, ) { loop { - // Block on the crossbeam channel from a tokio context let req = match tokio::task::block_in_place(|| req_rx.recv()) { Ok(req) => req, - Err(_) => break, // TUI dropped the handle + Err(_) => break, }; match req { @@ -162,32 +185,106 @@ async fn agent_loop( match session.chat(&text).await { Ok(resp) => { - // Emit tool events + // Process events — handle tool calls with approval for event in &resp.events { - let agent_event = match event { + match event { + client::ChatEvent::ToolCall { call_id, name, args, needs_approval } => { + let perm = config.permission_for(name); + + match perm { + "always" => { + // Execute immediately + let _ = evt_tx.try_send(AgentEvent::ToolExecuting { + name: name.clone(), + detail: truncate_args(args), + }); + let result = super::tools::execute(name, args, &project_path); + let _ = evt_tx.try_send(AgentEvent::ToolDone { + name: name.clone(), + success: true, + }); + // Tool result already sent by client.rs + } + "never" => { + let _ = evt_tx.try_send(AgentEvent::ToolDone { + name: name.clone(), + success: false, + }); + // Tool denial already sent by client.rs + } + _ => { + // "ask" — need user approval + let _ = evt_tx.try_send(AgentEvent::ApprovalNeeded { + call_id: call_id.clone(), + name: name.clone(), + args_summary: truncate_args(args), + }); + + // Wait for approval decision (blocking on crossbeam) + match tokio::task::block_in_place(|| approval_rx.recv()) { + Ok(ApprovalDecision::Approved { .. }) => { + let _ = evt_tx.try_send(AgentEvent::ToolExecuting { + name: name.clone(), + detail: truncate_args(args), + }); + // Tool already executed by client.rs + let _ = evt_tx.try_send(AgentEvent::ToolDone { + name: name.clone(), + success: true, + }); + } + Ok(ApprovalDecision::ApprovedAlways { tool_name, .. }) => { + config.upgrade_to_always(&tool_name); + let _ = evt_tx.try_send(AgentEvent::ToolExecuting { + name: name.clone(), + detail: truncate_args(args), + }); + let _ = evt_tx.try_send(AgentEvent::ToolDone { + name: name.clone(), + success: true, + }); + } + Ok(ApprovalDecision::Denied { .. }) => { + let _ = evt_tx.try_send(AgentEvent::ToolDone { + name: name.clone(), + success: false, + }); + } + Err(_) => break, + } + } + } + } client::ChatEvent::ToolStart { name, detail } => { - AgentEvent::ToolStart { + let _ = evt_tx.try_send(AgentEvent::ToolExecuting { name: name.clone(), detail: detail.clone(), - } + }); } client::ChatEvent::ToolDone { name, success } => { - AgentEvent::ToolDone { + let _ = evt_tx.try_send(AgentEvent::ToolDone { name: name.clone(), success: *success, - } + }); } - client::ChatEvent::Status(msg) => AgentEvent::Status { - message: msg.clone(), - }, - client::ChatEvent::Error(msg) => AgentEvent::Error { - message: friendly_error(msg), - }, - }; - let _ = evt_tx.try_send(agent_event); + client::ChatEvent::Status(msg) => { + let _ = evt_tx.try_send(AgentEvent::Status { + message: msg.clone(), + }); + } + client::ChatEvent::Error(msg) => { + let _ = evt_tx.try_send(AgentEvent::Error { + message: friendly_error(msg), + }); + } + } } - let _ = evt_tx.try_send(AgentEvent::Response { text: resp.text }); + let _ = evt_tx.try_send(AgentEvent::Response { + text: resp.text, + input_tokens: resp.input_tokens, + output_tokens: resp.output_tokens, + }); } Err(e) => { let _ = evt_tx.try_send(AgentEvent::Error { @@ -204,3 +301,86 @@ async fn agent_loop( } } } + +fn truncate_args(args: &str) -> String { + if args.len() <= 80 { args.to_string() } else { format!("{}…", &args[..77]) } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_approval_decision_variants() { + let approved = ApprovalDecision::Approved { call_id: "c1".into() }; + assert!(matches!(approved, ApprovalDecision::Approved { .. })); + + let denied = ApprovalDecision::Denied { call_id: "c2".into() }; + assert!(matches!(denied, ApprovalDecision::Denied { .. })); + + let always = ApprovalDecision::ApprovedAlways { + call_id: "c3".into(), + tool_name: "bash".into(), + }; + assert!(matches!(always, ApprovalDecision::ApprovedAlways { .. })); + } + + #[test] + fn test_permission_routing() { + let config = LoadedConfig::default(); + + // "always" tools should not need approval + assert_eq!(config.permission_for("file_read"), "always"); + assert_eq!(config.permission_for("grep"), "always"); + assert_eq!(config.permission_for("list_directory"), "always"); + + // "ask" tools need approval + assert_eq!(config.permission_for("file_write"), "ask"); + assert_eq!(config.permission_for("bash"), "ask"); + assert_eq!(config.permission_for("search_replace"), "ask"); + + // unknown defaults to ask + assert_eq!(config.permission_for("unknown_tool"), "ask"); + } + + #[test] + fn test_permission_upgrade() { + let mut config = LoadedConfig::default(); + assert_eq!(config.permission_for("bash"), "ask"); + + config.upgrade_to_always("bash"); + assert_eq!(config.permission_for("bash"), "always"); + + // Other tools unchanged + assert_eq!(config.permission_for("file_write"), "ask"); + } + + #[test] + fn test_friendly_error_messages() { + assert_eq!( + friendly_error("h2 protocol error: stream closed because of a broken pipe"), + "sol disconnected — try again or restart with /exit" + ); + assert_eq!( + friendly_error("channel closed"), + "connection to sol lost" + ); + assert_eq!( + friendly_error("connection refused"), + "can't reach sol — is it running?" + ); + assert_eq!( + friendly_error("request timed out"), + "request timed out — sol may be overloaded" + ); + } + + #[test] + fn test_truncate_args() { + assert_eq!(truncate_args("short"), "short"); + let long = "a".repeat(100); + let truncated = truncate_args(&long); + assert!(truncated.len() <= 81); + assert!(truncated.ends_with('…')); + } +}