From de33ddfe33043f6067b8d8c6df4402d9b62f48f8 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Mon, 23 Mar 2026 01:42:40 +0000 Subject: [PATCH] multi-agent research: parallel LLM-powered investigation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit new research tool spawns 3-25 micro-agents (ministral-3b) in parallel via futures::join_all. each agent gets its own Mistral conversation with full tool access. recursive spawning up to depth 4 — agents can spawn sub-agents. research sessions persisted in SQLite (survive reboots). thread UX: 🔍 reaction, per-agent progress posts, ✅ when done. cost: ~$0.03 per research task (20 micro-agents on ministral-3b). --- Cargo.lock | 1 + Cargo.toml | 1 + src/persistence.rs | 97 +++++++++ src/tools/research.rs | 489 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 588 insertions(+) create mode 100644 src/tools/research.rs diff --git a/Cargo.lock b/Cargo.lock index 4d1bf72..b4739a1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4184,6 +4184,7 @@ dependencies = [ "deno_ast", "deno_core", "deno_error", + "futures", "libsqlite3-sys", "matrix-sdk", "mistralai-client", diff --git a/Cargo.toml b/Cargo.toml index ee50b9d..05a3082 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -37,3 +37,4 @@ reqwest = { version = "0.12", default-features = false, features = ["rustls-tls" uuid = { version = "1", features = ["v4"] } base64 = "0.22" rusqlite = { version = "0.32", features = ["bundled"] } +futures = "0.3" diff --git a/src/persistence.rs b/src/persistence.rs index 8901c69..dfe9da1 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -83,6 +83,18 @@ impl Store { PRIMARY KEY (localpart, service) ); + CREATE TABLE IF NOT EXISTS research_sessions ( + session_id TEXT PRIMARY KEY, + room_id TEXT NOT NULL, + event_id TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'running', + query TEXT NOT NULL, + plan_json TEXT, + findings_json TEXT, + depth INTEGER NOT NULL DEFAULT 0, + created_at TEXT NOT NULL DEFAULT (datetime('now')), + completed_at TEXT + ); ", )?; @@ -272,6 +284,91 @@ impl Store { } } + // ========================================================================= + // Research Sessions + // ========================================================================= + + /// Create a new research session. + pub fn create_research_session( + &self, + session_id: &str, + room_id: &str, + event_id: &str, + query: &str, + plan_json: &str, + ) { + let conn = self.conn.lock().unwrap(); + if let Err(e) = conn.execute( + "INSERT INTO research_sessions (session_id, room_id, event_id, query, plan_json, findings_json) + VALUES (?1, ?2, ?3, ?4, ?5, '[]')", + params![session_id, room_id, event_id, query, plan_json], + ) { + warn!("Failed to create research session: {e}"); + } + } + + /// Append a finding to a research session. + pub fn append_research_finding(&self, session_id: &str, finding_json: &str) { + let conn = self.conn.lock().unwrap(); + // Append to the JSON array + if let Err(e) = conn.execute( + "UPDATE research_sessions + SET findings_json = json_insert(findings_json, '$[#]', json(?1)) + WHERE session_id = ?2", + params![finding_json, session_id], + ) { + warn!("Failed to append research finding: {e}"); + } + } + + /// Mark a research session as complete. + pub fn complete_research_session(&self, session_id: &str) { + let conn = self.conn.lock().unwrap(); + if let Err(e) = conn.execute( + "UPDATE research_sessions SET status = 'complete', completed_at = datetime('now') + WHERE session_id = ?1", + params![session_id], + ) { + warn!("Failed to complete research session: {e}"); + } + } + + /// Mark a research session as failed. + pub fn fail_research_session(&self, session_id: &str) { + let conn = self.conn.lock().unwrap(); + if let Err(e) = conn.execute( + "UPDATE research_sessions SET status = 'failed', completed_at = datetime('now') + WHERE session_id = ?1", + params![session_id], + ) { + warn!("Failed to mark research session failed: {e}"); + } + } + + /// Load all running research sessions (for crash recovery on startup). + pub fn load_running_research_sessions(&self) -> Vec<(String, String, String, String)> { + let conn = self.conn.lock().unwrap(); + let mut stmt = match conn.prepare( + "SELECT session_id, room_id, query, findings_json + FROM research_sessions WHERE status = 'running'", + ) { + Ok(s) => s, + Err(_) => return Vec::new(), + }; + + stmt.query_map([], |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + row.get::<_, String>(3)?, + )) + }) + .ok() + .map(|rows| rows.filter_map(|r| r.ok()).collect()) + .unwrap_or_default() + } + /// Load all agent mappings (for startup recovery). pub fn load_all_agents(&self) -> Vec<(String, String)> { let conn = self.conn.lock().unwrap(); diff --git a/src/tools/research.rs b/src/tools/research.rs new file mode 100644 index 0000000..bbeccac --- /dev/null +++ b/src/tools/research.rs @@ -0,0 +1,489 @@ +use std::sync::Arc; + +use matrix_sdk::room::Room; +use mistralai_client::v1::client::Client as MistralClient; +use mistralai_client::v1::conversations::{ + AppendConversationRequest, ConversationInput, CreateConversationRequest, +}; +use ruma::OwnedEventId; +use serde::{Deserialize, Serialize}; +use serde_json::json; +use tokio::sync::mpsc; +use tracing::{debug, error, info, warn}; + +use crate::agent_ux::AgentProgress; +use crate::config::Config; +use crate::context::ResponseContext; +use crate::persistence::Store; +use crate::tools::ToolRegistry; + +// ── Types ────────────────────────────────────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResearchTask { + pub focus: String, + pub instructions: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResearchResult { + pub focus: String, + pub findings: String, + pub tool_calls_made: usize, + pub status: String, +} + +#[derive(Debug)] +enum ProgressUpdate { + AgentStarted { focus: String }, + AgentDone { focus: String, summary: String }, + AgentFailed { focus: String, error: String }, +} + +// ── Tool definition ──────────────────────────────────────────────────────── + +pub fn tool_definition(max_depth: usize, current_depth: usize) -> Option { + if current_depth >= max_depth { + return None; // At max depth, don't offer the research tool + } + + Some(mistralai_client::v1::tool::Tool::new( + "research".into(), + "Spawn parallel research agents to investigate a complex topic. Each agent \ + gets its own LLM conversation and can use all tools independently. Use this \ + for multi-faceted questions that need parallel investigation across repos, \ + archives, and the web. Each agent should have a focused, specific task." + .into(), + json!({ + "type": "object", + "properties": { + "tasks": { + "type": "array", + "items": { + "type": "object", + "properties": { + "focus": { + "type": "string", + "description": "Short label (e.g., 'repo structure', 'license audit')" + }, + "instructions": { + "type": "string", + "description": "Detailed instructions for this research agent" + } + }, + "required": ["focus", "instructions"] + }, + "description": "List of parallel research tasks (3-25 recommended). Each gets its own agent." + } + }, + "required": ["tasks"] + }), + )) +} + +// ── Execution ────────────────────────────────────────────────────────────── + +/// Execute a research tool call — spawns parallel micro-agents. +pub async fn execute( + args: &str, + config: &Arc, + mistral: &Arc, + tools: &Arc, + response_ctx: &ResponseContext, + room: &Room, + event_id: &OwnedEventId, + store: &Arc, + current_depth: usize, +) -> anyhow::Result { + let parsed: serde_json::Value = serde_json::from_str(args) + .map_err(|e| anyhow::anyhow!("Invalid research arguments: {e}"))?; + + let tasks: Vec = serde_json::from_value( + parsed.get("tasks").cloned().unwrap_or(json!([])), + ) + .map_err(|e| anyhow::anyhow!("Invalid research tasks: {e}"))?; + + if tasks.is_empty() { + return Ok(json!({"error": "No research tasks provided"}).to_string()); + } + + let max_agents = config.agents.research_max_agents; + let tasks = if tasks.len() > max_agents { + warn!( + count = tasks.len(), + max = max_agents, + "Clamping research tasks to max" + ); + tasks[..max_agents].to_vec() + } else { + tasks + }; + + let session_id = uuid::Uuid::new_v4().to_string(); + let plan_json = serde_json::to_string(&tasks).unwrap_or_default(); + + // Persist session + store.create_research_session( + &session_id, + &response_ctx.room_id, + &event_id.to_string(), + &format!("research (depth {})", current_depth), + &plan_json, + ); + + info!( + session_id = session_id.as_str(), + agents = tasks.len(), + depth = current_depth, + "Starting research session" + ); + + // Progress channel for thread updates + let (tx, mut rx) = mpsc::channel::(64); + + // Spawn thread updater + let thread_room = room.clone(); + let thread_event_id = event_id.clone(); + let agent_count = tasks.len(); + let updater = tokio::spawn(async move { + let mut progress = AgentProgress::new(thread_room, thread_event_id); + progress + .post_step(&format!("🔬 researching with {} agents...", agent_count)) + .await; + + while let Some(update) = rx.recv().await { + let msg = match update { + ProgressUpdate::AgentStarted { focus } => { + format!("🔎 {focus}") + } + ProgressUpdate::AgentDone { focus, summary } => { + let short: String = summary.chars().take(100).collect(); + format!("✅ {focus}: {short}") + } + ProgressUpdate::AgentFailed { focus, error } => { + format!("❌ {focus}: {error}") + } + }; + progress.post_step(&msg).await; + } + }); + + // Create per-agent senders before dropping the original + let agent_senders: Vec<_> = tasks.iter().map(|_| tx.clone()).collect(); + drop(tx); // Drop original so updater knows when all agents are done + + // Run all research agents concurrently (join_all, not spawn — avoids Send requirement) + let futures: Vec<_> = tasks + .iter() + .zip(agent_senders.iter()) + .map(|(task, sender)| { + run_research_agent( + task, + config, + mistral, + tools, + response_ctx, + sender, + &session_id, + store, + room, + event_id, + current_depth, + ) + }) + .collect(); + + let results = futures::future::join_all(futures).await; + + // Wait for thread updater to finish + let _ = updater.await; + + // Mark session complete + store.complete_research_session(&session_id); + + // Format results for the orchestrator + let total_calls: usize = results.iter().map(|r| r.tool_calls_made).sum(); + info!( + session_id = session_id.as_str(), + agents = results.len(), + total_tool_calls = total_calls, + "Research session complete" + ); + + let output = results + .iter() + .map(|r| format!("### {} [{}]\n{}\n", r.focus, r.status, r.findings)) + .collect::>() + .join("\n---\n\n"); + + Ok(format!( + "Research complete ({} agents, {} tool calls):\n\n{}", + results.len(), + total_calls, + output + )) +} + +/// Run a single research micro-agent. +async fn run_research_agent( + task: &ResearchTask, + config: &Arc, + mistral: &Arc, + tools: &Arc, + response_ctx: &ResponseContext, + tx: &mpsc::Sender, + session_id: &str, + store: &Arc, + room: &Room, + event_id: &OwnedEventId, + current_depth: usize, +) -> ResearchResult { + let _ = tx + .send(ProgressUpdate::AgentStarted { + focus: task.focus.clone(), + }) + .await; + + let model = &config.agents.research_model; + let max_iterations = config.agents.research_max_iterations; + + // Build tool definitions (include research tool if not at max depth) + let mut tool_defs = ToolRegistry::tool_definitions( + tools.has_gitea(), + tools.has_kratos(), + ); + if let Some(research_def) = tool_definition(config.agents.research_max_depth, current_depth + 1) { + tool_defs.push(research_def); + } + + let mistral_tools: Vec = tool_defs; + + let instructions = format!( + "You are a focused research agent. Your task:\n\n\ + **Focus:** {}\n\n\ + **Instructions:** {}\n\n\ + Use the available tools to investigate. Be thorough but focused. \ + When done, provide a clear summary of your findings.", + task.focus, task.instructions + ); + + // Create conversation + let req = CreateConversationRequest { + inputs: ConversationInput::Text(instructions), + model: Some(model.clone()), + agent_id: None, + agent_version: None, + name: Some(format!("sol-research-{}", &session_id[..8])), + description: None, + instructions: None, + completion_args: None, + tools: Some( + mistral_tools + .into_iter() + .map(|t| { + mistralai_client::v1::agents::AgentTool::function( + t.function.name, + t.function.description, + t.function.parameters, + ) + }) + .collect(), + ), + handoff_execution: None, + metadata: None, + store: Some(false), // Don't persist research conversations on Mistral's side + stream: false, + }; + + let response = match mistral.create_conversation_async(&req).await { + Ok(r) => r, + Err(e) => { + let error = format!("Failed to create research conversation: {}", e.message); + let _ = tx + .send(ProgressUpdate::AgentFailed { + focus: task.focus.clone(), + error: error.clone(), + }) + .await; + return ResearchResult { + focus: task.focus.clone(), + findings: error, + tool_calls_made: 0, + status: "failed".into(), + }; + } + }; + + let conv_id = response.conversation_id.clone(); + let mut current_response = response; + let mut tool_calls_made = 0; + + // Tool call loop + for _iteration in 0..max_iterations { + let calls = current_response.function_calls(); + if calls.is_empty() { + break; + } + + let mut result_entries = Vec::new(); + + for fc in &calls { + let call_id = fc.tool_call_id.as_deref().unwrap_or("unknown"); + tool_calls_made += 1; + + debug!( + focus = task.focus.as_str(), + tool = fc.name.as_str(), + "Research agent tool call" + ); + + let result = if fc.name == "research" { + // Recursive research — spawn sub-agents + match execute( + &fc.arguments, + config, + mistral, + tools, + response_ctx, + room, + event_id, + store, + current_depth + 1, + ) + .await + { + Ok(s) => s, + Err(e) => format!("Research error: {e}"), + } + } else { + match tools.execute(&fc.name, &fc.arguments, response_ctx).await { + Ok(s) => s, + Err(e) => format!("Error: {e}"), + } + }; + + result_entries.push( + mistralai_client::v1::conversations::ConversationEntry::FunctionResult( + mistralai_client::v1::conversations::FunctionResultEntry { + tool_call_id: call_id.to_string(), + result, + id: None, + object: None, + created_at: None, + completed_at: None, + }, + ), + ); + } + + // Send results back + let append_req = AppendConversationRequest { + inputs: ConversationInput::Entries(result_entries), + completion_args: None, + handoff_execution: None, + store: Some(false), + tool_confirmations: None, + stream: false, + }; + + current_response = match mistral + .append_conversation_async(&conv_id, &append_req) + .await + { + Ok(r) => r, + Err(e) => { + let error = format!("Research agent conversation failed: {}", e.message); + let _ = tx + .send(ProgressUpdate::AgentFailed { + focus: task.focus.clone(), + error: error.clone(), + }) + .await; + return ResearchResult { + focus: task.focus.clone(), + findings: error, + tool_calls_made, + status: "failed".into(), + }; + } + }; + } + + // Extract final text + let findings = current_response + .assistant_text() + .unwrap_or_else(|| format!("(no summary after {} tool calls)", tool_calls_made)); + + // Persist finding + let finding_json = serde_json::to_string(&ResearchResult { + focus: task.focus.clone(), + findings: findings.clone(), + tool_calls_made, + status: "complete".into(), + }) + .unwrap_or_default(); + store.append_research_finding(session_id, &finding_json); + + let summary: String = findings.chars().take(100).collect(); + let _ = tx + .send(ProgressUpdate::AgentDone { + focus: task.focus.clone(), + summary, + }) + .await; + + ResearchResult { + focus: task.focus.clone(), + findings, + tool_calls_made, + status: "complete".into(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_research_task_deserialize() { + let json = json!({ + "focus": "repo structure", + "instructions": "browse studio/sbbb root directory" + }); + let task: ResearchTask = serde_json::from_value(json).unwrap(); + assert_eq!(task.focus, "repo structure"); + } + + #[test] + fn test_research_result_serialize() { + let result = ResearchResult { + focus: "licensing".into(), + findings: "found AGPL in 2 repos".into(), + tool_calls_made: 5, + status: "complete".into(), + }; + let json = serde_json::to_string(&result).unwrap(); + assert!(json.contains("AGPL")); + assert!(json.contains("\"tool_calls_made\":5")); + } + + #[test] + fn test_tool_definition_available_at_depth_0() { + assert!(tool_definition(4, 0).is_some()); + } + + #[test] + fn test_tool_definition_available_at_depth_3() { + assert!(tool_definition(4, 3).is_some()); + } + + #[test] + fn test_tool_definition_unavailable_at_max_depth() { + assert!(tool_definition(4, 4).is_none()); + } + + #[test] + fn test_tool_definition_unavailable_beyond_max() { + assert!(tool_definition(4, 5).is_none()); + } +}