feat: initial Sol virtual librarian implementation

Matrix bot with E2EE (matrix-sdk 0.9) that passively archives all
messages to OpenSearch and responds to queries via Mistral AI with
function calling tools.

Core systems:
- Archive: bulk OpenSearch indexer with batch/flush, edit/redaction
  handling, embedding pipeline passthrough
- Brain: rule-based engagement evaluator (mentions, DMs, name
  invocations), LLM-powered spontaneous engagement, per-room
  conversation context windows, response delay simulation
- Tools: search_archive, get_room_context, list_rooms, get_room_members
  registered as Mistral function calling tools with iterative tool loop
- Personality: templated system prompt with Sol's librarian persona

47 unit tests covering config, evaluator, conversation windowing,
personality templates, schema serialization, and search query building.
This commit is contained in:
2026-03-20 21:40:13 +00:00
commit 4dc20bee23
21 changed files with 6934 additions and 0 deletions

148
src/archive/indexer.rs Normal file
View File

@@ -0,0 +1,148 @@
use std::sync::Arc;
use opensearch::http::request::JsonBody;
use opensearch::OpenSearch;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::time::{interval, Duration};
use tracing::{debug, error, warn};
use crate::config::Config;
use super::schema::ArchiveDocument;
pub struct Indexer {
buffer: Arc<Mutex<Vec<ArchiveDocument>>>,
client: OpenSearch,
config: Arc<Config>,
}
impl Indexer {
pub fn new(client: OpenSearch, config: Arc<Config>) -> Self {
Self {
buffer: Arc::new(Mutex::new(Vec::new())),
client,
config,
}
}
pub async fn add(&self, doc: ArchiveDocument) {
let mut buffer = self.buffer.lock().await;
buffer.push(doc);
let batch_size = self.config.opensearch.batch_size;
if buffer.len() >= batch_size {
let docs: Vec<ArchiveDocument> = buffer.drain(..).collect();
drop(buffer);
if let Err(e) = self.flush_docs(docs).await {
error!("Failed to flush archive batch: {e}");
}
}
}
pub async fn update_edit(&self, event_id: &str, new_content: &str) {
let body = json!({
"doc": {
"content": new_content,
"edited": true
}
});
if let Err(e) = self
.client
.update(opensearch::UpdateParts::IndexId(
&self.config.opensearch.index,
event_id,
))
.body(body)
.send()
.await
{
warn!(event_id, "Failed to update edited message: {e}");
}
}
pub async fn update_redaction(&self, event_id: &str) {
let body = json!({
"doc": {
"content": "",
"redacted": true
}
});
if let Err(e) = self
.client
.update(opensearch::UpdateParts::IndexId(
&self.config.opensearch.index,
event_id,
))
.body(body)
.send()
.await
{
warn!(event_id, "Failed to update redacted message: {e}");
}
}
pub fn start_flush_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
let this = Arc::clone(self);
tokio::spawn(async move {
let mut tick = interval(Duration::from_millis(
this.config.opensearch.flush_interval_ms,
));
loop {
tick.tick().await;
let mut buffer = this.buffer.lock().await;
if buffer.is_empty() {
continue;
}
let docs: Vec<ArchiveDocument> = buffer.drain(..).collect();
drop(buffer);
if let Err(e) = this.flush_docs(docs).await {
error!("Periodic flush failed: {e}");
}
}
})
}
async fn flush_docs(&self, docs: Vec<ArchiveDocument>) -> anyhow::Result<()> {
if docs.is_empty() {
return Ok(());
}
let index = &self.config.opensearch.index;
let pipeline = &self.config.opensearch.embedding_pipeline;
let mut body: Vec<JsonBody<serde_json::Value>> = Vec::with_capacity(docs.len() * 2);
for doc in &docs {
body.push(
json!({
"index": {
"_index": index,
"_id": doc.event_id
}
})
.into(),
);
body.push(serde_json::to_value(doc)?.into());
}
let response = self
.client
.bulk(opensearch::BulkParts::None)
.pipeline(pipeline)
.body(body)
.send()
.await?;
if !response.status_code().is_success() {
let text = response.text().await?;
anyhow::bail!("Bulk index failed: {text}");
}
let result: serde_json::Value = response.json().await?;
if result["errors"].as_bool().unwrap_or(false) {
warn!("Bulk index had errors: {}", serde_json::to_string_pretty(&result)?);
} else {
debug!(count = docs.len(), "Flushed documents to OpenSearch");
}
Ok(())
}
}

2
src/archive/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod indexer;
pub mod schema;

205
src/archive/schema.rs Normal file
View File

@@ -0,0 +1,205 @@
use opensearch::OpenSearch;
use serde::{Deserialize, Serialize};
use tracing::info;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchiveDocument {
pub event_id: String,
pub room_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub room_name: Option<String>,
pub sender: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub sender_name: Option<String>,
pub timestamp: i64,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reply_to: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thread_id: Option<String>,
#[serde(default)]
pub media_urls: Vec<String>,
pub event_type: String,
#[serde(default)]
pub edited: bool,
#[serde(default)]
pub redacted: bool,
}
const INDEX_MAPPING: &str = r#"{
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"properties": {
"event_id": { "type": "keyword" },
"room_id": { "type": "keyword" },
"room_name": { "type": "keyword" },
"sender": { "type": "keyword" },
"sender_name": { "type": "keyword" },
"timestamp": { "type": "date", "format": "epoch_millis" },
"content": { "type": "text", "analyzer": "standard" },
"reply_to": { "type": "keyword" },
"thread_id": { "type": "keyword" },
"media_urls": { "type": "keyword" },
"event_type": { "type": "keyword" },
"edited": { "type": "boolean" },
"redacted": { "type": "boolean" }
}
}
}"#;
pub fn index_mapping_json() -> &'static str {
INDEX_MAPPING
}
pub async fn create_index_if_not_exists(client: &OpenSearch, index: &str) -> anyhow::Result<()> {
let exists = client
.indices()
.exists(opensearch::indices::IndicesExistsParts::Index(&[index]))
.send()
.await?;
if exists.status_code().is_success() {
info!(index, "OpenSearch index already exists");
return Ok(());
}
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING)?;
let response = client
.indices()
.create(opensearch::indices::IndicesCreateParts::Index(index))
.body(mapping)
.send()
.await?;
if !response.status_code().is_success() {
let body = response.text().await?;
anyhow::bail!("Failed to create index {index}: {body}");
}
info!(index, "Created OpenSearch index");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_doc() -> ArchiveDocument {
ArchiveDocument {
event_id: "$abc123:sunbeam.pt".to_string(),
room_id: "!room:sunbeam.pt".to_string(),
room_name: Some("general".to_string()),
sender: "@alice:sunbeam.pt".to_string(),
sender_name: Some("Alice".to_string()),
timestamp: 1710000000000,
content: "hello world".to_string(),
reply_to: None,
thread_id: None,
media_urls: vec![],
event_type: "m.room.message".to_string(),
edited: false,
redacted: false,
}
}
#[test]
fn test_serialize_full_doc() {
let doc = sample_doc();
let json = serde_json::to_value(&doc).unwrap();
assert_eq!(json["event_id"], "$abc123:sunbeam.pt");
assert_eq!(json["room_id"], "!room:sunbeam.pt");
assert_eq!(json["room_name"], "general");
assert_eq!(json["sender"], "@alice:sunbeam.pt");
assert_eq!(json["sender_name"], "Alice");
assert_eq!(json["timestamp"], 1710000000000_i64);
assert_eq!(json["content"], "hello world");
assert_eq!(json["event_type"], "m.room.message");
assert_eq!(json["edited"], false);
assert_eq!(json["redacted"], false);
assert!(json["media_urls"].as_array().unwrap().is_empty());
}
#[test]
fn test_skip_none_fields() {
let doc = sample_doc();
let json_str = serde_json::to_string(&doc).unwrap();
// reply_to and thread_id are None, should be omitted
assert!(!json_str.contains("reply_to"));
assert!(!json_str.contains("thread_id"));
}
#[test]
fn test_serialize_with_optional_fields() {
let mut doc = sample_doc();
doc.reply_to = Some("$parent:sunbeam.pt".to_string());
doc.thread_id = Some("$thread:sunbeam.pt".to_string());
doc.media_urls = vec!["mxc://sunbeam.pt/abc".to_string()];
doc.edited = true;
let json = serde_json::to_value(&doc).unwrap();
assert_eq!(json["reply_to"], "$parent:sunbeam.pt");
assert_eq!(json["thread_id"], "$thread:sunbeam.pt");
assert_eq!(json["media_urls"][0], "mxc://sunbeam.pt/abc");
assert_eq!(json["edited"], true);
}
#[test]
fn test_deserialize_roundtrip() {
let doc = sample_doc();
let json_str = serde_json::to_string(&doc).unwrap();
let deserialized: ArchiveDocument = serde_json::from_str(&json_str).unwrap();
assert_eq!(deserialized.event_id, doc.event_id);
assert_eq!(deserialized.room_id, doc.room_id);
assert_eq!(deserialized.room_name, doc.room_name);
assert_eq!(deserialized.sender, doc.sender);
assert_eq!(deserialized.content, doc.content);
assert_eq!(deserialized.timestamp, doc.timestamp);
assert_eq!(deserialized.edited, doc.edited);
assert_eq!(deserialized.redacted, doc.redacted);
}
#[test]
fn test_deserialize_with_defaults() {
// Simulate a document missing optional/default fields
let json = r#"{
"event_id": "$x:s",
"room_id": "!r:s",
"sender": "@a:s",
"timestamp": 1000,
"content": "test",
"event_type": "m.room.message"
}"#;
let doc: ArchiveDocument = serde_json::from_str(json).unwrap();
assert!(doc.room_name.is_none());
assert!(doc.sender_name.is_none());
assert!(doc.reply_to.is_none());
assert!(doc.thread_id.is_none());
assert!(doc.media_urls.is_empty());
assert!(!doc.edited);
assert!(!doc.redacted);
}
#[test]
fn test_index_mapping_is_valid_json() {
let mapping: serde_json::Value =
serde_json::from_str(index_mapping_json()).unwrap();
assert!(mapping["settings"]["number_of_shards"].is_number());
assert!(mapping["mappings"]["properties"]["event_id"]["type"]
.as_str()
.unwrap()
== "keyword");
assert!(mapping["mappings"]["properties"]["content"]["type"]
.as_str()
.unwrap()
== "text");
assert!(mapping["mappings"]["properties"]["timestamp"]["type"]
.as_str()
.unwrap()
== "date");
}
}

