Files
sol/src/conversations.rs
Sienna Meridian Satterwhite 2a1d7a003d auto-recover corrupted conversations on API error
when append_conversation fails (422, 404, etc.), the stale mapping
is deleted and a fresh conversation is created automatically.
prevents Sol from being permanently stuck after a hung research
session or Mistral API error.
2026-03-23 09:53:29 +00:00

341 lines
11 KiB
Rust

use std::collections::HashMap;
use std::sync::Arc;
use mistralai_client::v1::client::Client as MistralClient;
use mistralai_client::v1::conversations::{
AppendConversationRequest, ConversationInput, ConversationResponse,
CreateConversationRequest,
};
use tokio::sync::Mutex;
use tracing::{debug, info, warn};
use crate::persistence::Store;
/// Maps Matrix room IDs to Mistral conversation IDs.
/// Group rooms get a shared conversation; DMs are per-room (already unique per DM pair).
/// State is persisted to SQLite so mappings survive reboots.
pub struct ConversationRegistry {
/// room_id → conversation_id (in-memory cache, backed by SQLite)
mapping: Mutex<HashMap<String, ConversationState>>,
/// Agent ID to use when creating new conversations (orchestrator or None for model-only).
agent_id: Mutex<Option<String>>,
/// Model to use when no agent is configured.
model: String,
/// Token budget before compaction triggers (% of model context window).
compaction_threshold: u32,
/// SQLite persistence.
store: Arc<Store>,
}
struct ConversationState {
conversation_id: String,
/// Estimated token count (incremented per message, reset on compaction).
estimated_tokens: u32,
}
impl ConversationRegistry {
pub fn new(model: String, compaction_threshold: u32, store: Arc<Store>) -> Self {
// Load existing mappings from SQLite
let persisted = store.load_all_conversations();
let mut mapping = HashMap::new();
for (room_id, conversation_id, estimated_tokens) in persisted {
mapping.insert(
room_id,
ConversationState {
conversation_id,
estimated_tokens,
},
);
}
let count = mapping.len();
if count > 0 {
info!(count, "Restored conversation mappings from database");
}
Self {
mapping: Mutex::new(mapping),
agent_id: Mutex::new(None),
model,
compaction_threshold,
store,
}
}
/// Set the orchestrator agent ID (called after agent registry creates it).
pub async fn set_agent_id(&self, agent_id: String) {
let mut id = self.agent_id.lock().await;
*id = Some(agent_id);
}
/// Get or create a conversation for a room. Returns the conversation ID.
/// If a conversation doesn't exist yet, creates one with the first message.
/// `context_hint` is prepended to the first message on new conversations,
/// giving the agent recent conversation history for continuity after resets.
pub async fn send_message(
&self,
room_id: &str,
message: ConversationInput,
is_dm: bool,
mistral: &MistralClient,
context_hint: Option<&str>,
) -> Result<ConversationResponse, String> {
let mut mapping = self.mapping.lock().await;
// Try to append to existing conversation; if it fails, drop and recreate
if let Some(state) = mapping.get_mut(room_id) {
let req = AppendConversationRequest {
inputs: message.clone(),
completion_args: None,
handoff_execution: None,
store: Some(true),
tool_confirmations: None,
stream: false,
};
match mistral
.append_conversation_async(&state.conversation_id, &req)
.await
{
Ok(response) => {
state.estimated_tokens += response.usage.total_tokens;
self.store.update_tokens(room_id, state.estimated_tokens);
debug!(
room = room_id,
conversation_id = state.conversation_id.as_str(),
tokens = state.estimated_tokens,
"Appended to conversation"
);
return Ok(response);
}
Err(e) => {
warn!(
room = room_id,
conversation_id = state.conversation_id.as_str(),
error = e.message.as_str(),
"Conversation corrupted — dropping and creating fresh"
);
self.store.delete_conversation(room_id);
mapping.remove(room_id);
// Fall through to create a new conversation below
}
}
}
{
// New conversation — create (with optional context hint for continuity)
let agent_id = self.agent_id.lock().await.clone();
let inputs = if let Some(hint) = context_hint {
// Prepend recent conversation history to the first message
match message {
ConversationInput::Text(text) => {
ConversationInput::Text(format!(
"[recent conversation for context]\n{hint}\n\n[current message]\n{text}"
))
}
other => other,
}
} else {
message
};
let req = CreateConversationRequest {
inputs,
model: if agent_id.is_none() {
Some(self.model.clone())
} else {
None
},
agent_id,
agent_version: None,
name: Some(format!("sol-{}", room_id)),
description: None,
instructions: None,
completion_args: None,
tools: None,
handoff_execution: None,
metadata: None,
store: Some(true),
stream: false,
};
let response = mistral
.create_conversation_async(&req)
.await
.map_err(|e| format!("create_conversation failed: {}", e.message))?;
let conv_id = response.conversation_id.clone();
let tokens = response.usage.total_tokens;
info!(
room = room_id,
conversation_id = conv_id.as_str(),
"Created new conversation"
);
self.store.upsert_conversation(room_id, &conv_id, tokens);
mapping.insert(
room_id.to_string(),
ConversationState {
conversation_id: conv_id,
estimated_tokens: tokens,
},
);
Ok(response)
}
}
/// Send a function result back to a conversation.
pub async fn send_function_result(
&self,
room_id: &str,
entries: Vec<mistralai_client::v1::conversations::ConversationEntry>,
mistral: &MistralClient,
) -> Result<ConversationResponse, String> {
let mapping = self.mapping.lock().await;
let state = mapping
.get(room_id)
.ok_or_else(|| format!("no conversation for room {room_id}"))?;
let req = AppendConversationRequest {
inputs: ConversationInput::Entries(entries),
completion_args: None,
handoff_execution: None,
store: Some(true),
tool_confirmations: None,
stream: false,
};
mistral
.append_conversation_async(&state.conversation_id, &req)
.await
.map_err(|e| format!("append_conversation (function result) failed: {}", e.message))
}
/// Check if a room's conversation needs compaction.
pub async fn needs_compaction(&self, room_id: &str) -> bool {
let mapping = self.mapping.lock().await;
if let Some(state) = mapping.get(room_id) {
state.estimated_tokens >= self.compaction_threshold
} else {
false
}
}
/// Reset ALL conversations (e.g., after agent recreation).
/// Clears both in-memory mappings and SQLite.
pub async fn reset_all(&self) {
let mut mapping = self.mapping.lock().await;
let count = mapping.len();
mapping.clear();
self.store.delete_all_conversations();
info!(count, "Reset all conversations");
}
/// Reset a room's conversation (e.g., after compaction).
/// Removes the mapping so the next message creates a fresh conversation.
pub async fn reset(&self, room_id: &str) {
let mut mapping = self.mapping.lock().await;
if let Some(state) = mapping.remove(room_id) {
self.store.delete_conversation(room_id);
info!(
room = room_id,
conversation_id = state.conversation_id.as_str(),
tokens = state.estimated_tokens,
"Reset conversation (compaction)"
);
}
}
/// Get the conversation ID for a room, if one exists.
pub async fn get_conversation_id(&self, room_id: &str) -> Option<String> {
let mapping = self.mapping.lock().await;
mapping.get(room_id).map(|s| s.conversation_id.clone())
}
/// Number of active conversations.
pub async fn active_count(&self) -> usize {
self.mapping.lock().await.len()
}
}
/// Merge multiple buffered user messages into a single conversation input.
/// For DMs: raw text concatenation with newlines.
/// For group rooms: prefix each line with `<sender_matrix_id>`.
pub fn merge_user_messages(
messages: &[(String, String)], // (sender_matrix_id, body)
is_dm: bool,
) -> String {
if messages.is_empty() {
return String::new();
}
if is_dm {
// DMs: just concatenate
messages
.iter()
.map(|(_, body)| body.as_str())
.collect::<Vec<_>>()
.join("\n")
} else {
// Group rooms: prefix with sender
messages
.iter()
.map(|(sender, body)| format!("<{sender}> {body}"))
.collect::<Vec<_>>()
.join("\n")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_merge_dm_messages() {
let msgs = vec![
("@alice:example.com".to_string(), "hello".to_string()),
("@alice:example.com".to_string(), "how are you?".to_string()),
];
let merged = merge_user_messages(&msgs, true);
assert_eq!(merged, "hello\nhow are you?");
}
#[test]
fn test_merge_group_messages() {
let msgs = vec![
("@sienna:sunbeam.pt".to_string(), "what's the error rate?".to_string()),
("@lonni:sunbeam.pt".to_string(), "also check memory".to_string()),
("@sienna:sunbeam.pt".to_string(), "and disk too".to_string()),
];
let merged = merge_user_messages(&msgs, false);
assert_eq!(
merged,
"<@sienna:sunbeam.pt> what's the error rate?\n<@lonni:sunbeam.pt> also check memory\n<@sienna:sunbeam.pt> and disk too"
);
}
#[test]
fn test_merge_empty() {
let merged = merge_user_messages(&[], true);
assert_eq!(merged, "");
}
#[test]
fn test_merge_single_dm() {
let msgs = vec![("@user:x".to_string(), "hi".to_string())];
let merged = merge_user_messages(&msgs, true);
assert_eq!(merged, "hi");
}
#[test]
fn test_merge_single_group() {
let msgs = vec![("@user:x".to_string(), "hi".to_string())];
let merged = merge_user_messages(&msgs, false);
assert_eq!(merged, "<@user:x> hi");
}
}