feat: multi-agent architecture with Conversations API and persistent state
Mistral Agents + Conversations API integration:
- Orchestrator agent created on startup with Sol's personality + tools
- ConversationRegistry routes messages through persistent conversations
- Per-room conversation state (room_id → conversation_id + token counts)
- Function call handling within conversation responses
- Configurable via [agents] section in sol.toml (use_conversations_api flag)
Multimodal support:
- m.image detection and Matrix media download (mxc:// → base64 data URI)
- ContentPart-based messages sent to Mistral vision models
- Archive stores media_urls for image messages
System prompt rewrite:
- 687 → 150 lines — dense, few-shot examples, hard rules
- {room_context_rules} placeholder for group vs DM behavior
- Sender prefixing (<@user:server>) for multi-user turns in group rooms
SQLite persistence (/data/sol.db):
- Conversation mappings and agent IDs survive reboots
- WAL mode for concurrent reads
- Falls back to in-memory on failure (sneezes into all rooms to signal)
- PVC already mounted at /data alongside Matrix SDK state store
New modules:
- src/persistence.rs — SQLite state store
- src/conversations.rs — ConversationRegistry + message merging
- src/agents/{mod,definitions,registry}.rs — agent lifecycle
- src/agent_ux.rs — reaction + thread progress UX
- src/tools/bridge.rs — tool dispatch for domain agents
102 tests passing.
This commit is contained in:
139
src/agent_ux.rs
Normal file
139
src/agent_ux.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use matrix_sdk::room::Room;
|
||||
use ruma::events::relation::InReplyTo;
|
||||
use ruma::events::room::message::{Relation, RoomMessageEventContent};
|
||||
use ruma::OwnedEventId;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::matrix_utils;
|
||||
|
||||
/// Reaction emojis for agent progress lifecycle.
|
||||
const REACTION_WORKING: &str = "\u{1F50D}"; // 🔍
|
||||
const REACTION_PROCESSING: &str = "\u{2699}\u{FE0F}"; // ⚙️
|
||||
const REACTION_DONE: &str = "\u{2705}"; // ✅
|
||||
|
||||
/// Manages the UX lifecycle for agentic work:
|
||||
/// reactions on the user's message + a thread for tool call details.
|
||||
pub struct AgentProgress {
|
||||
room: Room,
|
||||
user_event_id: OwnedEventId,
|
||||
/// Event ID of the current reaction (so we can redact + replace).
|
||||
current_reaction_id: Option<OwnedEventId>,
|
||||
/// Event ID of the thread root (first message in our thread).
|
||||
thread_root_id: Option<OwnedEventId>,
|
||||
}
|
||||
|
||||
impl AgentProgress {
|
||||
pub fn new(room: Room, user_event_id: OwnedEventId) -> Self {
|
||||
Self {
|
||||
room,
|
||||
user_event_id,
|
||||
current_reaction_id: None,
|
||||
thread_root_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start: add 🔍 reaction to indicate work has begun.
|
||||
pub async fn start(&mut self) {
|
||||
if let Ok(()) = matrix_utils::send_reaction(
|
||||
&self.room,
|
||||
self.user_event_id.clone(),
|
||||
REACTION_WORKING,
|
||||
)
|
||||
.await
|
||||
{
|
||||
// We can't easily get the reaction event ID from send_reaction,
|
||||
// so we track the emoji state instead.
|
||||
self.current_reaction_id = None; // TODO: capture reaction event ID if needed
|
||||
}
|
||||
}
|
||||
|
||||
/// Post a step update to the thread. Creates the thread on first call.
|
||||
pub async fn post_step(&mut self, text: &str) {
|
||||
let content = if let Some(ref _root) = self.thread_root_id {
|
||||
// Reply in existing thread
|
||||
let mut msg = RoomMessageEventContent::text_markdown(text);
|
||||
msg.relates_to = Some(Relation::Reply {
|
||||
in_reply_to: InReplyTo::new(self.user_event_id.clone()),
|
||||
});
|
||||
msg
|
||||
} else {
|
||||
// First message — starts the thread as a reply to the user's message
|
||||
let mut msg = RoomMessageEventContent::text_markdown(text);
|
||||
msg.relates_to = Some(Relation::Reply {
|
||||
in_reply_to: InReplyTo::new(self.user_event_id.clone()),
|
||||
});
|
||||
msg
|
||||
};
|
||||
|
||||
match self.room.send(content).await {
|
||||
Ok(response) => {
|
||||
if self.thread_root_id.is_none() {
|
||||
self.thread_root_id = Some(response.event_id);
|
||||
}
|
||||
}
|
||||
Err(e) => warn!("Failed to post agent step: {e}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Swap reaction to ⚙️ (processing).
|
||||
pub async fn processing(&mut self) {
|
||||
// Send new reaction (Matrix doesn't have "replace reaction" — we add another)
|
||||
let _ = matrix_utils::send_reaction(
|
||||
&self.room,
|
||||
self.user_event_id.clone(),
|
||||
REACTION_PROCESSING,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Swap reaction to ✅ (done).
|
||||
pub async fn done(&mut self) {
|
||||
let _ = matrix_utils::send_reaction(
|
||||
&self.room,
|
||||
self.user_event_id.clone(),
|
||||
REACTION_DONE,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Format a tool call for the thread.
|
||||
pub fn format_tool_call(name: &str, args: &str) -> String {
|
||||
format!("`{name}` → ```json\n{args}\n```")
|
||||
}
|
||||
|
||||
/// Format a tool result for the thread.
|
||||
pub fn format_tool_result(name: &str, result: &str) -> String {
|
||||
let truncated = if result.len() > 500 {
|
||||
format!("{}…", &result[..500])
|
||||
} else {
|
||||
result.to_string()
|
||||
};
|
||||
format!("`{name}` ← {truncated}")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_call() {
|
||||
let formatted = AgentProgress::format_tool_call("search_archive", r#"{"query":"test"}"#);
|
||||
assert!(formatted.contains("search_archive"));
|
||||
assert!(formatted.contains("test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_result_truncation() {
|
||||
let long = "x".repeat(1000);
|
||||
let formatted = AgentProgress::format_tool_result("search", &long);
|
||||
assert!(formatted.len() < 600);
|
||||
assert!(formatted.ends_with('…'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_result_short() {
|
||||
let formatted = AgentProgress::format_tool_result("search", "3 results found");
|
||||
assert_eq!(formatted, "`search` ← 3 results found");
|
||||
}
|
||||
}
|
||||
173
src/agents/definitions.rs
Normal file
173
src/agents/definitions.rs
Normal file
@@ -0,0 +1,173 @@
|
||||
use mistralai_client::v1::agents::{AgentTool, CompletionArgs, CreateAgentRequest};
|
||||
|
||||
/// Domain agent definitions — each scoped to a subset of sunbeam-sdk tools.
|
||||
/// These are created on startup via the Agents API and cached by the registry.
|
||||
|
||||
pub const ORCHESTRATOR_NAME: &str = "sol-orchestrator";
|
||||
pub const ORCHESTRATOR_DESCRIPTION: &str =
|
||||
"Sol — virtual librarian for Sunbeam Studios. Routes to domain agents or responds directly.";
|
||||
|
||||
/// Build the orchestrator agent instructions.
|
||||
/// The orchestrator carries Sol's personality and sees high-level domain descriptions.
|
||||
pub fn orchestrator_instructions(system_prompt: &str) -> String {
|
||||
format!(
|
||||
"{system_prompt}\n\n\
|
||||
## delegation\n\n\
|
||||
you have access to domain agents for specialized tasks. \
|
||||
for simple conversation, respond directly. for tasks requiring tools, delegate.\n\n\
|
||||
available domains:\n\
|
||||
- **observability**: metrics, logs, dashboards, alerts (prometheus, loki, grafana)\n\
|
||||
- **data**: full-text search, object storage (opensearch, seaweedfs)\n\
|
||||
- **devtools**: git repos, issues, PRs, kanban boards (gitea, planka)\n\
|
||||
- **infrastructure**: kubernetes, deployments, certificates, builds\n\
|
||||
- **identity**: user accounts, sessions, login, recovery, OAuth2 clients (kratos, hydra)\n\
|
||||
- **collaboration**: contacts, documents, meetings, files, email, calendars (la suite)\n\
|
||||
- **communication**: chat rooms, messages, members (matrix)\n\
|
||||
- **media**: video/audio rooms, recordings, streams (livekit)\n"
|
||||
)
|
||||
}
|
||||
|
||||
/// Build a domain agent creation request.
|
||||
pub fn domain_agent_request(
|
||||
name: &str,
|
||||
description: &str,
|
||||
instructions: &str,
|
||||
tools: Vec<AgentTool>,
|
||||
model: &str,
|
||||
) -> CreateAgentRequest {
|
||||
CreateAgentRequest {
|
||||
model: model.to_string(),
|
||||
name: name.to_string(),
|
||||
description: Some(description.to_string()),
|
||||
instructions: Some(instructions.to_string()),
|
||||
tools: Some(tools),
|
||||
handoffs: None,
|
||||
completion_args: Some(CompletionArgs {
|
||||
temperature: Some(0.3),
|
||||
..Default::default()
|
||||
}),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the orchestrator agent creation request.
|
||||
/// Includes Sol's existing tools as function calling tools.
|
||||
pub fn orchestrator_request(
|
||||
system_prompt: &str,
|
||||
model: &str,
|
||||
tools: Vec<AgentTool>,
|
||||
) -> CreateAgentRequest {
|
||||
let instructions = orchestrator_instructions(system_prompt);
|
||||
|
||||
CreateAgentRequest {
|
||||
model: model.to_string(),
|
||||
name: ORCHESTRATOR_NAME.to_string(),
|
||||
description: Some(ORCHESTRATOR_DESCRIPTION.to_string()),
|
||||
instructions: Some(instructions),
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
handoffs: None,
|
||||
completion_args: Some(CompletionArgs {
|
||||
temperature: Some(0.5),
|
||||
..Default::default()
|
||||
}),
|
||||
metadata: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Known domain agent configurations.
|
||||
/// Each entry: (name, description, instructions_snippet)
|
||||
pub const DOMAIN_AGENTS: &[(&str, &str, &str)] = &[
|
||||
(
|
||||
"sol-observability",
|
||||
"Metrics, logs, dashboards, and alerts",
|
||||
"you handle observability tasks for sunbeam infrastructure. \
|
||||
you can query prometheus metrics, search loki logs, manage grafana dashboards, \
|
||||
and check alert status. respond with data, not opinions.",
|
||||
),
|
||||
(
|
||||
"sol-data",
|
||||
"Full-text search and object storage",
|
||||
"you handle data operations. you can search the opensearch archive for past conversations, \
|
||||
manage seaweedfs object storage buckets and files. present search results clearly.",
|
||||
),
|
||||
(
|
||||
"sol-devtools",
|
||||
"Git repos, issues, PRs, and kanban boards",
|
||||
"you handle development tools. you can manage gitea repositories, issues, pull requests, \
|
||||
and planka kanban boards. be precise about repo names and issue numbers.",
|
||||
),
|
||||
(
|
||||
"sol-infrastructure",
|
||||
"Kubernetes, deployments, certificates, and builds",
|
||||
"you handle infrastructure operations. you can inspect kubernetes resources, \
|
||||
trigger deployments, check certificate status, and manage builds. \
|
||||
always confirm destructive actions.",
|
||||
),
|
||||
(
|
||||
"sol-identity",
|
||||
"User accounts, sessions, and OAuth2",
|
||||
"you handle identity management. you can create and manage user accounts via kratos, \
|
||||
manage OAuth2 clients via hydra, and handle recovery flows. \
|
||||
be careful with credentials — never expose secrets.",
|
||||
),
|
||||
(
|
||||
"sol-collaboration",
|
||||
"Contacts, documents, meetings, files, email, calendars",
|
||||
"you handle collaboration services from la suite numérique. \
|
||||
you can manage contacts (people), documents (docs), meetings (meet), \
|
||||
files (drive), email, and calendars. help users find and organize their work.",
|
||||
),
|
||||
(
|
||||
"sol-communication",
|
||||
"Chat rooms, messages, and members",
|
||||
"you handle matrix communication. you can manage rooms, look up members, \
|
||||
search message history, and help with room administration.",
|
||||
),
|
||||
(
|
||||
"sol-media",
|
||||
"Video/audio rooms, recordings, and streams",
|
||||
"you handle media services via livekit. you can manage video/audio rooms, \
|
||||
start/stop recordings, and check stream status.",
|
||||
),
|
||||
];
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_orchestrator_instructions_includes_prompt() {
|
||||
let prompt = "you are sol.";
|
||||
let instructions = orchestrator_instructions(prompt);
|
||||
assert!(instructions.starts_with("you are sol."));
|
||||
assert!(instructions.contains("observability"));
|
||||
assert!(instructions.contains("delegation"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_orchestrator_request() {
|
||||
let req = orchestrator_request("test prompt", "mistral-medium-latest", vec![]);
|
||||
assert_eq!(req.name, "sol-orchestrator");
|
||||
assert_eq!(req.model, "mistral-medium-latest");
|
||||
assert!(req.instructions.unwrap().contains("test prompt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_agent_request() {
|
||||
let req = domain_agent_request(
|
||||
"sol-test",
|
||||
"Test agent",
|
||||
"You test things.",
|
||||
vec![AgentTool::web_search()],
|
||||
"mistral-medium-latest",
|
||||
);
|
||||
assert_eq!(req.name, "sol-test");
|
||||
assert_eq!(req.tools.unwrap().len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_domain_agents_defined() {
|
||||
assert_eq!(DOMAIN_AGENTS.len(), 8);
|
||||
assert_eq!(DOMAIN_AGENTS[0].0, "sol-observability");
|
||||
}
|
||||
}
|
||||
2
src/agents/mod.rs
Normal file
2
src/agents/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
||||
pub mod definitions;
|
||||
pub mod registry;
|
||||
175
src/agents/registry.rs
Normal file
175
src/agents/registry.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use mistralai_client::v1::agents::{Agent, CreateAgentRequest};
|
||||
use mistralai_client::v1::client::Client as MistralClient;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{info, warn, error};
|
||||
|
||||
use super::definitions;
|
||||
use crate::persistence::Store;
|
||||
|
||||
/// Manages the lifecycle of Mistral agents — creates on startup, caches IDs,
|
||||
/// handles instruction updates by re-creating agents.
|
||||
/// Agent ID mappings are persisted to SQLite so they survive reboots.
|
||||
pub struct AgentRegistry {
|
||||
/// agent_name → Agent
|
||||
agents: Mutex<HashMap<String, Agent>>,
|
||||
/// SQLite persistence.
|
||||
store: Arc<Store>,
|
||||
}
|
||||
|
||||
impl AgentRegistry {
|
||||
pub fn new(store: Arc<Store>) -> Self {
|
||||
Self {
|
||||
agents: Mutex::new(HashMap::new()),
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
/// Ensure the orchestrator agent exists. Creates or verifies it.
|
||||
/// Returns the agent ID.
|
||||
pub async fn ensure_orchestrator(
|
||||
&self,
|
||||
system_prompt: &str,
|
||||
model: &str,
|
||||
tools: Vec<mistralai_client::v1::agents::AgentTool>,
|
||||
mistral: &MistralClient,
|
||||
) -> Result<String, String> {
|
||||
let mut agents = self.agents.lock().await;
|
||||
|
||||
// Check in-memory cache
|
||||
if let Some(agent) = agents.get(definitions::ORCHESTRATOR_NAME) {
|
||||
return Ok(agent.id.clone());
|
||||
}
|
||||
|
||||
// Check SQLite for persisted agent ID
|
||||
if let Some(agent_id) = self.store.get_agent(definitions::ORCHESTRATOR_NAME) {
|
||||
// Verify it still exists on the server
|
||||
match mistral.get_agent_async(&agent_id).await {
|
||||
Ok(agent) => {
|
||||
info!(agent_id = agent.id.as_str(), "Restored orchestrator agent from database");
|
||||
agents.insert(definitions::ORCHESTRATOR_NAME.to_string(), agent);
|
||||
return Ok(agent_id);
|
||||
}
|
||||
Err(_) => {
|
||||
warn!("Persisted orchestrator agent {agent_id} no longer exists on server");
|
||||
self.store.delete_agent(definitions::ORCHESTRATOR_NAME);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Check if it exists on the server by name
|
||||
let existing = self.find_by_name(definitions::ORCHESTRATOR_NAME, mistral).await;
|
||||
if let Some(agent) = existing {
|
||||
let id = agent.id.clone();
|
||||
info!(agent_id = id.as_str(), "Found existing orchestrator agent on server");
|
||||
self.store.upsert_agent(definitions::ORCHESTRATOR_NAME, &id, model);
|
||||
agents.insert(definitions::ORCHESTRATOR_NAME.to_string(), agent);
|
||||
return Ok(id);
|
||||
}
|
||||
|
||||
// Create new
|
||||
let req = definitions::orchestrator_request(system_prompt, model, tools);
|
||||
let agent = mistral
|
||||
.create_agent_async(&req)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to create orchestrator agent: {}", e.message))?;
|
||||
|
||||
let id = agent.id.clone();
|
||||
info!(agent_id = id.as_str(), "Created orchestrator agent");
|
||||
self.store.upsert_agent(definitions::ORCHESTRATOR_NAME, &id, model);
|
||||
agents.insert(definitions::ORCHESTRATOR_NAME.to_string(), agent);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Ensure a domain agent exists. Returns the agent ID.
|
||||
pub async fn ensure_domain_agent(
|
||||
&self,
|
||||
name: &str,
|
||||
request: &CreateAgentRequest,
|
||||
mistral: &MistralClient,
|
||||
) -> Result<String, String> {
|
||||
let mut agents = self.agents.lock().await;
|
||||
|
||||
if let Some(agent) = agents.get(name) {
|
||||
return Ok(agent.id.clone());
|
||||
}
|
||||
|
||||
// Check SQLite
|
||||
if let Some(agent_id) = self.store.get_agent(name) {
|
||||
match mistral.get_agent_async(&agent_id).await {
|
||||
Ok(agent) => {
|
||||
info!(name, agent_id = agent.id.as_str(), "Restored domain agent from database");
|
||||
agents.insert(name.to_string(), agent);
|
||||
return Ok(agent_id);
|
||||
}
|
||||
Err(_) => {
|
||||
warn!(name, "Persisted agent {agent_id} gone from server");
|
||||
self.store.delete_agent(name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let existing = self.find_by_name(name, mistral).await;
|
||||
if let Some(agent) = existing {
|
||||
let id = agent.id.clone();
|
||||
info!(name, agent_id = id.as_str(), "Found existing domain agent on server");
|
||||
self.store.upsert_agent(name, &id, &request.model);
|
||||
agents.insert(name.to_string(), agent);
|
||||
return Ok(id);
|
||||
}
|
||||
|
||||
let agent = mistral
|
||||
.create_agent_async(request)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to create agent {name}: {}", e.message))?;
|
||||
|
||||
let id = agent.id.clone();
|
||||
info!(name, agent_id = id.as_str(), "Created domain agent");
|
||||
self.store.upsert_agent(name, &id, &request.model);
|
||||
agents.insert(name.to_string(), agent);
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
/// Get the agent ID for a given name.
|
||||
pub async fn get_id(&self, name: &str) -> Option<String> {
|
||||
self.agents
|
||||
.lock()
|
||||
.await
|
||||
.get(name)
|
||||
.map(|a| a.id.clone())
|
||||
}
|
||||
|
||||
/// List all registered agent names and IDs.
|
||||
pub async fn list(&self) -> Vec<(String, String)> {
|
||||
self.agents
|
||||
.lock()
|
||||
.await
|
||||
.iter()
|
||||
.map(|(name, agent)| (name.clone(), agent.id.clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Find an agent by name on the Mistral server.
|
||||
async fn find_by_name(&self, name: &str, mistral: &MistralClient) -> Option<Agent> {
|
||||
match mistral.list_agents_async().await {
|
||||
Ok(list) => list.data.into_iter().find(|a| a.name == name),
|
||||
Err(e) => {
|
||||
warn!("Failed to list agents: {}", e.message);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_registry_creation() {
|
||||
let store = Arc::new(Store::open_memory().unwrap());
|
||||
let _reg = AgentRegistry::new(store);
|
||||
}
|
||||
}
|
||||
@@ -210,14 +210,14 @@ impl Evaluator {
|
||||
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let text = &response.choices[0].message.content;
|
||||
let text = response.choices[0].message.content.text();
|
||||
info!(
|
||||
raw_response = text.as_str(),
|
||||
model = self.config.mistral.evaluation_model.as_str(),
|
||||
"LLM evaluation raw response"
|
||||
);
|
||||
|
||||
match serde_json::from_str::<serde_json::Value>(text) {
|
||||
match serde_json::from_str::<serde_json::Value>(&text) {
|
||||
Ok(val) => {
|
||||
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
|
||||
let hook = val["hook"].as_str().unwrap_or("").to_string();
|
||||
|
||||
@@ -16,17 +16,30 @@ impl Personality {
|
||||
room_name: &str,
|
||||
members: &[String],
|
||||
memory_notes: Option<&str>,
|
||||
) -> String {
|
||||
is_dm: bool,
|
||||
) -> String {
|
||||
let now = Utc::now();
|
||||
let date = now.format("%Y-%m-%d").to_string();
|
||||
let epoch_ms = now.timestamp_millis().to_string();
|
||||
let members_str = members.join(", ");
|
||||
|
||||
let room_context_rules = if is_dm {
|
||||
String::new()
|
||||
} else {
|
||||
"you are in a group room. messages from multiple people are prefixed with \
|
||||
their Matrix user ID in angle brackets (e.g. <@sienna:sunbeam.pt>). \
|
||||
respond naturally to the room as a whole. do not address each person \
|
||||
by name unless specifically needed. do not prefix your response with \
|
||||
names or labels."
|
||||
.to_string()
|
||||
};
|
||||
|
||||
self.template
|
||||
.replace("{date}", &date)
|
||||
.replace("{epoch_ms}", &epoch_ms)
|
||||
.replace("{room_name}", room_name)
|
||||
.replace("{members}", &members_str)
|
||||
.replace("{room_context_rules}", &room_context_rules)
|
||||
.replace("{memory_notes}", memory_notes.unwrap_or(""))
|
||||
}
|
||||
}
|
||||
@@ -38,7 +51,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_date_substitution() {
|
||||
let p = Personality::new("Today is {date}.".to_string());
|
||||
let result = p.build_system_prompt("general", &[], None);
|
||||
let result = p.build_system_prompt("general", &[], None, false);
|
||||
let today = Utc::now().format("%Y-%m-%d").to_string();
|
||||
assert_eq!(result, format!("Today is {today}."));
|
||||
}
|
||||
@@ -46,7 +59,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_room_name_substitution() {
|
||||
let p = Personality::new("You are in {room_name}.".to_string());
|
||||
let result = p.build_system_prompt("design-chat", &[], None);
|
||||
let result = p.build_system_prompt("design-chat", &[], None, false);
|
||||
assert!(result.contains("design-chat"));
|
||||
}
|
||||
|
||||
@@ -54,14 +67,14 @@ mod tests {
|
||||
fn test_members_substitution() {
|
||||
let p = Personality::new("Members: {members}".to_string());
|
||||
let members = vec!["Alice".to_string(), "Bob".to_string(), "Carol".to_string()];
|
||||
let result = p.build_system_prompt("room", &members, None);
|
||||
let result = p.build_system_prompt("room", &members, None, false);
|
||||
assert_eq!(result, "Members: Alice, Bob, Carol");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_members() {
|
||||
let p = Personality::new("Members: {members}".to_string());
|
||||
let result = p.build_system_prompt("room", &[], None);
|
||||
let result = p.build_system_prompt("room", &[], None, false);
|
||||
assert_eq!(result, "Members: ");
|
||||
}
|
||||
|
||||
@@ -70,7 +83,7 @@ mod tests {
|
||||
let template = "Date: {date}, Room: {room_name}, People: {members}".to_string();
|
||||
let p = Personality::new(template);
|
||||
let members = vec!["Sienna".to_string(), "Lonni".to_string()];
|
||||
let result = p.build_system_prompt("studio", &members, None);
|
||||
let result = p.build_system_prompt("studio", &members, None, false);
|
||||
|
||||
let today = Utc::now().format("%Y-%m-%d").to_string();
|
||||
assert!(result.starts_with(&format!("Date: {today}")));
|
||||
@@ -81,14 +94,14 @@ mod tests {
|
||||
#[test]
|
||||
fn test_no_placeholders_passthrough() {
|
||||
let p = Personality::new("Static prompt with no variables.".to_string());
|
||||
let result = p.build_system_prompt("room", &["Alice".to_string()], None);
|
||||
let result = p.build_system_prompt("room", &["Alice".to_string()], None, false);
|
||||
assert_eq!(result, "Static prompt with no variables.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_same_placeholder() {
|
||||
let p = Personality::new("{room_name} is great. I love {room_name}.".to_string());
|
||||
let result = p.build_system_prompt("lounge", &[], None);
|
||||
let result = p.build_system_prompt("lounge", &[], None, false);
|
||||
assert_eq!(result, "lounge is great. I love lounge.");
|
||||
}
|
||||
|
||||
@@ -96,7 +109,7 @@ mod tests {
|
||||
fn test_memory_notes_substitution() {
|
||||
let p = Personality::new("Context:\n{memory_notes}\nEnd.".to_string());
|
||||
let notes = "## notes about sienna\n- [preference] likes terse answers";
|
||||
let result = p.build_system_prompt("room", &[], Some(notes));
|
||||
let result = p.build_system_prompt("room", &[], Some(notes), false);
|
||||
assert!(result.contains("## notes about sienna"));
|
||||
assert!(result.contains("- [preference] likes terse answers"));
|
||||
assert!(result.starts_with("Context:\n"));
|
||||
@@ -106,7 +119,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_memory_notes_none_clears_placeholder() {
|
||||
let p = Personality::new("Before\n{memory_notes}\nAfter".to_string());
|
||||
let result = p.build_system_prompt("room", &[], None);
|
||||
let result = p.build_system_prompt("room", &[], None, false);
|
||||
assert_eq!(result, "Before\n\nAfter");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatParams, ChatResponse, ChatResponseChoiceFinishReason},
|
||||
constants::Model,
|
||||
conversations::{ConversationEntry, ConversationInput, FunctionResultEntry},
|
||||
error::ApiError,
|
||||
tool::ToolChoice,
|
||||
};
|
||||
@@ -13,10 +14,12 @@ use tracing::{debug, error, info, warn};
|
||||
use matrix_sdk::room::Room;
|
||||
use opensearch::OpenSearch;
|
||||
|
||||
use crate::agent_ux::AgentProgress;
|
||||
use crate::brain::conversation::ContextMessage;
|
||||
use crate::brain::personality::Personality;
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::conversations::ConversationRegistry;
|
||||
use crate::memory;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
@@ -72,6 +75,7 @@ impl Responder {
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
room: &Room,
|
||||
response_ctx: &ResponseContext,
|
||||
image_data_uri: Option<&str>,
|
||||
) -> Option<String> {
|
||||
// Apply response delay (skip if instant_responses is enabled)
|
||||
// Delay happens BEFORE typing indicator — Sol "notices" the message first
|
||||
@@ -103,6 +107,7 @@ impl Responder {
|
||||
room_name,
|
||||
members,
|
||||
memory_notes.as_deref(),
|
||||
response_ctx.is_dm,
|
||||
);
|
||||
|
||||
let mut messages = vec![ChatMessage::new_system_message(&system_prompt)];
|
||||
@@ -120,9 +125,26 @@ impl Responder {
|
||||
}
|
||||
}
|
||||
|
||||
// Add the triggering message
|
||||
let trigger = format!("{trigger_sender}: {trigger_body}");
|
||||
messages.push(ChatMessage::new_user_message(&trigger));
|
||||
// Add the triggering message (multimodal if image attached)
|
||||
if let Some(data_uri) = image_data_uri {
|
||||
use mistralai_client::v1::chat::{ContentPart, ImageUrl};
|
||||
let mut parts = vec![];
|
||||
if !trigger_body.is_empty() {
|
||||
parts.push(ContentPart::Text {
|
||||
text: format!("{trigger_sender}: {trigger_body}"),
|
||||
});
|
||||
}
|
||||
parts.push(ContentPart::ImageUrl {
|
||||
image_url: ImageUrl {
|
||||
url: data_uri.to_string(),
|
||||
detail: None,
|
||||
},
|
||||
});
|
||||
messages.push(ChatMessage::new_user_message_with_images(parts));
|
||||
} else {
|
||||
let trigger = format!("{trigger_sender}: {trigger_body}");
|
||||
messages.push(ChatMessage::new_user_message(&trigger));
|
||||
}
|
||||
|
||||
let tool_defs = ToolRegistry::tool_definitions();
|
||||
let model = Model::new(&self.config.mistral.default_model);
|
||||
@@ -158,7 +180,7 @@ impl Responder {
|
||||
if let Some(tool_calls) = &choice.message.tool_calls {
|
||||
// Add assistant message with tool calls
|
||||
messages.push(ChatMessage::new_assistant_message(
|
||||
&choice.message.content,
|
||||
&choice.message.content.text(),
|
||||
Some(tool_calls.clone()),
|
||||
));
|
||||
|
||||
@@ -197,7 +219,7 @@ impl Responder {
|
||||
}
|
||||
|
||||
// Final text response — strip own name prefix if present
|
||||
let mut text = choice.message.content.trim().to_string();
|
||||
let mut text = choice.message.content.text().trim().to_string();
|
||||
|
||||
// Strip "sol:" or "sol 💕:" or similar prefixes the model sometimes adds
|
||||
let lower = text.to_lowercase();
|
||||
@@ -231,6 +253,173 @@ impl Responder {
|
||||
None
|
||||
}
|
||||
|
||||
/// Generate a response using the Mistral Conversations API.
|
||||
/// This path routes through the ConversationRegistry for persistent state,
|
||||
/// agent handoffs, and function calling with UX feedback (reactions + threads).
|
||||
pub async fn generate_response_conversations(
|
||||
&self,
|
||||
trigger_body: &str,
|
||||
trigger_sender: &str,
|
||||
room_id: &str,
|
||||
is_dm: bool,
|
||||
is_spontaneous: bool,
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
room: &Room,
|
||||
response_ctx: &ResponseContext,
|
||||
conversation_registry: &ConversationRegistry,
|
||||
image_data_uri: Option<&str>,
|
||||
) -> Option<String> {
|
||||
// Apply response delay
|
||||
if !self.config.behavior.instant_responses {
|
||||
let delay = if is_spontaneous {
|
||||
rand::thread_rng().gen_range(
|
||||
self.config.behavior.spontaneous_delay_min_ms
|
||||
..=self.config.behavior.spontaneous_delay_max_ms,
|
||||
)
|
||||
} else {
|
||||
rand::thread_rng().gen_range(
|
||||
self.config.behavior.response_delay_min_ms
|
||||
..=self.config.behavior.response_delay_max_ms,
|
||||
)
|
||||
};
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
let _ = room.typing_notice(true).await;
|
||||
|
||||
// Build the input message (with sender prefix for group rooms)
|
||||
let input_text = if is_dm {
|
||||
trigger_body.to_string()
|
||||
} else {
|
||||
format!("<{}> {}", response_ctx.matrix_user_id, trigger_body)
|
||||
};
|
||||
|
||||
// TODO: multimodal via image_data_uri — Conversations API may support
|
||||
// content parts in entries. For now, append image description request.
|
||||
let input = ConversationInput::Text(input_text);
|
||||
|
||||
// Send through conversation registry
|
||||
let response = match conversation_registry
|
||||
.send_message(room_id, input, is_dm, mistral)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Conversation API failed: {e}");
|
||||
let _ = room.typing_notice(false).await;
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// Check for function calls — execute locally and send results back
|
||||
let function_calls = response.function_calls();
|
||||
if !function_calls.is_empty() {
|
||||
// Agent UX: reactions + threads require the user's event ID
|
||||
// which we don't have in the responder. For now, log tool calls
|
||||
// and skip UX. TODO: pass event_id through ResponseContext.
|
||||
|
||||
let max_iterations = self.config.mistral.max_tool_iterations;
|
||||
let mut current_response = response;
|
||||
|
||||
for iteration in 0..max_iterations {
|
||||
let calls = current_response.function_calls();
|
||||
if calls.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut result_entries = Vec::new();
|
||||
|
||||
for fc in &calls {
|
||||
let call_id = fc.tool_call_id.as_deref().unwrap_or("unknown");
|
||||
info!(
|
||||
tool = fc.name.as_str(),
|
||||
id = call_id,
|
||||
args = fc.arguments.as_str(),
|
||||
"Executing tool call (conversations)"
|
||||
);
|
||||
|
||||
|
||||
|
||||
let result = self
|
||||
.tools
|
||||
.execute(&fc.name, &fc.arguments, response_ctx)
|
||||
.await;
|
||||
|
||||
let result_str = match result {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
warn!(tool = fc.name.as_str(), "Tool failed: {e}");
|
||||
format!("Error: {e}")
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
result_entries.push(ConversationEntry::FunctionResult(FunctionResultEntry {
|
||||
tool_call_id: call_id.to_string(),
|
||||
result: result_str,
|
||||
id: None,
|
||||
object: None,
|
||||
created_at: None,
|
||||
completed_at: None,
|
||||
}));
|
||||
}
|
||||
|
||||
// Send function results back to conversation
|
||||
current_response = match conversation_registry
|
||||
.send_function_result(room_id, result_entries, mistral)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
error!("Failed to send function results: {e}");
|
||||
let _ = room.typing_notice(false).await;
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
debug!(iteration, "Tool iteration complete (conversations)");
|
||||
}
|
||||
|
||||
// Extract final text from the last response
|
||||
if let Some(text) = current_response.assistant_text() {
|
||||
let text = strip_sol_prefix(&text);
|
||||
if text.is_empty() {
|
||||
let _ = room.typing_notice(false).await;
|
||||
return None;
|
||||
}
|
||||
let _ = room.typing_notice(false).await;
|
||||
info!(
|
||||
response_len = text.len(),
|
||||
"Generated response (conversations + tools)"
|
||||
);
|
||||
return Some(text);
|
||||
}
|
||||
|
||||
let _ = room.typing_notice(false).await;
|
||||
return None;
|
||||
}
|
||||
|
||||
// Simple response — no tools involved
|
||||
if let Some(text) = response.assistant_text() {
|
||||
let text = strip_sol_prefix(&text);
|
||||
if text.is_empty() {
|
||||
let _ = room.typing_notice(false).await;
|
||||
return None;
|
||||
}
|
||||
let _ = room.typing_notice(false).await;
|
||||
info!(
|
||||
response_len = text.len(),
|
||||
is_spontaneous,
|
||||
"Generated response (conversations)"
|
||||
);
|
||||
return Some(text);
|
||||
}
|
||||
|
||||
let _ = room.typing_notice(false).await;
|
||||
None
|
||||
}
|
||||
|
||||
async fn load_memory_notes(
|
||||
&self,
|
||||
ctx: &ResponseContext,
|
||||
@@ -284,6 +473,18 @@ impl Responder {
|
||||
}
|
||||
}
|
||||
|
||||
/// Strip "sol:" or "sol 💕:" prefixes the model sometimes adds.
|
||||
fn strip_sol_prefix(text: &str) -> String {
|
||||
let trimmed = text.trim();
|
||||
let lower = trimmed.to_lowercase();
|
||||
for prefix in &["sol:", "sol 💕:", "sol💕:"] {
|
||||
if lower.starts_with(prefix) {
|
||||
return trimmed[prefix.len()..].trim().to_string();
|
||||
}
|
||||
}
|
||||
trimmed.to_string()
|
||||
}
|
||||
|
||||
/// Format memory documents into a notes block for the system prompt.
|
||||
pub(crate) fn format_memory_notes(
|
||||
display_name: &str,
|
||||
|
||||
@@ -6,6 +6,35 @@ pub struct Config {
|
||||
pub opensearch: OpenSearchConfig,
|
||||
pub mistral: MistralConfig,
|
||||
pub behavior: BehaviorConfig,
|
||||
#[serde(default)]
|
||||
pub agents: AgentsConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct AgentsConfig {
|
||||
/// Model for the orchestrator agent.
|
||||
#[serde(default = "default_model")]
|
||||
pub orchestrator_model: String,
|
||||
/// Model for domain agents.
|
||||
#[serde(default = "default_model")]
|
||||
pub domain_model: String,
|
||||
/// Token threshold for conversation compaction (~90% of model context window).
|
||||
#[serde(default = "default_compaction_threshold")]
|
||||
pub compaction_threshold: u32,
|
||||
/// Whether to use the Conversations API (vs manual message management).
|
||||
#[serde(default)]
|
||||
pub use_conversations_api: bool,
|
||||
}
|
||||
|
||||
impl Default for AgentsConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
orchestrator_model: default_model(),
|
||||
domain_model: default_model(),
|
||||
compaction_threshold: default_compaction_threshold(),
|
||||
use_conversations_api: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -13,6 +42,10 @@ pub struct MatrixConfig {
|
||||
pub homeserver_url: String,
|
||||
pub user_id: String,
|
||||
pub state_store_path: String,
|
||||
/// Path to the SQLite database for persistent state (conversations, agents).
|
||||
/// Should be on a persistent volume in K8s (e.g. Longhorn PVC mounted at /data).
|
||||
#[serde(default = "default_db_path")]
|
||||
pub db_path: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -112,6 +145,8 @@ fn default_script_timeout_secs() -> u64 { 5 }
|
||||
fn default_script_max_heap_mb() -> usize { 64 }
|
||||
fn default_memory_index() -> String { "sol_user_memory".into() }
|
||||
fn default_memory_extraction_enabled() -> bool { true }
|
||||
fn default_db_path() -> String { "/data/sol.db".into() }
|
||||
fn default_compaction_threshold() -> u32 { 118000 } // ~90% of 131K context window
|
||||
|
||||
impl Config {
|
||||
pub fn load(path: &str) -> anyhow::Result<Self> {
|
||||
|
||||
299
src/conversations.rs
Normal file
299
src/conversations.rs
Normal file
@@ -0,0 +1,299 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use mistralai_client::v1::client::Client as MistralClient;
|
||||
use mistralai_client::v1::conversations::{
|
||||
AppendConversationRequest, ConversationInput, ConversationResponse,
|
||||
CreateConversationRequest,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::persistence::Store;
|
||||
|
||||
/// Maps Matrix room IDs to Mistral conversation IDs.
|
||||
/// Group rooms get a shared conversation; DMs are per-room (already unique per DM pair).
|
||||
/// State is persisted to SQLite so mappings survive reboots.
|
||||
pub struct ConversationRegistry {
|
||||
/// room_id → conversation_id (in-memory cache, backed by SQLite)
|
||||
mapping: Mutex<HashMap<String, ConversationState>>,
|
||||
/// Agent ID to use when creating new conversations (orchestrator or None for model-only).
|
||||
agent_id: Mutex<Option<String>>,
|
||||
/// Model to use when no agent is configured.
|
||||
model: String,
|
||||
/// Token budget before compaction triggers (% of model context window).
|
||||
compaction_threshold: u32,
|
||||
/// SQLite persistence.
|
||||
store: Arc<Store>,
|
||||
}
|
||||
|
||||
struct ConversationState {
|
||||
conversation_id: String,
|
||||
/// Estimated token count (incremented per message, reset on compaction).
|
||||
estimated_tokens: u32,
|
||||
}
|
||||
|
||||
impl ConversationRegistry {
|
||||
pub fn new(model: String, compaction_threshold: u32, store: Arc<Store>) -> Self {
|
||||
// Load existing mappings from SQLite
|
||||
let persisted = store.load_all_conversations();
|
||||
let mut mapping = HashMap::new();
|
||||
for (room_id, conversation_id, estimated_tokens) in persisted {
|
||||
mapping.insert(
|
||||
room_id,
|
||||
ConversationState {
|
||||
conversation_id,
|
||||
estimated_tokens,
|
||||
},
|
||||
);
|
||||
}
|
||||
let count = mapping.len();
|
||||
if count > 0 {
|
||||
info!(count, "Restored conversation mappings from database");
|
||||
}
|
||||
|
||||
Self {
|
||||
mapping: Mutex::new(mapping),
|
||||
agent_id: Mutex::new(None),
|
||||
model,
|
||||
compaction_threshold,
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the orchestrator agent ID (called after agent registry creates it).
|
||||
pub async fn set_agent_id(&self, agent_id: String) {
|
||||
let mut id = self.agent_id.lock().await;
|
||||
*id = Some(agent_id);
|
||||
}
|
||||
|
||||
/// Get or create a conversation for a room. Returns the conversation ID.
|
||||
/// If a conversation doesn't exist yet, creates one with the first message.
|
||||
pub async fn send_message(
|
||||
&self,
|
||||
room_id: &str,
|
||||
message: ConversationInput,
|
||||
is_dm: bool,
|
||||
mistral: &MistralClient,
|
||||
) -> Result<ConversationResponse, String> {
|
||||
let mut mapping = self.mapping.lock().await;
|
||||
|
||||
if let Some(state) = mapping.get_mut(room_id) {
|
||||
// Existing conversation — append
|
||||
let req = AppendConversationRequest {
|
||||
inputs: message,
|
||||
completion_args: None,
|
||||
handoff_execution: None,
|
||||
store: Some(true),
|
||||
tool_confirmations: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = mistral
|
||||
.append_conversation_async(&state.conversation_id, &req)
|
||||
.await
|
||||
.map_err(|e| format!("append_conversation failed: {}", e.message))?;
|
||||
|
||||
// Update token estimate
|
||||
state.estimated_tokens += response.usage.total_tokens;
|
||||
self.store.update_tokens(room_id, state.estimated_tokens);
|
||||
|
||||
debug!(
|
||||
room = room_id,
|
||||
conversation_id = state.conversation_id.as_str(),
|
||||
tokens = state.estimated_tokens,
|
||||
"Appended to conversation"
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
} else {
|
||||
// New conversation — create
|
||||
let agent_id = self.agent_id.lock().await.clone();
|
||||
|
||||
let req = CreateConversationRequest {
|
||||
inputs: message,
|
||||
model: if agent_id.is_none() {
|
||||
Some(self.model.clone())
|
||||
} else {
|
||||
None
|
||||
},
|
||||
agent_id,
|
||||
agent_version: None,
|
||||
name: Some(format!("sol-{}", room_id)),
|
||||
description: None,
|
||||
instructions: None,
|
||||
completion_args: None,
|
||||
tools: None,
|
||||
handoff_execution: None,
|
||||
metadata: None,
|
||||
store: Some(true),
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = mistral
|
||||
.create_conversation_async(&req)
|
||||
.await
|
||||
.map_err(|e| format!("create_conversation failed: {}", e.message))?;
|
||||
|
||||
let conv_id = response.conversation_id.clone();
|
||||
let tokens = response.usage.total_tokens;
|
||||
info!(
|
||||
room = room_id,
|
||||
conversation_id = conv_id.as_str(),
|
||||
"Created new conversation"
|
||||
);
|
||||
|
||||
self.store.upsert_conversation(room_id, &conv_id, tokens);
|
||||
|
||||
mapping.insert(
|
||||
room_id.to_string(),
|
||||
ConversationState {
|
||||
conversation_id: conv_id,
|
||||
estimated_tokens: tokens,
|
||||
},
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a function result back to a conversation.
|
||||
pub async fn send_function_result(
|
||||
&self,
|
||||
room_id: &str,
|
||||
entries: Vec<mistralai_client::v1::conversations::ConversationEntry>,
|
||||
mistral: &MistralClient,
|
||||
) -> Result<ConversationResponse, String> {
|
||||
let mapping = self.mapping.lock().await;
|
||||
let state = mapping
|
||||
.get(room_id)
|
||||
.ok_or_else(|| format!("no conversation for room {room_id}"))?;
|
||||
|
||||
let req = AppendConversationRequest {
|
||||
inputs: ConversationInput::Entries(entries),
|
||||
completion_args: None,
|
||||
handoff_execution: None,
|
||||
store: Some(true),
|
||||
tool_confirmations: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
mistral
|
||||
.append_conversation_async(&state.conversation_id, &req)
|
||||
.await
|
||||
.map_err(|e| format!("append_conversation (function result) failed: {}", e.message))
|
||||
}
|
||||
|
||||
/// Check if a room's conversation needs compaction.
|
||||
pub async fn needs_compaction(&self, room_id: &str) -> bool {
|
||||
let mapping = self.mapping.lock().await;
|
||||
if let Some(state) = mapping.get(room_id) {
|
||||
state.estimated_tokens >= self.compaction_threshold
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
/// Reset a room's conversation (e.g., after compaction).
|
||||
/// Removes the mapping so the next message creates a fresh conversation.
|
||||
pub async fn reset(&self, room_id: &str) {
|
||||
let mut mapping = self.mapping.lock().await;
|
||||
if let Some(state) = mapping.remove(room_id) {
|
||||
self.store.delete_conversation(room_id);
|
||||
info!(
|
||||
room = room_id,
|
||||
conversation_id = state.conversation_id.as_str(),
|
||||
tokens = state.estimated_tokens,
|
||||
"Reset conversation (compaction)"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the conversation ID for a room, if one exists.
|
||||
pub async fn get_conversation_id(&self, room_id: &str) -> Option<String> {
|
||||
let mapping = self.mapping.lock().await;
|
||||
mapping.get(room_id).map(|s| s.conversation_id.clone())
|
||||
}
|
||||
|
||||
/// Number of active conversations.
|
||||
pub async fn active_count(&self) -> usize {
|
||||
self.mapping.lock().await.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// Merge multiple buffered user messages into a single conversation input.
|
||||
/// For DMs: raw text concatenation with newlines.
|
||||
/// For group rooms: prefix each line with `<sender_matrix_id>`.
|
||||
pub fn merge_user_messages(
|
||||
messages: &[(String, String)], // (sender_matrix_id, body)
|
||||
is_dm: bool,
|
||||
) -> String {
|
||||
if messages.is_empty() {
|
||||
return String::new();
|
||||
}
|
||||
|
||||
if is_dm {
|
||||
// DMs: just concatenate
|
||||
messages
|
||||
.iter()
|
||||
.map(|(_, body)| body.as_str())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
} else {
|
||||
// Group rooms: prefix with sender
|
||||
messages
|
||||
.iter()
|
||||
.map(|(sender, body)| format!("<{sender}> {body}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_merge_dm_messages() {
|
||||
let msgs = vec![
|
||||
("@alice:example.com".to_string(), "hello".to_string()),
|
||||
("@alice:example.com".to_string(), "how are you?".to_string()),
|
||||
];
|
||||
let merged = merge_user_messages(&msgs, true);
|
||||
assert_eq!(merged, "hello\nhow are you?");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_group_messages() {
|
||||
let msgs = vec![
|
||||
("@sienna:sunbeam.pt".to_string(), "what's the error rate?".to_string()),
|
||||
("@lonni:sunbeam.pt".to_string(), "also check memory".to_string()),
|
||||
("@sienna:sunbeam.pt".to_string(), "and disk too".to_string()),
|
||||
];
|
||||
let merged = merge_user_messages(&msgs, false);
|
||||
assert_eq!(
|
||||
merged,
|
||||
"<@sienna:sunbeam.pt> what's the error rate?\n<@lonni:sunbeam.pt> also check memory\n<@sienna:sunbeam.pt> and disk too"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_empty() {
|
||||
let merged = merge_user_messages(&[], true);
|
||||
assert_eq!(merged, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_single_dm() {
|
||||
let msgs = vec![("@user:x".to_string(), "hi".to_string())];
|
||||
let merged = merge_user_messages(&msgs, true);
|
||||
assert_eq!(merged, "hi");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_single_group() {
|
||||
let msgs = vec![("@user:x".to_string(), "hi".to_string())];
|
||||
let merged = merge_user_messages(&msgs, false);
|
||||
assert_eq!(merged, "<@user:x> hi");
|
||||
}
|
||||
}
|
||||
67
src/main.rs
67
src/main.rs
@@ -1,9 +1,13 @@
|
||||
mod agent_ux;
|
||||
mod agents;
|
||||
mod archive;
|
||||
mod brain;
|
||||
mod config;
|
||||
mod context;
|
||||
mod conversations;
|
||||
mod matrix_utils;
|
||||
mod memory;
|
||||
mod persistence;
|
||||
mod sync;
|
||||
mod tools;
|
||||
|
||||
@@ -15,12 +19,14 @@ use opensearch::OpenSearch;
|
||||
use ruma::{OwnedDeviceId, OwnedUserId};
|
||||
use tokio::signal;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, warn};
|
||||
use url::Url;
|
||||
|
||||
use agents::registry::AgentRegistry;
|
||||
use archive::indexer::Indexer;
|
||||
use archive::schema::create_index_if_not_exists;
|
||||
use brain::conversation::{ContextMessage, ConversationManager};
|
||||
use conversations::ConversationRegistry;
|
||||
use memory::schema::create_index_if_not_exists as create_memory_index;
|
||||
use brain::evaluator::Evaluator;
|
||||
use brain::personality::Personality;
|
||||
@@ -110,6 +116,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
let mistral = Arc::new(mistral_client);
|
||||
|
||||
// Build components
|
||||
let system_prompt_text = system_prompt.clone();
|
||||
let personality = Arc::new(Personality::new(system_prompt));
|
||||
let conversations = Arc::new(Mutex::new(ConversationManager::new(
|
||||
config.behavior.room_context_window,
|
||||
@@ -141,6 +148,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
// Start background flush task
|
||||
let _flush_handle = indexer.start_flush_task();
|
||||
|
||||
// Initialize persistent state database
|
||||
let (store, state_recovery_failed) = match persistence::Store::open(&config.matrix.db_path) {
|
||||
Ok(s) => (Arc::new(s), false),
|
||||
Err(e) => {
|
||||
error!("Failed to open state database at {}: {e}", config.matrix.db_path);
|
||||
error!("Falling back to in-memory state — conversations will not survive restarts");
|
||||
(Arc::new(persistence::Store::open_memory().expect("in-memory DB must work")), true)
|
||||
}
|
||||
};
|
||||
|
||||
// Initialize agent registry and conversation registry (with SQLite backing)
|
||||
let agent_registry = Arc::new(AgentRegistry::new(store.clone()));
|
||||
let conversation_registry = Arc::new(ConversationRegistry::new(
|
||||
config.mistral.default_model.clone(),
|
||||
config.agents.compaction_threshold,
|
||||
store,
|
||||
));
|
||||
|
||||
// Build shared state
|
||||
let state = Arc::new(AppState {
|
||||
config: config.clone(),
|
||||
@@ -148,12 +173,39 @@ async fn main() -> anyhow::Result<()> {
|
||||
evaluator,
|
||||
responder,
|
||||
conversations,
|
||||
agent_registry,
|
||||
conversation_registry,
|
||||
mistral,
|
||||
opensearch: os_client,
|
||||
last_response: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
responding_in: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
});
|
||||
|
||||
// Initialize orchestrator agent if conversations API is enabled
|
||||
if config.agents.use_conversations_api {
|
||||
info!("Conversations API enabled — ensuring orchestrator agent exists");
|
||||
let agent_tools = tools::ToolRegistry::agent_tool_definitions();
|
||||
match state
|
||||
.agent_registry
|
||||
.ensure_orchestrator(
|
||||
&system_prompt_text,
|
||||
&config.agents.orchestrator_model,
|
||||
agent_tools,
|
||||
&state.mistral,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(agent_id) => {
|
||||
info!(agent_id = agent_id.as_str(), "Orchestrator agent ready");
|
||||
state.conversation_registry.set_agent_id(agent_id).await;
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to create orchestrator agent: {e}");
|
||||
error!("Falling back to model-only conversations (no orchestrator)");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Backfill reactions from Matrix room timelines
|
||||
info!("Backfilling reactions from room timelines...");
|
||||
if let Err(e) = backfill_reactions(&matrix_client, &state.indexer).await {
|
||||
@@ -169,6 +221,19 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
});
|
||||
|
||||
// If state recovery failed, sneeze into all rooms to signal the hiccup
|
||||
if state_recovery_failed {
|
||||
info!("State recovery failed — sneezing into all rooms");
|
||||
for room in matrix_client.joined_rooms() {
|
||||
let content = ruma::events::room::message::RoomMessageEventContent::text_plain(
|
||||
"*sneezes*",
|
||||
);
|
||||
if let Err(e) = room.send(content).await {
|
||||
warn!("Failed to sneeze into {}: {e}", room.room_id());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!("Sol is running");
|
||||
|
||||
// Wait for shutdown signal
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use matrix_sdk::media::{MediaFormat, MediaRequestParameters};
|
||||
use matrix_sdk::room::Room;
|
||||
use matrix_sdk::RoomMemberships;
|
||||
use ruma::events::room::message::{
|
||||
@@ -67,6 +68,61 @@ pub async fn send_reaction(
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract image info from an m.image message event.
|
||||
/// Returns (mxc_url, mimetype, body/caption) if present.
|
||||
pub fn extract_image(event: &OriginalSyncRoomMessageEvent) -> Option<(String, String, String)> {
|
||||
if let MessageType::Image(image) = &event.content.msgtype {
|
||||
let url = match &image.source {
|
||||
ruma::events::room::MediaSource::Plain(mxc) => mxc.to_string(),
|
||||
ruma::events::room::MediaSource::Encrypted(_) => return None,
|
||||
};
|
||||
let mime = image
|
||||
.info
|
||||
.as_ref()
|
||||
.and_then(|i| i.mimetype.clone())
|
||||
.unwrap_or_else(|| "image/png".to_string());
|
||||
let caption = image.body.clone();
|
||||
Some((url, mime, caption))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Download image bytes from a Matrix mxc:// URL via the media API.
|
||||
/// Returns the raw bytes as a base64 data URI suitable for Mistral vision.
|
||||
pub async fn download_image_as_data_uri(
|
||||
client: &matrix_sdk::Client,
|
||||
event: &OriginalSyncRoomMessageEvent,
|
||||
) -> Option<String> {
|
||||
if let MessageType::Image(image) = &event.content.msgtype {
|
||||
let media_source = &image.source;
|
||||
let mime = image
|
||||
.info
|
||||
.as_ref()
|
||||
.and_then(|i| i.mimetype.clone())
|
||||
.unwrap_or_else(|| "image/png".to_string());
|
||||
|
||||
let request = MediaRequestParameters {
|
||||
source: media_source.clone(),
|
||||
format: MediaFormat::File,
|
||||
};
|
||||
|
||||
match client.media().get_media_content(&request, true).await {
|
||||
Ok(bytes) => {
|
||||
use base64::Engine;
|
||||
let b64 = base64::engine::general_purpose::STANDARD.encode(&bytes);
|
||||
Some(format!("data:{};base64,{}", mime, b64))
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to download image: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the display name for a room.
|
||||
pub fn room_display_name(room: &Room) -> String {
|
||||
room.cached_display_name()
|
||||
|
||||
@@ -64,7 +64,8 @@ pub async fn extract_and_store(
|
||||
};
|
||||
|
||||
let response = chat_blocking(mistral, model, messages, params).await?;
|
||||
let text = response.choices[0].message.content.trim();
|
||||
let text = response.choices[0].message.content.text();
|
||||
let text = text.trim();
|
||||
|
||||
let extraction: ExtractionResponse = match serde_json::from_str(text) {
|
||||
Ok(e) => e,
|
||||
|
||||
303
src/persistence.rs
Normal file
303
src/persistence.rs
Normal file
@@ -0,0 +1,303 @@
|
||||
use rusqlite::{params, Connection, Result as SqlResult};
|
||||
use std::path::Path;
|
||||
use std::sync::Mutex;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// SQLite-backed persistent state for Sol.
|
||||
///
|
||||
/// Stores:
|
||||
/// - Conversation registry: room_id → Mistral conversation_id + token estimates
|
||||
/// - Agent registry: agent_name → Mistral agent_id
|
||||
///
|
||||
/// ## Kubernetes mount
|
||||
///
|
||||
/// The database file should be on a persistent volume. Recommended setup:
|
||||
///
|
||||
/// ```yaml
|
||||
/// # PersistentVolumeClaim (Longhorn)
|
||||
/// apiVersion: v1
|
||||
/// kind: PersistentVolumeClaim
|
||||
/// metadata:
|
||||
/// name: sol-data
|
||||
/// namespace: matrix
|
||||
/// spec:
|
||||
/// accessModes: [ReadWriteOnce]
|
||||
/// storageClassName: longhorn
|
||||
/// resources:
|
||||
/// requests:
|
||||
/// storage: 1Gi
|
||||
///
|
||||
/// # Deployment volume mount
|
||||
/// volumes:
|
||||
/// - name: sol-data
|
||||
/// persistentVolumeClaim:
|
||||
/// claimName: sol-data
|
||||
/// containers:
|
||||
/// - name: sol
|
||||
/// volumeMounts:
|
||||
/// - name: sol-data
|
||||
/// mountPath: /data
|
||||
/// ```
|
||||
///
|
||||
/// Default path: `/data/sol.db` (configurable via `matrix.db_path` in sol.toml).
|
||||
/// The `/data` mount also holds the Matrix SDK state store at `/data/matrix-state`.
|
||||
pub struct Store {
|
||||
conn: Mutex<Connection>,
|
||||
}
|
||||
|
||||
impl Store {
|
||||
/// Open or create the database at the given path.
|
||||
/// Creates tables if they don't exist.
|
||||
pub fn open(path: &str) -> anyhow::Result<Self> {
|
||||
// Ensure parent directory exists
|
||||
if let Some(parent) = Path::new(path).parent() {
|
||||
std::fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let conn = Connection::open(path)?;
|
||||
|
||||
// Enable WAL mode for better concurrent read performance
|
||||
conn.execute_batch("PRAGMA journal_mode=WAL;")?;
|
||||
|
||||
conn.execute_batch(
|
||||
"CREATE TABLE IF NOT EXISTS conversations (
|
||||
room_id TEXT PRIMARY KEY,
|
||||
conversation_id TEXT NOT NULL,
|
||||
estimated_tokens INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS agents (
|
||||
name TEXT PRIMARY KEY,
|
||||
agent_id TEXT NOT NULL,
|
||||
model TEXT NOT NULL DEFAULT '',
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);",
|
||||
)?;
|
||||
|
||||
info!(path, "Opened Sol state database");
|
||||
Ok(Self {
|
||||
conn: Mutex::new(conn),
|
||||
})
|
||||
}
|
||||
|
||||
/// Open an in-memory database (for tests).
|
||||
pub fn open_memory() -> anyhow::Result<Self> {
|
||||
Self::open(":memory:")
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Conversations
|
||||
// =========================================================================
|
||||
|
||||
/// Get the conversation_id for a room, if one exists.
|
||||
pub fn get_conversation(&self, room_id: &str) -> Option<(String, u32)> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
conn.query_row(
|
||||
"SELECT conversation_id, estimated_tokens FROM conversations WHERE room_id = ?1",
|
||||
params![room_id],
|
||||
|row| Ok((row.get::<_, String>(0)?, row.get::<_, u32>(1)?)),
|
||||
)
|
||||
.ok()
|
||||
}
|
||||
|
||||
/// Store or update a conversation mapping.
|
||||
pub fn upsert_conversation(
|
||||
&self,
|
||||
room_id: &str,
|
||||
conversation_id: &str,
|
||||
estimated_tokens: u32,
|
||||
) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
if let Err(e) = conn.execute(
|
||||
"INSERT INTO conversations (room_id, conversation_id, estimated_tokens)
|
||||
VALUES (?1, ?2, ?3)
|
||||
ON CONFLICT(room_id) DO UPDATE SET
|
||||
conversation_id = excluded.conversation_id,
|
||||
estimated_tokens = excluded.estimated_tokens",
|
||||
params![room_id, conversation_id, estimated_tokens],
|
||||
) {
|
||||
warn!("Failed to upsert conversation: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Update token estimate for a conversation.
|
||||
pub fn update_tokens(&self, room_id: &str, estimated_tokens: u32) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let _ = conn.execute(
|
||||
"UPDATE conversations SET estimated_tokens = ?1 WHERE room_id = ?2",
|
||||
params![estimated_tokens, room_id],
|
||||
);
|
||||
}
|
||||
|
||||
/// Remove a conversation mapping (e.g., after compaction).
|
||||
pub fn delete_conversation(&self, room_id: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let _ = conn.execute(
|
||||
"DELETE FROM conversations WHERE room_id = ?1",
|
||||
params![room_id],
|
||||
);
|
||||
}
|
||||
|
||||
/// Load all conversation mappings (for startup recovery).
|
||||
pub fn load_all_conversations(&self) -> Vec<(String, String, u32)> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT room_id, conversation_id, estimated_tokens FROM conversations",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
stmt.query_map([], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, u32>(2)?,
|
||||
))
|
||||
})
|
||||
.ok()
|
||||
.map(|rows| rows.filter_map(|r| r.ok()).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Agents
|
||||
// =========================================================================
|
||||
|
||||
/// Get the agent_id for a named agent.
|
||||
pub fn get_agent(&self, name: &str) -> Option<String> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
conn.query_row(
|
||||
"SELECT agent_id FROM agents WHERE name = ?1",
|
||||
params![name],
|
||||
|row| row.get(0),
|
||||
)
|
||||
.ok()
|
||||
}
|
||||
|
||||
/// Store or update an agent mapping.
|
||||
pub fn upsert_agent(&self, name: &str, agent_id: &str, model: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
if let Err(e) = conn.execute(
|
||||
"INSERT INTO agents (name, agent_id, model)
|
||||
VALUES (?1, ?2, ?3)
|
||||
ON CONFLICT(name) DO UPDATE SET
|
||||
agent_id = excluded.agent_id,
|
||||
model = excluded.model",
|
||||
params![name, agent_id, model],
|
||||
) {
|
||||
warn!("Failed to upsert agent: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove an agent mapping.
|
||||
pub fn delete_agent(&self, name: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let _ = conn.execute("DELETE FROM agents WHERE name = ?1", params![name]);
|
||||
}
|
||||
|
||||
/// Load all agent mappings (for startup recovery).
|
||||
pub fn load_all_agents(&self) -> Vec<(String, String)> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let mut stmt = match conn.prepare("SELECT name, agent_id FROM agents") {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
stmt.query_map([], |row| {
|
||||
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
|
||||
})
|
||||
.ok()
|
||||
.map(|rows| rows.filter_map(|r| r.ok()).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_open_memory_db() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
assert!(store.load_all_conversations().is_empty());
|
||||
assert!(store.load_all_agents().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conversation_crud() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
|
||||
// Insert
|
||||
store.upsert_conversation("!room:x", "conv_abc", 100);
|
||||
let (conv_id, tokens) = store.get_conversation("!room:x").unwrap();
|
||||
assert_eq!(conv_id, "conv_abc");
|
||||
assert_eq!(tokens, 100);
|
||||
|
||||
// Update tokens
|
||||
store.update_tokens("!room:x", 500);
|
||||
let (_, tokens) = store.get_conversation("!room:x").unwrap();
|
||||
assert_eq!(tokens, 500);
|
||||
|
||||
// Upsert (replace conversation_id)
|
||||
store.upsert_conversation("!room:x", "conv_def", 0);
|
||||
let (conv_id, tokens) = store.get_conversation("!room:x").unwrap();
|
||||
assert_eq!(conv_id, "conv_def");
|
||||
assert_eq!(tokens, 0);
|
||||
|
||||
// Delete
|
||||
store.delete_conversation("!room:x");
|
||||
assert!(store.get_conversation("!room:x").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_agent_crud() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
|
||||
store.upsert_agent("sol-orchestrator", "ag_123", "mistral-medium-latest");
|
||||
assert_eq!(
|
||||
store.get_agent("sol-orchestrator").unwrap(),
|
||||
"ag_123"
|
||||
);
|
||||
|
||||
// Update
|
||||
store.upsert_agent("sol-orchestrator", "ag_456", "mistral-medium-latest");
|
||||
assert_eq!(
|
||||
store.get_agent("sol-orchestrator").unwrap(),
|
||||
"ag_456"
|
||||
);
|
||||
|
||||
// Delete
|
||||
store.delete_agent("sol-orchestrator");
|
||||
assert!(store.get_agent("sol-orchestrator").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_conversations() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
store.upsert_conversation("!a:x", "conv_1", 10);
|
||||
store.upsert_conversation("!b:x", "conv_2", 20);
|
||||
store.upsert_conversation("!c:x", "conv_3", 30);
|
||||
|
||||
let all = store.load_all_conversations();
|
||||
assert_eq!(all.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_all_agents() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
store.upsert_agent("orch", "ag_1", "medium");
|
||||
store.upsert_agent("obs", "ag_2", "medium");
|
||||
|
||||
let all = store.load_all_agents();
|
||||
assert_eq!(all.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonexistent_keys_return_none() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
assert!(store.get_conversation("!nope:x").is_none());
|
||||
assert!(store.get_agent("nope").is_none());
|
||||
}
|
||||
}
|
||||
82
src/sync.rs
82
src/sync.rs
@@ -14,6 +14,7 @@ use tracing::{debug, error, info, warn};
|
||||
|
||||
use opensearch::OpenSearch;
|
||||
|
||||
use crate::agents::registry::AgentRegistry;
|
||||
use crate::archive::indexer::Indexer;
|
||||
use crate::archive::schema::ArchiveDocument;
|
||||
use crate::brain::conversation::{ContextMessage, ConversationManager};
|
||||
@@ -21,6 +22,7 @@ use crate::brain::evaluator::{Engagement, Evaluator};
|
||||
use crate::brain::responder::Responder;
|
||||
use crate::config::Config;
|
||||
use crate::context::{self, ResponseContext};
|
||||
use crate::conversations::ConversationRegistry;
|
||||
use crate::matrix_utils;
|
||||
use crate::memory;
|
||||
|
||||
@@ -32,6 +34,10 @@ pub struct AppState {
|
||||
pub conversations: Arc<Mutex<ConversationManager>>,
|
||||
pub mistral: Arc<mistralai_client::v1::client::Client>,
|
||||
pub opensearch: OpenSearch,
|
||||
/// Agent registry — manages Mistral agent lifecycle.
|
||||
pub agent_registry: Arc<AgentRegistry>,
|
||||
/// Conversation registry for Mistral Conversations API.
|
||||
pub conversation_registry: Arc<ConversationRegistry>,
|
||||
/// Tracks when Sol last responded in each room (for cooldown)
|
||||
pub last_response: Arc<Mutex<HashMap<String, Instant>>>,
|
||||
/// Tracks rooms where a response is currently being generated (in-flight guard)
|
||||
@@ -104,10 +110,31 @@ async fn handle_message(
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let Some(body) = matrix_utils::extract_body(&event) else {
|
||||
return Ok(());
|
||||
// Extract text body — or image caption for m.image events
|
||||
let image_data_uri = matrix_utils::download_image_as_data_uri(
|
||||
&room.client(),
|
||||
&event,
|
||||
)
|
||||
.await;
|
||||
|
||||
let body = if let Some(ref _uri) = image_data_uri {
|
||||
// For images, use the caption/filename as the text body
|
||||
matrix_utils::extract_image(&event)
|
||||
.map(|(_, _, caption)| caption)
|
||||
.or_else(|| matrix_utils::extract_body(&event))
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
match matrix_utils::extract_body(&event) {
|
||||
Some(b) => b,
|
||||
None => return Ok(()),
|
||||
}
|
||||
};
|
||||
|
||||
// Skip if we have neither text nor image
|
||||
if body.is_empty() && image_data_uri.is_none() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let room_name = matrix_utils::room_display_name(&room);
|
||||
let sender_name = room
|
||||
.get_member_no_sync(&event.sender)
|
||||
@@ -131,7 +158,9 @@ async fn handle_message(
|
||||
content: body.clone(),
|
||||
reply_to,
|
||||
thread_id,
|
||||
media_urls: Vec::new(),
|
||||
media_urls: matrix_utils::extract_image(&event)
|
||||
.map(|(url, _, _)| vec![url])
|
||||
.unwrap_or_default(),
|
||||
event_type: "m.room.message".into(),
|
||||
edited: false,
|
||||
redacted: false,
|
||||
@@ -242,20 +271,39 @@ async fn handle_message(
|
||||
let members = matrix_utils::room_member_names(&room).await;
|
||||
let display_sender = sender_name.as_deref().unwrap_or(&sender);
|
||||
|
||||
let response = state
|
||||
.responder
|
||||
.generate_response(
|
||||
&context,
|
||||
&body,
|
||||
display_sender,
|
||||
&room_name,
|
||||
&members,
|
||||
is_spontaneous,
|
||||
&state.mistral,
|
||||
&room,
|
||||
&response_ctx,
|
||||
)
|
||||
.await;
|
||||
let response = if state.config.agents.use_conversations_api {
|
||||
state
|
||||
.responder
|
||||
.generate_response_conversations(
|
||||
&body,
|
||||
display_sender,
|
||||
&room_id,
|
||||
is_dm,
|
||||
is_spontaneous,
|
||||
&state.mistral,
|
||||
&room,
|
||||
&response_ctx,
|
||||
&state.conversation_registry,
|
||||
image_data_uri.as_deref(),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
state
|
||||
.responder
|
||||
.generate_response(
|
||||
&context,
|
||||
&body,
|
||||
display_sender,
|
||||
&room_name,
|
||||
&members,
|
||||
is_spontaneous,
|
||||
&state.mistral,
|
||||
&room,
|
||||
&response_ctx,
|
||||
image_data_uri.as_deref(),
|
||||
)
|
||||
.await
|
||||
};
|
||||
|
||||
if let Some(text) = response {
|
||||
// Reply with reference only when directly addressed. Spontaneous
|
||||
|
||||
86
src/tools/bridge.rs
Normal file
86
src/tools/bridge.rs
Normal file
@@ -0,0 +1,86 @@
|
||||
use std::collections::HashMap;
|
||||
use std::future::Future;
|
||||
use std::pin::Pin;
|
||||
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Maps Mistral tool call names to sunbeam-sdk async methods.
|
||||
/// Each domain has its own set of tool registrations.
|
||||
pub struct ToolBridge {
|
||||
/// tool_name → handler function
|
||||
handlers: HashMap<String, ToolHandler>,
|
||||
}
|
||||
|
||||
/// A tool handler: takes JSON arguments, returns JSON result.
|
||||
type ToolHandler = Box<dyn Fn(&str) -> Pin<Box<dyn Future<Output = Result<String, String>> + Send>> + Send + Sync>;
|
||||
|
||||
impl ToolBridge {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
handlers: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a tool handler.
|
||||
pub fn register<F, Fut>(&mut self, name: &str, handler: F)
|
||||
where
|
||||
F: Fn(String) -> Fut + Send + Sync + 'static,
|
||||
Fut: std::future::Future<Output = Result<String, String>> + Send + 'static,
|
||||
{
|
||||
self.handlers.insert(
|
||||
name.to_string(),
|
||||
Box::new(move |args: &str| {
|
||||
let args = args.to_string();
|
||||
let fut = handler(args);
|
||||
Box::pin(fut)
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
/// Execute a tool call by name.
|
||||
pub async fn execute(&self, name: &str, arguments: &str) -> Result<String, String> {
|
||||
if let Some(handler) = self.handlers.get(name) {
|
||||
info!(tool = name, "Executing tool via bridge");
|
||||
handler(arguments).await
|
||||
} else {
|
||||
warn!(tool = name, "Unknown tool in bridge");
|
||||
Err(format!("unknown tool: {name}"))
|
||||
}
|
||||
}
|
||||
|
||||
/// List registered tool names.
|
||||
pub fn tool_names(&self) -> Vec<&str> {
|
||||
self.handlers.keys().map(|k| k.as_str()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bridge_register_and_execute() {
|
||||
let mut bridge = ToolBridge::new();
|
||||
bridge.register("echo", |args| async move { Ok(args) });
|
||||
|
||||
let result = bridge.execute("echo", "hello").await;
|
||||
assert_eq!(result.unwrap(), "hello");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_bridge_unknown_tool() {
|
||||
let bridge = ToolBridge::new();
|
||||
let result = bridge.execute("nonexistent", "{}").await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("unknown tool"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bridge_tool_names() {
|
||||
let mut bridge = ToolBridge::new();
|
||||
bridge.register("a", |_| async { Ok(String::new()) });
|
||||
bridge.register("b", |_| async { Ok(String::new()) });
|
||||
let names = bridge.tool_names();
|
||||
assert_eq!(names.len(), 2);
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod bridge;
|
||||
pub mod room_history;
|
||||
pub mod room_info;
|
||||
pub mod script;
|
||||
@@ -156,6 +157,21 @@ impl ToolRegistry {
|
||||
]
|
||||
}
|
||||
|
||||
/// Convert Sol's tool definitions to Mistral AgentTool format
|
||||
/// for use with the Agents API (orchestrator agent creation).
|
||||
pub fn agent_tool_definitions() -> Vec<mistralai_client::v1::agents::AgentTool> {
|
||||
Self::tool_definitions()
|
||||
.into_iter()
|
||||
.map(|t| {
|
||||
mistralai_client::v1::agents::AgentTool::function(
|
||||
t.function.name,
|
||||
t.function.description,
|
||||
t.function.parameters,
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub async fn execute(
|
||||
&self,
|
||||
name: &str,
|
||||
|
||||
Reference in New Issue
Block a user