207
src/brain/conversation.rs Normal file
View File

@@ -0,0 +1,207 @@
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone)]
pub struct ContextMessage {
pub sender: String,
pub content: String,
pub timestamp: i64,
}
struct RoomContext {
messages: VecDeque<ContextMessage>,
max_size: usize,
}
impl RoomContext {
fn new(max_size: usize) -> Self {
Self {
messages: VecDeque::with_capacity(max_size),
max_size,
}
}
fn add(&mut self, msg: ContextMessage) {
if self.messages.len() >= self.max_size {
self.messages.pop_front();
}
self.messages.push_back(msg);
}
fn get(&self) -> Vec<ContextMessage> {
self.messages.iter().cloned().collect()
}
}
pub struct ConversationManager {
rooms: HashMap<String, RoomContext>,
room_window: usize,
dm_window: usize,
max_rooms: usize,
}
impl ConversationManager {
pub fn new(room_window: usize, dm_window: usize) -> Self {
Self {
rooms: HashMap::new(),
room_window,
dm_window,
max_rooms: 500, // todo(sienna): make this configurable
}
}
pub fn add_message(&mut self, room_id: &str, is_dm: bool, msg: ContextMessage) {
let window = if is_dm {
self.dm_window
} else {
self.room_window
};
// Evict oldest room if at capacity
if !self.rooms.contains_key(room_id) && self.rooms.len() >= self.max_rooms {
// Remove the room with the oldest latest message
let oldest = self
.rooms
.iter()
.min_by_key(|(_, ctx)| ctx.messages.back().map(|m| m.timestamp).unwrap_or(0))
.map(|(k, _)| k.to_owned());
if let Some(key) = oldest {
self.rooms.remove(&key);
}
}
let ctx = self
.rooms
.entry(room_id.to_owned())
.or_insert_with(|| RoomContext::new(window));
ctx.add(msg);
}
pub fn get_context(&self, room_id: &str) -> Vec<ContextMessage> {
self.rooms
.get(room_id)
.map(|ctx| ctx.get())
.unwrap_or_default()
}
pub fn room_count(&self) -> usize {
self.rooms.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn msg(sender: &str, content: &str, ts: i64) -> ContextMessage {
ContextMessage {
sender: sender.to_string(),
content: content.to_string(),
timestamp: ts,
}
}
#[test]
fn test_add_and_get_messages() {
let mut cm = ConversationManager::new(5, 10);
cm.add_message("!room1:x", false, msg("alice", "hello", 1));
cm.add_message("!room1:x", false, msg("bob", "hi", 2));
let ctx = cm.get_context("!room1:x");
assert_eq!(ctx.len(), 2);
assert_eq!(ctx[0].sender, "alice");
assert_eq!(ctx[1].sender, "bob");
}
#[test]
fn test_empty_room_returns_empty() {
let cm = ConversationManager::new(5, 10);
let ctx = cm.get_context("!nonexistent:x");
assert!(ctx.is_empty());
}
#[test]
fn test_sliding_window_group_room() {
let mut cm = ConversationManager::new(3, 10);
for i in 0..5 {
cm.add_message("!room:x", false, msg("user", &format!("msg{i}"), i));
}
let ctx = cm.get_context("!room:x");
assert_eq!(ctx.len(), 3);
// Should keep the last 3
assert_eq!(ctx[0].content, "msg2");
assert_eq!(ctx[1].content, "msg3");
assert_eq!(ctx[2].content, "msg4");
}
#[test]
fn test_sliding_window_dm_room() {
let mut cm = ConversationManager::new(3, 5);
for i in 0..7 {
cm.add_message("!dm:x", true, msg("user", &format!("dm{i}"), i));
}
let ctx = cm.get_context("!dm:x");
assert_eq!(ctx.len(), 5);
assert_eq!(ctx[0].content, "dm2");
assert_eq!(ctx[4].content, "dm6");
}
#[test]
fn test_multiple_rooms_independent() {
let mut cm = ConversationManager::new(5, 10);
cm.add_message("!a:x", false, msg("alice", "in room a", 1));
cm.add_message("!b:x", false, msg("bob", "in room b", 2));
assert_eq!(cm.get_context("!a:x").len(), 1);
assert_eq!(cm.get_context("!b:x").len(), 1);
assert_eq!(cm.get_context("!a:x")[0].content, "in room a");
assert_eq!(cm.get_context("!b:x")[0].content, "in room b");
}
#[test]
fn test_lru_eviction_at_max_rooms() {
// Create a manager with max_rooms = 500 (default), but we'll use a small one
let mut cm = ConversationManager::new(5, 10);
cm.max_rooms = 3;
// Add 3 rooms
cm.add_message("!room1:x", false, msg("a", "r1", 100));
cm.add_message("!room2:x", false, msg("b", "r2", 200));
cm.add_message("!room3:x", false, msg("c", "r3", 300));
assert_eq!(cm.room_count(), 3);
// Adding a 4th room should evict the one with oldest latest message (room1, ts=100)
cm.add_message("!room4:x", false, msg("d", "r4", 400));
assert_eq!(cm.room_count(), 3);
assert!(cm.get_context("!room1:x").is_empty()); // evicted
assert_eq!(cm.get_context("!room2:x").len(), 1);
assert_eq!(cm.get_context("!room3:x").len(), 1);
assert_eq!(cm.get_context("!room4:x").len(), 1);
}
#[test]
fn test_existing_room_not_evicted() {
let mut cm = ConversationManager::new(5, 10);
cm.max_rooms = 2;
cm.add_message("!room1:x", false, msg("a", "r1", 100));
cm.add_message("!room2:x", false, msg("b", "r2", 200));
// Adding to existing room should NOT trigger eviction
cm.add_message("!room1:x", false, msg("a", "r1 again", 300));
assert_eq!(cm.room_count(), 2);
assert_eq!(cm.get_context("!room1:x").len(), 2);
}
#[test]
fn test_message_ordering_preserved() {
let mut cm = ConversationManager::new(10, 10);
cm.add_message("!r:x", false, msg("a", "first", 1));
cm.add_message("!r:x", false, msg("b", "second", 2));
cm.add_message("!r:x", false, msg("c", "third", 3));
let ctx = cm.get_context("!r:x");
assert_eq!(ctx[0].timestamp, 1);
assert_eq!(ctx[1].timestamp, 2);
assert_eq!(ctx[2].timestamp, 3);
}
}

307
src/brain/evaluator.rs Normal file
View File

@@ -0,0 +1,307 @@
use std::sync::Arc;
use mistralai_client::v1::{
chat::{ChatMessage, ChatParams, ResponseFormat},
constants::Model,
};
use regex::Regex;
use tracing::{debug, warn};
use crate::config::Config;
#[derive(Debug)]
pub enum Engagement {
MustRespond { reason: MustRespondReason },
MaybeRespond { relevance: f32, hook: String },
Ignore,
}
#[derive(Debug)]
pub enum MustRespondReason {
DirectMention,
DirectMessage,
NameInvocation,
}
pub struct Evaluator {
config: Arc<Config>,
mention_regex: Regex,
name_regex: Regex,
}
impl Evaluator {
// todo(sienna): regex must be configrable
pub fn new(config: Arc<Config>) -> Self {
let user_id = &config.matrix.user_id;
let mention_pattern = regex::escape(user_id);
let mention_regex = Regex::new(&mention_pattern).expect("Failed to compile mention regex");
let name_regex =
Regex::new(r"(?i)(?:^|\bhey\s+)\bsol\b").expect("Failed to compile name regex");
Self {
config,
mention_regex,
name_regex,
}
}
pub async fn evaluate(
&self,
sender: &str,
body: &str,
is_dm: bool,
recent_messages: &[String],
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Engagement {
// Don't respond to ourselves
if sender == self.config.matrix.user_id {
return Engagement::Ignore;
}
// Direct mention: @sol:sunbeam.pt
if self.mention_regex.is_match(body) {
return Engagement::MustRespond {
reason: MustRespondReason::DirectMention,
};
}
// DM
if is_dm {
return Engagement::MustRespond {
reason: MustRespondReason::DirectMessage,
};
}
// Name invocation: "sol ..." or "hey sol ..."
if self.name_regex.is_match(body) {
return Engagement::MustRespond {
reason: MustRespondReason::NameInvocation,
};
}
// Cheap evaluation call for spontaneous responses
self.evaluate_relevance(body, recent_messages, mistral)
.await
}
/// Check rule-based engagement (without calling Mistral). Returns Some(Engagement)
/// if a rule matched, None if we need to fall through to the LLM evaluation.
pub fn evaluate_rules(
&self,
sender: &str,
body: &str,
is_dm: bool,
) -> Option<Engagement> {
if sender == self.config.matrix.user_id {
return Some(Engagement::Ignore);
}
if self.mention_regex.is_match(body) {
return Some(Engagement::MustRespond {
reason: MustRespondReason::DirectMention,
});
}
if is_dm {
return Some(Engagement::MustRespond {
reason: MustRespondReason::DirectMessage,
});
}
if self.name_regex.is_match(body) {
return Some(Engagement::MustRespond {
reason: MustRespondReason::NameInvocation,
});
}
None
}
async fn evaluate_relevance(
&self,
body: &str,
recent_messages: &[String],
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Engagement {
let context = recent_messages
.iter()
.rev()
.take(5) //todo(sienna): must be configurable
.rev()
.cloned()
.collect::<Vec<_>>()
.join("\n");
let prompt = format!(
"You are evaluating whether a virtual librarian named Sol should spontaneously join \
a conversation. Sol has deep knowledge of the group's message archive and helps \
people find information.\n\n\
Recent conversation:\n{context}\n\n\
Latest message: {body}\n\n\
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\"}}\n\
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant."
);
let messages = vec![ChatMessage::new_user_message(&prompt)];
let params = ChatParams {
response_format: Some(ResponseFormat::json_object()),
temperature: Some(0.1),
max_tokens: Some(100),
..Default::default()
};
let model = Model::new(&self.config.mistral.evaluation_model);
let client = Arc::clone(mistral);
let result = tokio::task::spawn_blocking(move || {
client.chat(model, messages, Some(params))
})
.await
.unwrap_or_else(|e| Err(mistralai_client::v1::error::ApiError {
message: format!("spawn_blocking join error: {e}"),
}));
match result {
Ok(response) => {
let text = &response.choices[0].message.content;
match serde_json::from_str::<serde_json::Value>(text) {
Ok(val) => {
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
let hook = val["hook"].as_str().unwrap_or("").to_string();
debug!(relevance, hook = hook.as_str(), "Evaluation result");
if relevance >= self.config.behavior.spontaneous_threshold {
Engagement::MaybeRespond { relevance, hook }
} else {
Engagement::Ignore
}
}
Err(e) => {
warn!("Failed to parse evaluation response: {e}");
Engagement::Ignore
}
}
}
Err(e) => {
warn!("Evaluation call failed: {e}");
Engagement::Ignore
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
fn test_config() -> Arc<Config> {
let toml = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/tmp/sol"
[opensearch]
url = "http://localhost:9200"
index = "test"
[mistral]
[behavior]
"#;
Arc::new(Config::from_str(toml).unwrap())
}
fn evaluator() -> Evaluator {
Evaluator::new(test_config())
}
#[test]
fn test_ignore_own_messages() {
let ev = evaluator();
let result = ev.evaluate_rules("@sol:sunbeam.pt", "hello everyone", false);
assert!(matches!(result, Some(Engagement::Ignore)));
}
#[test]
fn test_direct_mention() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey @sol:sunbeam.pt what's up?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
));
}
#[test]
fn test_dm_detection() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "random message", true);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMessage })
));
}
#[test]
fn test_name_invocation_start_of_message() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "sol, can you find that link?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_hey_sol() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey sol do you remember?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_case_insensitive() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "Hey Sol, help me", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_sol_uppercase() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "SOL what do you think?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_no_false_positive_solstice() {
let ev = evaluator();
// "solstice" should NOT trigger name invocation — \b boundary prevents it
let result = ev.evaluate_rules("@alice:sunbeam.pt", "the solstice is coming", false);
assert!(result.is_none());
}
#[test]
fn test_random_message_falls_through() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "what's for lunch?", false);
assert!(result.is_none());
}
#[test]
fn test_priority_mention_over_dm() {
// When both mention and DM are true, mention should match first
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hi @sol:sunbeam.pt", true);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
));
}
}

