when append_conversation fails (422, 404, etc.), the stale mapping is deleted and a fresh conversation is created automatically. prevents Sol from being permanently stuck after a hung research session or Mistral API error.
341 lines
11 KiB
Rust
341 lines
11 KiB
Rust
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
use mistralai_client::v1::client::Client as MistralClient;
|
|
use mistralai_client::v1::conversations::{
|
|
AppendConversationRequest, ConversationInput, ConversationResponse,
|
|
CreateConversationRequest,
|
|
};
|
|
use tokio::sync::Mutex;
|
|
use tracing::{debug, info, warn};
|
|
|
|
use crate::persistence::Store;
|
|
|
|
/// Maps Matrix room IDs to Mistral conversation IDs.
|
|
/// Group rooms get a shared conversation; DMs are per-room (already unique per DM pair).
|
|
/// State is persisted to SQLite so mappings survive reboots.
|
|
pub struct ConversationRegistry {
|
|
/// room_id → conversation_id (in-memory cache, backed by SQLite)
|
|
mapping: Mutex<HashMap<String, ConversationState>>,
|
|
/// Agent ID to use when creating new conversations (orchestrator or None for model-only).
|
|
agent_id: Mutex<Option<String>>,
|
|
/// Model to use when no agent is configured.
|
|
model: String,
|
|
/// Token budget before compaction triggers (% of model context window).
|
|
compaction_threshold: u32,
|
|
/// SQLite persistence.
|
|
store: Arc<Store>,
|
|
}
|
|
|
|
struct ConversationState {
|
|
conversation_id: String,
|
|
/// Estimated token count (incremented per message, reset on compaction).
|
|
estimated_tokens: u32,
|
|
}
|
|
|
|
impl ConversationRegistry {
|
|
pub fn new(model: String, compaction_threshold: u32, store: Arc<Store>) -> Self {
|
|
// Load existing mappings from SQLite
|
|
let persisted = store.load_all_conversations();
|
|
let mut mapping = HashMap::new();
|
|
for (room_id, conversation_id, estimated_tokens) in persisted {
|
|
mapping.insert(
|
|
room_id,
|
|
ConversationState {
|
|
conversation_id,
|
|
estimated_tokens,
|
|
},
|
|
);
|
|
}
|
|
let count = mapping.len();
|
|
if count > 0 {
|
|
info!(count, "Restored conversation mappings from database");
|
|
}
|
|
|
|
Self {
|
|
mapping: Mutex::new(mapping),
|
|
agent_id: Mutex::new(None),
|
|
model,
|
|
compaction_threshold,
|
|
store,
|
|
}
|
|
}
|
|
|
|
/// Set the orchestrator agent ID (called after agent registry creates it).
|
|
pub async fn set_agent_id(&self, agent_id: String) {
|
|
let mut id = self.agent_id.lock().await;
|
|
*id = Some(agent_id);
|
|
}
|
|
|
|
/// Get or create a conversation for a room. Returns the conversation ID.
|
|
/// If a conversation doesn't exist yet, creates one with the first message.
|
|
/// `context_hint` is prepended to the first message on new conversations,
|
|
/// giving the agent recent conversation history for continuity after resets.
|
|
pub async fn send_message(
|
|
&self,
|
|
room_id: &str,
|
|
message: ConversationInput,
|
|
is_dm: bool,
|
|
mistral: &MistralClient,
|
|
context_hint: Option<&str>,
|
|
) -> Result<ConversationResponse, String> {
|
|
let mut mapping = self.mapping.lock().await;
|
|
|
|
// Try to append to existing conversation; if it fails, drop and recreate
|
|
if let Some(state) = mapping.get_mut(room_id) {
|
|
let req = AppendConversationRequest {
|
|
inputs: message.clone(),
|
|
completion_args: None,
|
|
handoff_execution: None,
|
|
store: Some(true),
|
|
tool_confirmations: None,
|
|
stream: false,
|
|
};
|
|
|
|
match mistral
|
|
.append_conversation_async(&state.conversation_id, &req)
|
|
.await
|
|
{
|
|
Ok(response) => {
|
|
state.estimated_tokens += response.usage.total_tokens;
|
|
self.store.update_tokens(room_id, state.estimated_tokens);
|
|
|
|
debug!(
|
|
room = room_id,
|
|
conversation_id = state.conversation_id.as_str(),
|
|
tokens = state.estimated_tokens,
|
|
"Appended to conversation"
|
|
);
|
|
|
|
return Ok(response);
|
|
}
|
|
Err(e) => {
|
|
warn!(
|
|
room = room_id,
|
|
conversation_id = state.conversation_id.as_str(),
|
|
error = e.message.as_str(),
|
|
"Conversation corrupted — dropping and creating fresh"
|
|
);
|
|
self.store.delete_conversation(room_id);
|
|
mapping.remove(room_id);
|
|
// Fall through to create a new conversation below
|
|
}
|
|
}
|
|
}
|
|
|
|
{
|
|
// New conversation — create (with optional context hint for continuity)
|
|
let agent_id = self.agent_id.lock().await.clone();
|
|
|
|
let inputs = if let Some(hint) = context_hint {
|
|
// Prepend recent conversation history to the first message
|
|
match message {
|
|
ConversationInput::Text(text) => {
|
|
ConversationInput::Text(format!(
|
|
"[recent conversation for context]\n{hint}\n\n[current message]\n{text}"
|
|
))
|
|
}
|
|
other => other,
|
|
}
|
|
} else {
|
|
message
|
|
};
|
|
|
|
let req = CreateConversationRequest {
|
|
inputs,
|
|
model: if agent_id.is_none() {
|
|
Some(self.model.clone())
|
|
} else {
|
|
None
|
|
},
|
|
agent_id,
|
|
agent_version: None,
|
|
name: Some(format!("sol-{}", room_id)),
|
|
description: None,
|
|
instructions: None,
|
|
completion_args: None,
|
|
tools: None,
|
|
handoff_execution: None,
|
|
metadata: None,
|
|
store: Some(true),
|
|
stream: false,
|
|
};
|
|
|
|
let response = mistral
|
|
.create_conversation_async(&req)
|
|
.await
|
|
.map_err(|e| format!("create_conversation failed: {}", e.message))?;
|
|
|
|
let conv_id = response.conversation_id.clone();
|
|
let tokens = response.usage.total_tokens;
|
|
info!(
|
|
room = room_id,
|
|
conversation_id = conv_id.as_str(),
|
|
"Created new conversation"
|
|
);
|
|
|
|
self.store.upsert_conversation(room_id, &conv_id, tokens);
|
|
|
|
mapping.insert(
|
|
room_id.to_string(),
|
|
ConversationState {
|
|
conversation_id: conv_id,
|
|
estimated_tokens: tokens,
|
|
},
|
|
);
|
|
|
|
Ok(response)
|
|
}
|
|
}
|
|
|
|
/// Send a function result back to a conversation.
|
|
pub async fn send_function_result(
|
|
&self,
|
|
room_id: &str,
|
|
entries: Vec<mistralai_client::v1::conversations::ConversationEntry>,
|
|
mistral: &MistralClient,
|
|
) -> Result<ConversationResponse, String> {
|
|
let mapping = self.mapping.lock().await;
|
|
let state = mapping
|
|
.get(room_id)
|
|
.ok_or_else(|| format!("no conversation for room {room_id}"))?;
|
|
|
|
let req = AppendConversationRequest {
|
|
inputs: ConversationInput::Entries(entries),
|
|
completion_args: None,
|
|
handoff_execution: None,
|
|
store: Some(true),
|
|
tool_confirmations: None,
|
|
stream: false,
|
|
};
|
|
|
|
mistral
|
|
.append_conversation_async(&state.conversation_id, &req)
|
|
.await
|
|
.map_err(|e| format!("append_conversation (function result) failed: {}", e.message))
|
|
}
|
|
|
|
/// Check if a room's conversation needs compaction.
|
|
pub async fn needs_compaction(&self, room_id: &str) -> bool {
|
|
let mapping = self.mapping.lock().await;
|
|
if let Some(state) = mapping.get(room_id) {
|
|
state.estimated_tokens >= self.compaction_threshold
|
|
} else {
|
|
false
|
|
}
|
|
}
|
|
|
|
/// Reset ALL conversations (e.g., after agent recreation).
|
|
/// Clears both in-memory mappings and SQLite.
|
|
pub async fn reset_all(&self) {
|
|
let mut mapping = self.mapping.lock().await;
|
|
let count = mapping.len();
|
|
mapping.clear();
|
|
self.store.delete_all_conversations();
|
|
info!(count, "Reset all conversations");
|
|
}
|
|
|
|
/// Reset a room's conversation (e.g., after compaction).
|
|
/// Removes the mapping so the next message creates a fresh conversation.
|
|
pub async fn reset(&self, room_id: &str) {
|
|
let mut mapping = self.mapping.lock().await;
|
|
if let Some(state) = mapping.remove(room_id) {
|
|
self.store.delete_conversation(room_id);
|
|
info!(
|
|
room = room_id,
|
|
conversation_id = state.conversation_id.as_str(),
|
|
tokens = state.estimated_tokens,
|
|
"Reset conversation (compaction)"
|
|
);
|
|
}
|
|
}
|
|
|
|
/// Get the conversation ID for a room, if one exists.
|
|
pub async fn get_conversation_id(&self, room_id: &str) -> Option<String> {
|
|
let mapping = self.mapping.lock().await;
|
|
mapping.get(room_id).map(|s| s.conversation_id.clone())
|
|
}
|
|
|
|
/// Number of active conversations.
|
|
pub async fn active_count(&self) -> usize {
|
|
self.mapping.lock().await.len()
|
|
}
|
|
}
|
|
|
|
/// Merge multiple buffered user messages into a single conversation input.
|
|
/// For DMs: raw text concatenation with newlines.
|
|
/// For group rooms: prefix each line with `<sender_matrix_id>`.
|
|
pub fn merge_user_messages(
|
|
messages: &[(String, String)], // (sender_matrix_id, body)
|
|
is_dm: bool,
|
|
) -> String {
|
|
if messages.is_empty() {
|
|
return String::new();
|
|
}
|
|
|
|
if is_dm {
|
|
// DMs: just concatenate
|
|
messages
|
|
.iter()
|
|
.map(|(_, body)| body.as_str())
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
} else {
|
|
// Group rooms: prefix with sender
|
|
messages
|
|
.iter()
|
|
.map(|(sender, body)| format!("<{sender}> {body}"))
|
|
.collect::<Vec<_>>()
|
|
.join("\n")
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_merge_dm_messages() {
|
|
let msgs = vec![
|
|
("@alice:example.com".to_string(), "hello".to_string()),
|
|
("@alice:example.com".to_string(), "how are you?".to_string()),
|
|
];
|
|
let merged = merge_user_messages(&msgs, true);
|
|
assert_eq!(merged, "hello\nhow are you?");
|
|
}
|
|
|
|
#[test]
|
|
fn test_merge_group_messages() {
|
|
let msgs = vec![
|
|
("@sienna:sunbeam.pt".to_string(), "what's the error rate?".to_string()),
|
|
("@lonni:sunbeam.pt".to_string(), "also check memory".to_string()),
|
|
("@sienna:sunbeam.pt".to_string(), "and disk too".to_string()),
|
|
];
|
|
let merged = merge_user_messages(&msgs, false);
|
|
assert_eq!(
|
|
merged,
|
|
"<@sienna:sunbeam.pt> what's the error rate?\n<@lonni:sunbeam.pt> also check memory\n<@sienna:sunbeam.pt> and disk too"
|
|
);
|
|
}
|
|
|
|
#[test]
|
|
fn test_merge_empty() {
|
|
let merged = merge_user_messages(&[], true);
|
|
assert_eq!(merged, "");
|
|
}
|
|
|
|
#[test]
|
|
fn test_merge_single_dm() {
|
|
let msgs = vec![("@user:x".to_string(), "hi".to_string())];
|
|
let merged = merge_user_messages(&msgs, true);
|
|
assert_eq!(merged, "hi");
|
|
}
|
|
|
|
#[test]
|
|
fn test_merge_single_group() {
|
|
let msgs = vec![("@user:x".to_string(), "hi".to_string())];
|
|
let merged = merge_user_messages(&msgs, false);
|
|
assert_eq!(merged, "<@user:x> hi");
|
|
}
|
|
}
|