multi-agent research: parallel LLM-powered investigation
new research tool spawns 3-25 micro-agents (ministral-3b) in parallel via futures::join_all. each agent gets its own Mistral conversation with full tool access. recursive spawning up to depth 4 — agents can spawn sub-agents. research sessions persisted in SQLite (survive reboots). thread UX: 🔍 reaction, per-agent progress posts, ✅ when done. cost: ~$0.03 per research task (20 micro-agents on ministral-3b).
This commit is contained in:
@@ -83,6 +83,18 @@ impl Store {
|
||||
PRIMARY KEY (localpart, service)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS research_sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'running',
|
||||
query TEXT NOT NULL,
|
||||
plan_json TEXT,
|
||||
findings_json TEXT,
|
||||
depth INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
);
|
||||
",
|
||||
)?;
|
||||
|
||||
@@ -272,6 +284,91 @@ impl Store {
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Research Sessions
|
||||
// =========================================================================
|
||||
|
||||
/// Create a new research session.
|
||||
pub fn create_research_session(
|
||||
&self,
|
||||
session_id: &str,
|
||||
room_id: &str,
|
||||
event_id: &str,
|
||||
query: &str,
|
||||
plan_json: &str,
|
||||
) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
if let Err(e) = conn.execute(
|
||||
"INSERT INTO research_sessions (session_id, room_id, event_id, query, plan_json, findings_json)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, '[]')",
|
||||
params![session_id, room_id, event_id, query, plan_json],
|
||||
) {
|
||||
warn!("Failed to create research session: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Append a finding to a research session.
|
||||
pub fn append_research_finding(&self, session_id: &str, finding_json: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
// Append to the JSON array
|
||||
if let Err(e) = conn.execute(
|
||||
"UPDATE research_sessions
|
||||
SET findings_json = json_insert(findings_json, '$[#]', json(?1))
|
||||
WHERE session_id = ?2",
|
||||
params![finding_json, session_id],
|
||||
) {
|
||||
warn!("Failed to append research finding: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a research session as complete.
|
||||
pub fn complete_research_session(&self, session_id: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
if let Err(e) = conn.execute(
|
||||
"UPDATE research_sessions SET status = 'complete', completed_at = datetime('now')
|
||||
WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
) {
|
||||
warn!("Failed to complete research session: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a research session as failed.
|
||||
pub fn fail_research_session(&self, session_id: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
if let Err(e) = conn.execute(
|
||||
"UPDATE research_sessions SET status = 'failed', completed_at = datetime('now')
|
||||
WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
) {
|
||||
warn!("Failed to mark research session failed: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all running research sessions (for crash recovery on startup).
|
||||
pub fn load_running_research_sessions(&self) -> Vec<(String, String, String, String)> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT session_id, room_id, query, findings_json
|
||||
FROM research_sessions WHERE status = 'running'",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
stmt.query_map([], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
row.get::<_, String>(3)?,
|
||||
))
|
||||
})
|
||||
.ok()
|
||||
.map(|rows| rows.filter_map(|r| r.ok()).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Load all agent mappings (for startup recovery).
|
||||
pub fn load_all_agents(&self) -> Vec<(String, String)> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
|
||||
489
src/tools/research.rs
Normal file
489
src/tools/research.rs
Normal file
@@ -0,0 +1,489 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk::room::Room;
|
||||
use mistralai_client::v1::client::Client as MistralClient;
|
||||
use mistralai_client::v1::conversations::{
|
||||
AppendConversationRequest, ConversationInput, CreateConversationRequest,
|
||||
};
|
||||
use ruma::OwnedEventId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::agent_ux::AgentProgress;
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::persistence::Store;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
// ── Types ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResearchTask {
|
||||
pub focus: String,
|
||||
pub instructions: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResearchResult {
|
||||
pub focus: String,
|
||||
pub findings: String,
|
||||
pub tool_calls_made: usize,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ProgressUpdate {
|
||||
AgentStarted { focus: String },
|
||||
AgentDone { focus: String, summary: String },
|
||||
AgentFailed { focus: String, error: String },
|
||||
}
|
||||
|
||||
// ── Tool definition ────────────────────────────────────────────────────────
|
||||
|
||||
pub fn tool_definition(max_depth: usize, current_depth: usize) -> Option<mistralai_client::v1::tool::Tool> {
|
||||
if current_depth >= max_depth {
|
||||
return None; // At max depth, don't offer the research tool
|
||||
}
|
||||
|
||||
Some(mistralai_client::v1::tool::Tool::new(
|
||||
"research".into(),
|
||||
"Spawn parallel research agents to investigate a complex topic. Each agent \
|
||||
gets its own LLM conversation and can use all tools independently. Use this \
|
||||
for multi-faceted questions that need parallel investigation across repos, \
|
||||
archives, and the web. Each agent should have a focused, specific task."
|
||||
.into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"focus": {
|
||||
"type": "string",
|
||||
"description": "Short label (e.g., 'repo structure', 'license audit')"
|
||||
},
|
||||
"instructions": {
|
||||
"type": "string",
|
||||
"description": "Detailed instructions for this research agent"
|
||||
}
|
||||
},
|
||||
"required": ["focus", "instructions"]
|
||||
},
|
||||
"description": "List of parallel research tasks (3-25 recommended). Each gets its own agent."
|
||||
}
|
||||
},
|
||||
"required": ["tasks"]
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
// ── Execution ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Execute a research tool call — spawns parallel micro-agents.
|
||||
pub async fn execute(
|
||||
args: &str,
|
||||
config: &Arc<Config>,
|
||||
mistral: &Arc<MistralClient>,
|
||||
tools: &Arc<ToolRegistry>,
|
||||
response_ctx: &ResponseContext,
|
||||
room: &Room,
|
||||
event_id: &OwnedEventId,
|
||||
store: &Arc<Store>,
|
||||
current_depth: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let parsed: serde_json::Value = serde_json::from_str(args)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid research arguments: {e}"))?;
|
||||
|
||||
let tasks: Vec<ResearchTask> = serde_json::from_value(
|
||||
parsed.get("tasks").cloned().unwrap_or(json!([])),
|
||||
)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid research tasks: {e}"))?;
|
||||
|
||||
if tasks.is_empty() {
|
||||
return Ok(json!({"error": "No research tasks provided"}).to_string());
|
||||
}
|
||||
|
||||
let max_agents = config.agents.research_max_agents;
|
||||
let tasks = if tasks.len() > max_agents {
|
||||
warn!(
|
||||
count = tasks.len(),
|
||||
max = max_agents,
|
||||
"Clamping research tasks to max"
|
||||
);
|
||||
tasks[..max_agents].to_vec()
|
||||
} else {
|
||||
tasks
|
||||
};
|
||||
|
||||
let session_id = uuid::Uuid::new_v4().to_string();
|
||||
let plan_json = serde_json::to_string(&tasks).unwrap_or_default();
|
||||
|
||||
// Persist session
|
||||
store.create_research_session(
|
||||
&session_id,
|
||||
&response_ctx.room_id,
|
||||
&event_id.to_string(),
|
||||
&format!("research (depth {})", current_depth),
|
||||
&plan_json,
|
||||
);
|
||||
|
||||
info!(
|
||||
session_id = session_id.as_str(),
|
||||
agents = tasks.len(),
|
||||
depth = current_depth,
|
||||
"Starting research session"
|
||||
);
|
||||
|
||||
// Progress channel for thread updates
|
||||
let (tx, mut rx) = mpsc::channel::<ProgressUpdate>(64);
|
||||
|
||||
// Spawn thread updater
|
||||
let thread_room = room.clone();
|
||||
let thread_event_id = event_id.clone();
|
||||
let agent_count = tasks.len();
|
||||
let updater = tokio::spawn(async move {
|
||||
let mut progress = AgentProgress::new(thread_room, thread_event_id);
|
||||
progress
|
||||
.post_step(&format!("🔬 researching with {} agents...", agent_count))
|
||||
.await;
|
||||
|
||||
while let Some(update) = rx.recv().await {
|
||||
let msg = match update {
|
||||
ProgressUpdate::AgentStarted { focus } => {
|
||||
format!("🔎 {focus}")
|
||||
}
|
||||
ProgressUpdate::AgentDone { focus, summary } => {
|
||||
let short: String = summary.chars().take(100).collect();
|
||||
format!("✅ {focus}: {short}")
|
||||
}
|
||||
ProgressUpdate::AgentFailed { focus, error } => {
|
||||
format!("❌ {focus}: {error}")
|
||||
}
|
||||
};
|
||||
progress.post_step(&msg).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Create per-agent senders before dropping the original
|
||||
let agent_senders: Vec<_> = tasks.iter().map(|_| tx.clone()).collect();
|
||||
drop(tx); // Drop original so updater knows when all agents are done
|
||||
|
||||
// Run all research agents concurrently (join_all, not spawn — avoids Send requirement)
|
||||
let futures: Vec<_> = tasks
|
||||
.iter()
|
||||
.zip(agent_senders.iter())
|
||||
.map(|(task, sender)| {
|
||||
run_research_agent(
|
||||
task,
|
||||
config,
|
||||
mistral,
|
||||
tools,
|
||||
response_ctx,
|
||||
sender,
|
||||
&session_id,
|
||||
store,
|
||||
room,
|
||||
event_id,
|
||||
current_depth,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = futures::future::join_all(futures).await;
|
||||
|
||||
// Wait for thread updater to finish
|
||||
let _ = updater.await;
|
||||
|
||||
// Mark session complete
|
||||
store.complete_research_session(&session_id);
|
||||
|
||||
// Format results for the orchestrator
|
||||
let total_calls: usize = results.iter().map(|r| r.tool_calls_made).sum();
|
||||
info!(
|
||||
session_id = session_id.as_str(),
|
||||
agents = results.len(),
|
||||
total_tool_calls = total_calls,
|
||||
"Research session complete"
|
||||
);
|
||||
|
||||
let output = results
|
||||
.iter()
|
||||
.map(|r| format!("### {} [{}]\n{}\n", r.focus, r.status, r.findings))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n---\n\n");
|
||||
|
||||
Ok(format!(
|
||||
"Research complete ({} agents, {} tool calls):\n\n{}",
|
||||
results.len(),
|
||||
total_calls,
|
||||
output
|
||||
))
|
||||
}
|
||||
|
||||
/// Run a single research micro-agent.
|
||||
async fn run_research_agent(
|
||||
task: &ResearchTask,
|
||||
config: &Arc<Config>,
|
||||
mistral: &Arc<MistralClient>,
|
||||
tools: &Arc<ToolRegistry>,
|
||||
response_ctx: &ResponseContext,
|
||||
tx: &mpsc::Sender<ProgressUpdate>,
|
||||
session_id: &str,
|
||||
store: &Arc<Store>,
|
||||
room: &Room,
|
||||
event_id: &OwnedEventId,
|
||||
current_depth: usize,
|
||||
) -> ResearchResult {
|
||||
let _ = tx
|
||||
.send(ProgressUpdate::AgentStarted {
|
||||
focus: task.focus.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
let model = &config.agents.research_model;
|
||||
let max_iterations = config.agents.research_max_iterations;
|
||||
|
||||
// Build tool definitions (include research tool if not at max depth)
|
||||
let mut tool_defs = ToolRegistry::tool_definitions(
|
||||
tools.has_gitea(),
|
||||
tools.has_kratos(),
|
||||
);
|
||||
if let Some(research_def) = tool_definition(config.agents.research_max_depth, current_depth + 1) {
|
||||
tool_defs.push(research_def);
|
||||
}
|
||||
|
||||
let mistral_tools: Vec<mistralai_client::v1::tool::Tool> = tool_defs;
|
||||
|
||||
let instructions = format!(
|
||||
"You are a focused research agent. Your task:\n\n\
|
||||
**Focus:** {}\n\n\
|
||||
**Instructions:** {}\n\n\
|
||||
Use the available tools to investigate. Be thorough but focused. \
|
||||
When done, provide a clear summary of your findings.",
|
||||
task.focus, task.instructions
|
||||
);
|
||||
|
||||
// Create conversation
|
||||
let req = CreateConversationRequest {
|
||||
inputs: ConversationInput::Text(instructions),
|
||||
model: Some(model.clone()),
|
||||
agent_id: None,
|
||||
agent_version: None,
|
||||
name: Some(format!("sol-research-{}", &session_id[..8])),
|
||||
description: None,
|
||||
instructions: None,
|
||||
completion_args: None,
|
||||
tools: Some(
|
||||
mistral_tools
|
||||
.into_iter()
|
||||
.map(|t| {
|
||||
mistralai_client::v1::agents::AgentTool::function(
|
||||
t.function.name,
|
||||
t.function.description,
|
||||
t.function.parameters,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
handoff_execution: None,
|
||||
metadata: None,
|
||||
store: Some(false), // Don't persist research conversations on Mistral's side
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = match mistral.create_conversation_async(&req).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let error = format!("Failed to create research conversation: {}", e.message);
|
||||
let _ = tx
|
||||
.send(ProgressUpdate::AgentFailed {
|
||||
focus: task.focus.clone(),
|
||||
error: error.clone(),
|
||||
})
|
||||
.await;
|
||||
return ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings: error,
|
||||
tool_calls_made: 0,
|
||||
status: "failed".into(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let conv_id = response.conversation_id.clone();
|
||||
let mut current_response = response;
|
||||
let mut tool_calls_made = 0;
|
||||
|
||||
// Tool call loop
|
||||
for _iteration in 0..max_iterations {
|
||||
let calls = current_response.function_calls();
|
||||
if calls.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut result_entries = Vec::new();
|
||||
|
||||
for fc in &calls {
|
||||
let call_id = fc.tool_call_id.as_deref().unwrap_or("unknown");
|
||||
tool_calls_made += 1;
|
||||
|
||||
debug!(
|
||||
focus = task.focus.as_str(),
|
||||
tool = fc.name.as_str(),
|
||||
"Research agent tool call"
|
||||
);
|
||||
|
||||
let result = if fc.name == "research" {
|
||||
// Recursive research — spawn sub-agents
|
||||
match execute(
|
||||
&fc.arguments,
|
||||
config,
|
||||
mistral,
|
||||
tools,
|
||||
response_ctx,
|
||||
room,
|
||||
event_id,
|
||||
store,
|
||||
current_depth + 1,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(e) => format!("Research error: {e}"),
|
||||
}
|
||||
} else {
|
||||
match tools.execute(&fc.name, &fc.arguments, response_ctx).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => format!("Error: {e}"),
|
||||
}
|
||||
};
|
||||
|
||||
result_entries.push(
|
||||
mistralai_client::v1::conversations::ConversationEntry::FunctionResult(
|
||||
mistralai_client::v1::conversations::FunctionResultEntry {
|
||||
tool_call_id: call_id.to_string(),
|
||||
result,
|
||||
id: None,
|
||||
object: None,
|
||||
created_at: None,
|
||||
completed_at: None,
|
||||
},
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// Send results back
|
||||
let append_req = AppendConversationRequest {
|
||||
inputs: ConversationInput::Entries(result_entries),
|
||||
completion_args: None,
|
||||
handoff_execution: None,
|
||||
store: Some(false),
|
||||
tool_confirmations: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
current_response = match mistral
|
||||
.append_conversation_async(&conv_id, &append_req)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let error = format!("Research agent conversation failed: {}", e.message);
|
||||
let _ = tx
|
||||
.send(ProgressUpdate::AgentFailed {
|
||||
focus: task.focus.clone(),
|
||||
error: error.clone(),
|
||||
})
|
||||
.await;
|
||||
return ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings: error,
|
||||
tool_calls_made,
|
||||
status: "failed".into(),
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Extract final text
|
||||
let findings = current_response
|
||||
.assistant_text()
|
||||
.unwrap_or_else(|| format!("(no summary after {} tool calls)", tool_calls_made));
|
||||
|
||||
// Persist finding
|
||||
let finding_json = serde_json::to_string(&ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings: findings.clone(),
|
||||
tool_calls_made,
|
||||
status: "complete".into(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
store.append_research_finding(session_id, &finding_json);
|
||||
|
||||
let summary: String = findings.chars().take(100).collect();
|
||||
let _ = tx
|
||||
.send(ProgressUpdate::AgentDone {
|
||||
focus: task.focus.clone(),
|
||||
summary,
|
||||
})
|
||||
.await;
|
||||
|
||||
ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings,
|
||||
tool_calls_made,
|
||||
status: "complete".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_research_task_deserialize() {
|
||||
let json = json!({
|
||||
"focus": "repo structure",
|
||||
"instructions": "browse studio/sbbb root directory"
|
||||
});
|
||||
let task: ResearchTask = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(task.focus, "repo structure");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_result_serialize() {
|
||||
let result = ResearchResult {
|
||||
focus: "licensing".into(),
|
||||
findings: "found AGPL in 2 repos".into(),
|
||||
tool_calls_made: 5,
|
||||
status: "complete".into(),
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("AGPL"));
|
||||
assert!(json.contains("\"tool_calls_made\":5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition_available_at_depth_0() {
|
||||
assert!(tool_definition(4, 0).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition_available_at_depth_3() {
|
||||
assert!(tool_definition(4, 3).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition_unavailable_at_max_depth() {
|
||||
assert!(tool_definition(4, 4).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition_unavailable_beyond_max() {
|
||||
assert!(tool_definition(4, 5).is_none());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user