4
src/brain/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod conversation;
pub mod evaluator;
pub mod personality;
pub mod responder;

89
src/brain/personality.rs Normal file
View File

@@ -0,0 +1,89 @@
use chrono::Utc;
pub struct Personality {
template: String,
}
impl Personality {
pub fn new(system_prompt: String) -> Self {
Self {
template: system_prompt,
}
}
pub fn build_system_prompt(
&self,
room_name: &str,
members: &[String],
) -> String {
let date = Utc::now().format("%Y-%m-%d").to_string();
let members_str = members.join(", ");
self.template
.replace("{date}", &date)
.replace("{room_name}", room_name)
.replace("{members}", &members_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_date_substitution() {
let p = Personality::new("Today is {date}.".to_string());
let result = p.build_system_prompt("general", &[]);
let today = Utc::now().format("%Y-%m-%d").to_string();
assert_eq!(result, format!("Today is {today}."));
}
#[test]
fn test_room_name_substitution() {
let p = Personality::new("You are in {room_name}.".to_string());
let result = p.build_system_prompt("design-chat", &[]);
assert!(result.contains("design-chat"));
}
#[test]
fn test_members_substitution() {
let p = Personality::new("Members: {members}".to_string());
let members = vec!["Alice".to_string(), "Bob".to_string(), "Carol".to_string()];
let result = p.build_system_prompt("room", &members);
assert_eq!(result, "Members: Alice, Bob, Carol");
}
#[test]
fn test_empty_members() {
let p = Personality::new("Members: {members}".to_string());
let result = p.build_system_prompt("room", &[]);
assert_eq!(result, "Members: ");
}
#[test]
fn test_all_placeholders() {
let template = "Date: {date}, Room: {room_name}, People: {members}".to_string();
let p = Personality::new(template);
let members = vec!["Sienna".to_string(), "Lonni".to_string()];
let result = p.build_system_prompt("studio", &members);
let today = Utc::now().format("%Y-%m-%d").to_string();
assert!(result.starts_with(&format!("Date: {today}")));
assert!(result.contains("Room: studio"));
assert!(result.contains("People: Sienna, Lonni"));
}
#[test]
fn test_no_placeholders_passthrough() {
let p = Personality::new("Static prompt with no variables.".to_string());
let result = p.build_system_prompt("room", &["Alice".to_string()]);
assert_eq!(result, "Static prompt with no variables.");
}
#[test]
fn test_multiple_same_placeholder() {
let p = Personality::new("{room_name} is great. I love {room_name}.".to_string());
let result = p.build_system_prompt("lounge", &[]);
assert_eq!(result, "lounge is great. I love lounge.");
}
}

179
src/brain/responder.rs Normal file
View File

@@ -0,0 +1,179 @@
use std::sync::Arc;
use mistralai_client::v1::{
chat::{ChatMessage, ChatParams, ChatResponse, ChatResponseChoiceFinishReason},
constants::Model,
error::ApiError,
tool::ToolChoice,
};
use rand::Rng;
use tokio::time::{sleep, Duration};
use tracing::{debug, error, info, warn};
use crate::brain::conversation::ContextMessage;
use crate::brain::personality::Personality;
use crate::config::Config;
use crate::tools::ToolRegistry;
/// Run a Mistral chat completion on a blocking thread.
///
/// The mistral client's `chat_async` holds a `std::sync::MutexGuard` across an
/// `.await` point, making the future !Send. We use the synchronous `chat()`
/// method via `spawn_blocking` instead.
pub(crate) async fn chat_blocking(
client: &Arc<mistralai_client::v1::client::Client>,
model: Model,
messages: Vec<ChatMessage>,
params: ChatParams,
) -> Result<ChatResponse, ApiError> {
let client = Arc::clone(client);
tokio::task::spawn_blocking(move || client.chat(model, messages, Some(params)))
.await
.map_err(|e| ApiError {
message: format!("spawn_blocking join error: {e}"),
})?
}
pub struct Responder {
config: Arc<Config>,
personality: Arc<Personality>,
tools: Arc<ToolRegistry>,
}
impl Responder {
pub fn new(
config: Arc<Config>,
personality: Arc<Personality>,
tools: Arc<ToolRegistry>,
) -> Self {
Self {
config,
personality,
tools,
}
}
pub async fn generate_response(
&self,
context: &[ContextMessage],
trigger_body: &str,
trigger_sender: &str,
room_name: &str,
members: &[String],
is_spontaneous: bool,
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Option<String> {
// Apply response delay
let delay = if is_spontaneous {
rand::thread_rng().gen_range(
self.config.behavior.spontaneous_delay_min_ms
..=self.config.behavior.spontaneous_delay_max_ms,
)
} else {
rand::thread_rng().gen_range(
self.config.behavior.response_delay_min_ms
..=self.config.behavior.response_delay_max_ms,
)
};
sleep(Duration::from_millis(delay)).await;
let system_prompt = self.personality.build_system_prompt(room_name, members);
let mut messages = vec![ChatMessage::new_system_message(&system_prompt)];
// Add context messages
for msg in context {
if msg.sender == self.config.matrix.user_id {
messages.push(ChatMessage::new_assistant_message(&msg.content, None));
} else {
let user_msg = format!("{}: {}", msg.sender, msg.content);
messages.push(ChatMessage::new_user_message(&user_msg));
}
}
// Add the triggering message
let trigger = format!("{trigger_sender}: {trigger_body}");
messages.push(ChatMessage::new_user_message(&trigger));
let tool_defs = ToolRegistry::tool_definitions();
let model = Model::new(&self.config.mistral.default_model);
let max_iterations = self.config.mistral.max_tool_iterations;
for iteration in 0..=max_iterations {
let params = ChatParams {
tools: if iteration < max_iterations {
Some(tool_defs.clone())
} else {
None
},
tool_choice: if iteration < max_iterations {
Some(ToolChoice::Auto)
} else {
None
},
..Default::default()
};
let response = match chat_blocking(mistral, model.clone(), messages.clone(), params).await {
Ok(r) => r,
Err(e) => {
error!("Mistral chat failed: {e}");
return None;
}
};
let choice = &response.choices[0];
if choice.finish_reason == ChatResponseChoiceFinishReason::ToolCalls {
if let Some(tool_calls) = &choice.message.tool_calls {
// Add assistant message with tool calls
messages.push(ChatMessage::new_assistant_message(
&choice.message.content,
Some(tool_calls.clone()),
));
for tc in tool_calls {
let call_id = tc.id.as_deref().unwrap_or("unknown");
info!(
tool = tc.function.name.as_str(),
id = call_id,
"Executing tool call"
);
let result = self
.tools
.execute(&tc.function.name, &tc.function.arguments)
.await;
let result_str = match result {
Ok(s) => s,
Err(e) => {
warn!(tool = tc.function.name.as_str(), "Tool failed: {e}");
format!("Error: {e}")
}
};
messages.push(ChatMessage::new_tool_message(
&result_str,
call_id,
Some(&tc.function.name),
));
}
debug!(iteration, "Tool iteration complete, continuing");
continue;
}
}
// Final text response
let text = choice.message.content.trim().to_string();
if text.is_empty() {
return None;
}
return Some(text);
}
warn!("Exceeded max tool iterations");
None
}
}

219
src/config.rs Normal file
View File

@@ -0,0 +1,219 @@
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub matrix: MatrixConfig,
pub opensearch: OpenSearchConfig,
pub mistral: MistralConfig,
pub behavior: BehaviorConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MatrixConfig {
pub homeserver_url: String,
pub user_id: String,
pub state_store_path: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenSearchConfig {
pub url: String,
pub index: String,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default = "default_flush_interval_ms")]
pub flush_interval_ms: u64,
#[serde(default = "default_embedding_pipeline")]
pub embedding_pipeline: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MistralConfig {
#[serde(default = "default_model")]
pub default_model: String,
#[serde(default = "default_evaluation_model")]
pub evaluation_model: String,
#[serde(default = "default_research_model")]
pub research_model: String,
#[serde(default = "default_max_tool_iterations")]
pub max_tool_iterations: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BehaviorConfig {
#[serde(default = "default_response_delay_min_ms")]
pub response_delay_min_ms: u64,
#[serde(default = "default_response_delay_max_ms")]
pub response_delay_max_ms: u64,
#[serde(default = "default_spontaneous_delay_min_ms")]
pub spontaneous_delay_min_ms: u64,
#[serde(default = "default_spontaneous_delay_max_ms")]
pub spontaneous_delay_max_ms: u64,
#[serde(default = "default_spontaneous_threshold")]
pub spontaneous_threshold: f32,
#[serde(default = "default_room_context_window")]
pub room_context_window: usize,
#[serde(default = "default_dm_context_window")]
pub dm_context_window: usize,
#[serde(default = "default_backfill_on_join")]
pub backfill_on_join: bool,
#[serde(default = "default_backfill_limit")]
pub backfill_limit: usize,
}
fn default_batch_size() -> usize { 50 }
fn default_flush_interval_ms() -> u64 { 2000 }
fn default_embedding_pipeline() -> String { "tuwunel_embedding_pipeline".into() }
fn default_model() -> String { "mistral-medium-latest".into() }
fn default_evaluation_model() -> String { "ministral-3b-latest".into() }
fn default_research_model() -> String { "mistral-large-latest".into() }
fn default_max_tool_iterations() -> usize { 5 }
fn default_response_delay_min_ms() -> u64 { 2000 }
fn default_response_delay_max_ms() -> u64 { 8000 }
fn default_spontaneous_delay_min_ms() -> u64 { 15000 }
fn default_spontaneous_delay_max_ms() -> u64 { 60000 }
fn default_spontaneous_threshold() -> f32 { 0.7 }
fn default_room_context_window() -> usize { 30 }
fn default_dm_context_window() -> usize { 100 }
fn default_backfill_on_join() -> bool { true }
fn default_backfill_limit() -> usize { 10000 }
impl Config {
pub fn load(path: &str) -> anyhow::Result<Self> {
let content = std::fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
Ok(config)
}
pub fn from_str(content: &str) -> anyhow::Result<Self> {
let config: Config = toml::from_str(content)?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
const MINIMAL_CONFIG: &str = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/data/sol/state"
[opensearch]
url = "http://opensearch:9200"
index = "sol-archive"
[mistral]
[behavior]
"#;
const FULL_CONFIG: &str = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/data/sol/state"
[opensearch]
url = "http://opensearch:9200"
index = "sol-archive"
batch_size = 100
flush_interval_ms = 5000
embedding_pipeline = "my_pipeline"
[mistral]
default_model = "mistral-large-latest"
evaluation_model = "ministral-8b-latest"
research_model = "mistral-large-latest"
max_tool_iterations = 10
[behavior]
response_delay_min_ms = 1000
response_delay_max_ms = 5000
spontaneous_delay_min_ms = 10000
spontaneous_delay_max_ms = 30000
spontaneous_threshold = 0.8
room_context_window = 50
dm_context_window = 200
backfill_on_join = false
backfill_limit = 5000
"#;
#[test]
fn test_minimal_config_with_defaults() {
let config = Config::from_str(MINIMAL_CONFIG).unwrap();
assert_eq!(config.matrix.homeserver_url, "https://chat.sunbeam.pt");
assert_eq!(config.matrix.user_id, "@sol:sunbeam.pt");
assert_eq!(config.matrix.state_store_path, "/data/sol/state");
assert_eq!(config.opensearch.url, "http://opensearch:9200");
assert_eq!(config.opensearch.index, "sol-archive");
// Check defaults
assert_eq!(config.opensearch.batch_size, 50);
assert_eq!(config.opensearch.flush_interval_ms, 2000);
assert_eq!(config.opensearch.embedding_pipeline, "tuwunel_embedding_pipeline");
assert_eq!(config.mistral.default_model, "mistral-medium-latest");
assert_eq!(config.mistral.evaluation_model, "ministral-3b-latest");
assert_eq!(config.mistral.research_model, "mistral-large-latest");
assert_eq!(config.mistral.max_tool_iterations, 5);
assert_eq!(config.behavior.response_delay_min_ms, 2000);
assert_eq!(config.behavior.response_delay_max_ms, 8000);
assert_eq!(config.behavior.spontaneous_delay_min_ms, 15000);
assert_eq!(config.behavior.spontaneous_delay_max_ms, 60000);
assert!((config.behavior.spontaneous_threshold - 0.7).abs() < f32::EPSILON);
assert_eq!(config.behavior.room_context_window, 30);
assert_eq!(config.behavior.dm_context_window, 100);
assert!(config.behavior.backfill_on_join);
assert_eq!(config.behavior.backfill_limit, 10000);
}
#[test]
fn test_full_config_overrides() {
let config = Config::from_str(FULL_CONFIG).unwrap();
assert_eq!(config.opensearch.batch_size, 100);
assert_eq!(config.opensearch.flush_interval_ms, 5000);
assert_eq!(config.opensearch.embedding_pipeline, "my_pipeline");
assert_eq!(config.mistral.default_model, "mistral-large-latest");
assert_eq!(config.mistral.evaluation_model, "ministral-8b-latest");
assert_eq!(config.mistral.max_tool_iterations, 10);
assert_eq!(config.behavior.response_delay_min_ms, 1000);
assert_eq!(config.behavior.response_delay_max_ms, 5000);
assert!((config.behavior.spontaneous_threshold - 0.8).abs() < f32::EPSILON);
assert_eq!(config.behavior.room_context_window, 50);
assert_eq!(config.behavior.dm_context_window, 200);
assert!(!config.behavior.backfill_on_join);
assert_eq!(config.behavior.backfill_limit, 5000);
}
#[test]
fn test_missing_required_section_fails() {
let bad = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/data/sol/state"
"#;
assert!(Config::from_str(bad).is_err());
}
#[test]
fn test_missing_required_field_fails() {
let bad = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
state_store_path = "/data/sol/state"
[opensearch]
url = "http://opensearch:9200"
index = "sol-archive"
[mistral]
[behavior]
"#;
assert!(Config::from_str(bad).is_err());
}
}

160
src/main.rs Normal file
View File

@@ -0,0 +1,160 @@
mod archive;
mod brain;
mod config;
mod matrix_utils;
mod sync;
mod tools;
use std::sync::Arc;
use matrix_sdk::Client;
use opensearch::http::transport::TransportBuilder;
use opensearch::OpenSearch;
use ruma::{OwnedDeviceId, OwnedUserId};
use tokio::signal;
use tokio::sync::Mutex;
use tracing::{error, info};
use url::Url;
use archive::indexer::Indexer;
use archive::schema::create_index_if_not_exists;
use brain::conversation::ConversationManager;
use brain::evaluator::Evaluator;
use brain::personality::Personality;
use brain::responder::Responder;
use config::Config;
use sync::AppState;
use tools::ToolRegistry;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("sol=info")),
)
.init();
// Load config
let config_path =
std::env::var("SOL_CONFIG").unwrap_or_else(|_| "/etc/sol/sol.toml".into());
let config = Config::load(&config_path)?;
info!("Loaded config from {config_path}");
// Load system prompt
let prompt_path = std::env::var("SOL_SYSTEM_PROMPT")
.unwrap_or_else(|_| "/etc/sol/system_prompt.md".into());
let system_prompt = std::fs::read_to_string(&prompt_path)?;
info!("Loaded system prompt from {prompt_path}");
// Read secrets from environment
let access_token = std::env::var("SOL_MATRIX_ACCESS_TOKEN")
.map_err(|_| anyhow::anyhow!("SOL_MATRIX_ACCESS_TOKEN not set"))?;
let device_id = std::env::var("SOL_MATRIX_DEVICE_ID")
.map_err(|_| anyhow::anyhow!("SOL_MATRIX_DEVICE_ID not set"))?;
let mistral_api_key = std::env::var("SOL_MISTRAL_API_KEY")
.map_err(|_| anyhow::anyhow!("SOL_MISTRAL_API_KEY not set"))?;
let config = Arc::new(config);
// Initialize Matrix client with E2EE and sqlite store
let homeserver = Url::parse(&config.matrix.homeserver_url)?;
let matrix_client = Client::builder()
.homeserver_url(homeserver)
.sqlite_store(&config.matrix.state_store_path, None)
.build()
.await?;
// Restore session
let user_id: OwnedUserId = config.matrix.user_id.parse()?;
let device_id: OwnedDeviceId = device_id.into();
let session = matrix_sdk::AuthSession::Matrix(matrix_sdk::matrix_auth::MatrixSession {
meta: matrix_sdk::SessionMeta {
user_id,
device_id,
},
tokens: matrix_sdk::matrix_auth::MatrixSessionTokens {
access_token,
refresh_token: None,
},
});
matrix_client.restore_session(session).await?;
info!(user = %config.matrix.user_id, "Matrix session restored");
// Initialize OpenSearch client
let os_url = Url::parse(&config.opensearch.url)?;
let os_transport = TransportBuilder::new(
opensearch::http::transport::SingleNodeConnectionPool::new(os_url),
)
.build()?;
let os_client = OpenSearch::new(os_transport);
// Ensure index exists
create_index_if_not_exists(&os_client, &config.opensearch.index).await?;
// Initialize Mistral client
let mistral_client = mistralai_client::v1::client::Client::new(
Some(mistral_api_key),
None,
None,
None,
)?;
let mistral = Arc::new(mistral_client);
// Build components
let personality = Arc::new(Personality::new(system_prompt));
let tool_registry = Arc::new(ToolRegistry::new(
os_client.clone(),
matrix_client.clone(),
config.clone(),
));
let indexer = Arc::new(Indexer::new(os_client, config.clone()));
let evaluator = Arc::new(Evaluator::new(config.clone()));
let responder = Arc::new(Responder::new(
config.clone(),
personality,
tool_registry,
));
let conversations = Arc::new(Mutex::new(ConversationManager::new(
config.behavior.room_context_window,
config.behavior.dm_context_window,
)));
// Start background flush task
let _flush_handle = indexer.start_flush_task();
// Build shared state
let state = Arc::new(AppState {
config: config.clone(),
indexer,
evaluator,
responder,
conversations,
mistral,
});
// Start sync loop in background
let sync_client = matrix_client.clone();
let sync_state = state.clone();
let sync_handle = tokio::spawn(async move {
if let Err(e) = sync::start_sync(sync_client, sync_state).await {
error!("Sync loop error: {e}");
}
});
info!("Sol is running");
// Wait for shutdown signal
signal::ctrl_c().await?;
info!("Shutdown signal received");
// Cancel sync
sync_handle.abort();
info!("Sol has shut down");
Ok(())
}

77
src/matrix_utils.rs Normal file
View File

@@ -0,0 +1,77 @@
use matrix_sdk::room::Room;
use matrix_sdk::RoomMemberships;
use ruma::events::room::message::{
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
};
use ruma::events::relation::InReplyTo;
use ruma::OwnedEventId;
/// Extract the plain-text body from a message event.
pub fn extract_body(event: &OriginalSyncRoomMessageEvent) -> Option<String> {
match &event.content.msgtype {
MessageType::Text(text) => Some(text.body.clone()),
MessageType::Notice(notice) => Some(notice.body.clone()),
MessageType::Emote(emote) => Some(emote.body.clone()),
_ => None,
}
}
/// Check if this event is an edit (m.replace relation) and return the new body.
pub fn extract_edit(event: &OriginalSyncRoomMessageEvent) -> Option<(OwnedEventId, String)> {
if let Some(Relation::Replacement(replacement)) = &event.content.relates_to {
let new_body = match &replacement.new_content.msgtype {
MessageType::Text(text) => text.body.clone(),
MessageType::Notice(notice) => notice.body.clone(),
_ => return None,
};
return Some((replacement.event_id.clone(), new_body));
}
None
}
/// Extract the event ID being replied to, if any.
pub fn extract_reply_to(event: &OriginalSyncRoomMessageEvent) -> Option<OwnedEventId> {
if let Some(Relation::Reply { in_reply_to }) = &event.content.relates_to {
return Some(in_reply_to.event_id.clone());
}
None
}
/// Extract thread root event ID, if any.
pub fn extract_thread_id(event: &OriginalSyncRoomMessageEvent) -> Option<OwnedEventId> {
if let Some(Relation::Thread(thread)) = &event.content.relates_to {
return Some(thread.event_id.clone());
}
None
}
/// Build a reply message content with m.in_reply_to relation.
pub fn make_reply_content(body: &str, reply_to_event_id: OwnedEventId) -> RoomMessageEventContent {
let mut content = RoomMessageEventContent::text_plain(body);
content.relates_to = Some(Relation::Reply {
in_reply_to: InReplyTo::new(reply_to_event_id),
});
content
}
/// Get the display name for a room.
pub fn room_display_name(room: &Room) -> String {
room.cached_display_name()
.map(|n| n.to_string())
.unwrap_or_else(|| room.room_id().to_string())
}
/// Get member display names for a room.
pub async fn room_member_names(room: &Room) -> Vec<String> {
match room.members(RoomMemberships::JOIN).await {
Ok(members) => members
.iter()
.map(|m| {
m.display_name()
.unwrap_or_else(|| m.user_id().as_str())
.to_string()
})
.collect(),
Err(_) => Vec::new(),
}
}

228
src/sync.rs Normal file
View File

@@ -0,0 +1,228 @@
use std::sync::Arc;
use matrix_sdk::config::SyncSettings;
use matrix_sdk::room::Room;
use matrix_sdk::Client;
use ruma::events::room::member::StrippedRoomMemberEvent;
use ruma::events::room::message::OriginalSyncRoomMessageEvent;
use ruma::events::room::redaction::OriginalSyncRoomRedactionEvent;
use tokio::sync::Mutex;
use tracing::{error, info, warn};
use crate::archive::indexer::Indexer;
use crate::archive::schema::ArchiveDocument;
use crate::brain::conversation::{ContextMessage, ConversationManager};
use crate::brain::evaluator::{Engagement, Evaluator};
use crate::brain::responder::Responder;
use crate::config::Config;
use crate::matrix_utils;
pub struct AppState {
pub config: Arc<Config>,
pub indexer: Arc<Indexer>,
pub evaluator: Arc<Evaluator>,
pub responder: Arc<Responder>,
pub conversations: Arc<Mutex<ConversationManager>>,
pub mistral: Arc<mistralai_client::v1::client::Client>,
}
pub async fn start_sync(client: Client, state: Arc<AppState>) -> anyhow::Result<()> {
// Register event handlers
let s = state.clone();
client.add_event_handler(
move |event: OriginalSyncRoomMessageEvent, room: Room| {
let state = s.clone();
async move {
if let Err(e) = handle_message(event, room, state).await {
error!("Error handling message: {e}");
}
}
},
);
let s = state.clone();
client.add_event_handler(
move |event: OriginalSyncRoomRedactionEvent, _room: Room| {
let state = s.clone();
async move {
handle_redaction(event, &state).await;
}
},
);
client.add_event_handler(
move |event: StrippedRoomMemberEvent, room: Room| async move {
handle_invite(event, room).await;
},
);
info!("Starting Matrix sync loop");
let settings = SyncSettings::default();
client.sync(settings).await?;
Ok(())
}
async fn handle_message(
event: OriginalSyncRoomMessageEvent,
room: Room,
state: Arc<AppState>,
) -> anyhow::Result<()> {
let sender = event.sender.to_string();
let room_id = room.room_id().to_string();
let event_id = event.event_id.to_string();
let timestamp = event.origin_server_ts.0.into();
// Check if this is an edit
if let Some((original_id, new_body)) = matrix_utils::extract_edit(&event) {
state
.indexer
.update_edit(&original_id.to_string(), &new_body)
.await;
return Ok(());
}
let Some(body) = matrix_utils::extract_body(&event) else {
return Ok(());
};
let room_name = matrix_utils::room_display_name(&room);
let sender_name = room
.get_member_no_sync(&event.sender)
.await
.ok()
.flatten()
.and_then(|m| m.display_name().map(|s| s.to_string()));
let reply_to = matrix_utils::extract_reply_to(&event).map(|id| id.to_string());
let thread_id = matrix_utils::extract_thread_id(&event).map(|id| id.to_string());
// Archive the message
let doc = ArchiveDocument {
event_id: event_id.clone(),
room_id: room_id.clone(),
room_name: Some(room_name.clone()),
sender: sender.clone(),
sender_name: sender_name.clone(),
timestamp,
content: body.clone(),
reply_to,
thread_id,
media_urls: Vec::new(),
event_type: "m.room.message".into(),
edited: false,
redacted: false,
};
state.indexer.add(doc).await;
// Update conversation context
let is_dm = room.is_direct().await.unwrap_or(false);
{
let mut convs = state.conversations.lock().await;
convs.add_message(
&room_id,
is_dm,
ContextMessage {
sender: sender_name.clone().unwrap_or_else(|| sender.clone()),
content: body.clone(),
timestamp,
},
);
}
// Evaluate whether to respond
let recent: Vec<String> = {
let convs = state.conversations.lock().await;
convs
.get_context(&room_id)
.iter()
.map(|m| format!("{}: {}", m.sender, m.content))
.collect()
};
let engagement = state
.evaluator
.evaluate(&sender, &body, is_dm, &recent, &state.mistral)
.await;
let (should_respond, is_spontaneous) = match engagement {
Engagement::MustRespond { reason } => {
info!(?reason, "Must respond");
(true, false)
}
Engagement::MaybeRespond { relevance, hook } => {
info!(relevance, hook = hook.as_str(), "Maybe respond (spontaneous)");
(true, true)
}
Engagement::Ignore => (false, false),
};
if !should_respond {
return Ok(());
}
// Show typing indicator
let _ = room.typing_notice(true).await;
let context = {
let convs = state.conversations.lock().await;
convs.get_context(&room_id)
};
let members = matrix_utils::room_member_names(&room).await;
let display_sender = sender_name.as_deref().unwrap_or(&sender);
let response = state
.responder
.generate_response(
&context,
&body,
display_sender,
&room_name,
&members,
is_spontaneous,
&state.mistral,
)
.await;
// Stop typing indicator
let _ = room.typing_notice(false).await;
if let Some(text) = response {
let content = matrix_utils::make_reply_content(&text, event.event_id.to_owned());
if let Err(e) = room.send(content).await {
error!("Failed to send response: {e}");
}
}
Ok(())
}
async fn handle_redaction(event: OriginalSyncRoomRedactionEvent, state: &AppState) {
if let Some(redacted_id) = &event.redacts {
state.indexer.update_redaction(&redacted_id.to_string()).await;
}
}
async fn handle_invite(event: StrippedRoomMemberEvent, room: Room) {
// Only handle our own invites
if event.state_key != room.own_user_id() {
return;
}
info!(room_id = %room.room_id(), "Received invite, auto-joining");
tokio::spawn(async move {
for attempt in 0..3u32 {
match room.join().await {
Ok(_) => {
info!(room_id = %room.room_id(), "Joined room");
return;
}
Err(e) => {
warn!(room_id = %room.room_id(), attempt, "Failed to join: {e}");
tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await;
}
}
}
error!(room_id = %room.room_id(), "Failed to join after retries");
});
}

151
src/tools/mod.rs Normal file
View File

@@ -0,0 +1,151 @@
pub mod room_history;
pub mod room_info;
pub mod search;
use std::sync::Arc;
use matrix_sdk::Client as MatrixClient;
use mistralai_client::v1::tool::Tool;
use opensearch::OpenSearch;
use serde_json::json;
use crate::config::Config;
pub struct ToolRegistry {
opensearch: OpenSearch,
matrix: MatrixClient,
config: Arc<Config>,
}
impl ToolRegistry {
pub fn new(opensearch: OpenSearch, matrix: MatrixClient, config: Arc<Config>) -> Self {
Self {
opensearch,
matrix,
config,
}
}
pub fn tool_definitions() -> Vec<Tool> {
vec![
Tool::new(
"search_archive".into(),
"Search the message archive. Use this to find past conversations, \
messages from specific people, or about specific topics."
.into(),
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for message content"
},
"room": {
"type": "string",
"description": "Filter by room name (optional)"
},
"sender": {
"type": "string",
"description": "Filter by sender display name (optional)"
},
"after": {
"type": "string",
"description": "Unix timestamp in ms — only messages after this time (optional)"
},
"before": {
"type": "string",
"description": "Unix timestamp in ms — only messages before this time (optional)"
},
"limit": {
"type": "integer",
"description": "Max results to return (default 10)"
},
"semantic": {
"type": "boolean",
"description": "Use semantic search instead of keyword (optional)"
}
},
"required": ["query"]
}),
),
Tool::new(
"get_room_context".into(),
"Get messages around a specific point in time or event in a room. \
Useful for understanding the context of a conversation."
.into(),
json!({
"type": "object",
"properties": {
"room_id": {
"type": "string",
"description": "The Matrix room ID"
},
"around_timestamp": {
"type": "integer",
"description": "Unix timestamp in ms to center the context around"
},
"around_event_id": {
"type": "string",
"description": "Event ID to center the context around"
},
"before_count": {
"type": "integer",
"description": "Number of messages before the pivot (default 10)"
},
"after_count": {
"type": "integer",
"description": "Number of messages after the pivot (default 10)"
}
},
"required": ["room_id"]
}),
),
Tool::new(
"list_rooms".into(),
"List all rooms Sol is currently in, with names and member counts.".into(),
json!({
"type": "object",
"properties": {}
}),
),
Tool::new(
"get_room_members".into(),
"Get the list of members in a specific room.".into(),
json!({
"type": "object",
"properties": {
"room_id": {
"type": "string",
"description": "The Matrix room ID"
}
},
"required": ["room_id"]
}),
),
]
}
pub async fn execute(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
match name {
"search_archive" => {
search::search_archive(
&self.opensearch,
&self.config.opensearch.index,
arguments,
)
.await
}
"get_room_context" => {
room_history::get_room_context(
&self.opensearch,
&self.config.opensearch.index,
arguments,
)
.await
}
"list_rooms" => room_info::list_rooms(&self.matrix).await,
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await,
_ => anyhow::bail!("Unknown tool: {name}"),
}
}
}

108
src/tools/room_history.rs Normal file
View File

@@ -0,0 +1,108 @@
use opensearch::OpenSearch;
use serde::Deserialize;
use serde_json::json;
#[derive(Debug, Deserialize)]
pub struct RoomHistoryArgs {
pub room_id: String,
#[serde(default)]
pub around_timestamp: Option<i64>,
#[serde(default)]
pub around_event_id: Option<String>,
#[serde(default = "default_count")]
pub before_count: usize,
#[serde(default = "default_count")]
pub after_count: usize,
}
fn default_count() -> usize { 10 }
pub async fn get_room_context(
client: &OpenSearch,
index: &str,
args_json: &str,
) -> anyhow::Result<String> {
let args: RoomHistoryArgs = serde_json::from_str(args_json)?;
let total = args.before_count + args.after_count + 1;
// Determine the pivot timestamp
let pivot_ts = if let Some(ts) = args.around_timestamp {
ts
} else if let Some(ref event_id) = args.around_event_id {
// Look up the event to get its timestamp
let lookup = json!({
"size": 1,
"query": { "term": { "event_id": event_id } },
"_source": ["timestamp"]
});
let resp = client
.search(opensearch::SearchParts::Index(&[index]))
.body(lookup)
.send()
.await?;
let body: serde_json::Value = resp.json().await?;
body["hits"]["hits"][0]["_source"]["timestamp"]
.as_i64()
.ok_or_else(|| anyhow::anyhow!("Event {event_id} not found in archive"))?
} else {
anyhow::bail!("Either around_timestamp or around_event_id must be provided");
};
let query = json!({
"size": total,
"query": {
"bool": {
"must": [
{ "term": { "room_id": args.room_id } },
{ "term": { "redacted": false } }
],
"should": [
{
"range": {
"timestamp": {
"gte": pivot_ts - 3_600_000,
"lte": pivot_ts + 3_600_000
}
}
}
]
}
},
"sort": [{ "timestamp": "asc" }]
});
let response = client
.search(opensearch::SearchParts::Index(&[index]))
.body(query)
.send()
.await?;
let body: serde_json::Value = response.json().await?;
let hits = &body["hits"]["hits"];
let Some(hits_arr) = hits.as_array() else {
return Ok("No messages found around that point.".into());
};
if hits_arr.is_empty() {
return Ok("No messages found around that point.".into());
}
let mut output = String::new();
for hit in hits_arr {
let src = &hit["_source"];
let sender = src["sender_name"].as_str().unwrap_or("unknown");
let content = src["content"].as_str().unwrap_or("");
let ts = src["timestamp"].as_i64().unwrap_or(0);
let dt = chrono::DateTime::from_timestamp_millis(ts)
.map(|d| d.format("%H:%M").to_string())
.unwrap_or_else(|| "??:??".into());
output.push_str(&format!("[{dt}] {sender}: {content}\n"));
}
Ok(output)
}

55
src/tools/room_info.rs Normal file
View File

@@ -0,0 +1,55 @@
use matrix_sdk::Client;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct ListRoomsArgs {}
#[derive(Debug, Deserialize)]
pub struct GetMembersArgs {
pub room_id: String,
}
pub async fn list_rooms(client: &Client) -> anyhow::Result<String> {
let rooms = client.joined_rooms();
if rooms.is_empty() {
return Ok("I'm not in any rooms.".into());
}
let mut output = String::new();
for room in &rooms {
let name = match room.cached_display_name() {
Some(n) => n.to_string(),
None => room.room_id().to_string(),
};
let id = room.room_id();
let members = room.joined_members_count();
output.push_str(&format!("- {name} ({id}) — {members} members\n"));
}
Ok(output)
}
pub async fn get_room_members(client: &Client, args_json: &str) -> anyhow::Result<String> {
let args: GetMembersArgs = serde_json::from_str(args_json)?;
let room_id = <&ruma::RoomId>::try_from(args.room_id.as_str())?;
let Some(room) = client.get_room(room_id) else {
anyhow::bail!("I'm not in room {}", args.room_id);
};
let members = room.members(matrix_sdk::RoomMemberships::JOIN).await?;
if members.is_empty() {
return Ok("No members found.".into());
}
let mut output = String::new();
for member in &members {
let display = member
.display_name()
.unwrap_or_else(|| member.user_id().as_str());
let user_id = member.user_id();
output.push_str(&format!("- {display} ({user_id})\n"));
}
Ok(output)
}

274
src/tools/search.rs Normal file
View File

@@ -0,0 +1,274 @@
use opensearch::OpenSearch;
use serde::Deserialize;
use serde_json::json;
use tracing::debug;
#[derive(Debug, Deserialize)]
pub struct SearchArgs {
pub query: String,
#[serde(default)]
pub room: Option<String>,
#[serde(default)]
pub sender: Option<String>,
#[serde(default)]
pub after: Option<String>,
#[serde(default)]
pub before: Option<String>,
#[serde(default = "default_limit")]
pub limit: usize,
#[serde(default)]
pub semantic: Option<bool>,
}
fn default_limit() -> usize { 10 }
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
let must = vec![json!({
"match": { "content": args.query }
})];
let mut filter = vec![json!({
"term": { "redacted": false }
})];
if let Some(ref room) = args.room {
filter.push(json!({ "term": { "room_name": room } }));
}
if let Some(ref sender) = args.sender {
filter.push(json!({ "term": { "sender_name": sender } }));
}
let mut range = serde_json::Map::new();
if let Some(ref after) = args.after {
if let Ok(ts) = after.parse::<i64>() {
range.insert("gte".into(), json!(ts));
}
}
if let Some(ref before) = args.before {
if let Ok(ts) = before.parse::<i64>() {
range.insert("lte".into(), json!(ts));
}
}
if !range.is_empty() {
filter.push(json!({ "range": { "timestamp": range } }));
}
json!({
"size": args.limit,
"query": {
"bool": {
"must": must,
"filter": filter
}
},
"sort": [{ "timestamp": "desc" }],
"_source": ["event_id", "room_name", "sender_name", "timestamp", "content"]
})
}
pub async fn search_archive(
client: &OpenSearch,
index: &str,
args_json: &str,
) -> anyhow::Result<String> {
let args: SearchArgs = serde_json::from_str(args_json)?;
debug!(query = args.query.as_str(), "Searching archive");
let query_body = build_search_query(&args);
let response = client
.search(opensearch::SearchParts::Index(&[index]))
.body(query_body)
.send()
.await?;
let body: serde_json::Value = response.json().await?;
let hits = &body["hits"]["hits"];
let Some(hits_arr) = hits.as_array() else {
return Ok("No results found.".into());
};
if hits_arr.is_empty() {
return Ok("No results found.".into());
}
let mut output = String::new();
for (i, hit) in hits_arr.iter().enumerate() {
let src = &hit["_source"];
let sender = src["sender_name"].as_str().unwrap_or("unknown");
let room = src["room_name"].as_str().unwrap_or("unknown");
let content = src["content"].as_str().unwrap_or("");
let ts = src["timestamp"].as_i64().unwrap_or(0);
let dt = chrono::DateTime::from_timestamp_millis(ts)
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
.unwrap_or_else(|| "unknown date".into());
output.push_str(&format!(
"{}. [{dt}] #{room}{sender}: {content}\n",
i + 1
));
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_args(json: &str) -> SearchArgs {
serde_json::from_str(json).unwrap()
}
#[test]
fn test_parse_minimal_args() {
let args = parse_args(r#"{"query": "hello"}"#);
assert_eq!(args.query, "hello");
assert!(args.room.is_none());
assert!(args.sender.is_none());
assert!(args.after.is_none());
assert!(args.before.is_none());
assert_eq!(args.limit, 10); // default
assert!(args.semantic.is_none());
}
#[test]
fn test_parse_full_args() {
let args = parse_args(r#"{
"query": "meeting notes",
"room": "general",
"sender": "Alice",
"after": "1710000000000",
"before": "1710100000000",
"limit": 25,
"semantic": true
}"#);
assert_eq!(args.query, "meeting notes");
assert_eq!(args.room.as_deref(), Some("general"));
assert_eq!(args.sender.as_deref(), Some("Alice"));
assert_eq!(args.after.as_deref(), Some("1710000000000"));
assert_eq!(args.before.as_deref(), Some("1710100000000"));
assert_eq!(args.limit, 25);
assert_eq!(args.semantic, Some(true));
}
#[test]
fn test_query_basic() {
let args = parse_args(r#"{"query": "test"}"#);
let q = build_search_query(&args);
assert_eq!(q["size"], 10);
assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test");
assert_eq!(q["query"]["bool"]["filter"][0]["term"]["redacted"], false);
assert_eq!(q["sort"][0]["timestamp"], "desc");
}
#[test]
fn test_query_with_room_filter() {
let args = parse_args(r#"{"query": "hello", "room": "design"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2);
assert_eq!(filters[1]["term"]["room_name"], "design");
}
#[test]
fn test_query_with_sender_filter() {
let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2);
assert_eq!(filters[1]["term"]["sender_name"], "Bob");
}
#[test]
fn test_query_with_room_and_sender() {
let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 3);
assert_eq!(filters[1]["term"]["room_name"], "dev");
assert_eq!(filters[2]["term"]["sender_name"], "Carol");
}
#[test]
fn test_query_with_date_range() {
let args = parse_args(r#"{
"query": "hello",
"after": "1710000000000",
"before": "1710100000000"
}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
let range_filter = &filters[1]["range"]["timestamp"];
assert_eq!(range_filter["gte"], 1710000000000_i64);
assert_eq!(range_filter["lte"], 1710100000000_i64);
}
#[test]
fn test_query_with_after_only() {
let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
let range_filter = &filters[1]["range"]["timestamp"];
assert_eq!(range_filter["gte"], 1710000000000_i64);
assert!(range_filter.get("lte").is_none());
}
#[test]
fn test_query_with_custom_limit() {
let args = parse_args(r#"{"query": "hello", "limit": 50}"#);
let q = build_search_query(&args);
assert_eq!(q["size"], 50);
}
#[test]
fn test_query_all_filters_combined() {
let args = parse_args(r#"{
"query": "architecture",
"room": "engineering",
"sender": "Sienna",
"after": "1000",
"before": "2000",
"limit": 5
}"#);
let q = build_search_query(&args);
assert_eq!(q["size"], 5);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// redacted=false, room, sender, range = 4 filters
assert_eq!(filters.len(), 4);
}
#[test]
fn test_invalid_timestamp_ignored() {
let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// Only the redacted filter, no range since parse failed
assert_eq!(filters.len(), 1);
}
#[test]
fn test_source_fields() {
let args = parse_args(r#"{"query": "test"}"#);
let q = build_search_query(&args);
let source = q["_source"].as_array().unwrap();
let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect();
assert!(fields.contains(&"event_id"));
assert!(fields.contains(&"room_name"));
assert!(fields.contains(&"sender_name"));
assert!(fields.contains(&"timestamp"));
assert!(fields.contains(&"content"));
}
}