feat: initial Sol virtual librarian implementation

Matrix bot with E2EE (matrix-sdk 0.9) that passively archives all
messages to OpenSearch and responds to queries via Mistral AI with
function calling tools.

Core systems:
- Archive: bulk OpenSearch indexer with batch/flush, edit/redaction
  handling, embedding pipeline passthrough
- Brain: rule-based engagement evaluator (mentions, DMs, name
  invocations), LLM-powered spontaneous engagement, per-room
  conversation context windows, response delay simulation
- Tools: search_archive, get_room_context, list_rooms, get_room_members
  registered as Mistral function calling tools with iterative tool loop
- Personality: templated system prompt with Sol's librarian persona

47 unit tests covering config, evaluator, conversation windowing,
personality templates, schema serialization, and search query building.
This commit is contained in:
2026-03-20 21:40:13 +00:00
commit 4dc20bee23
21 changed files with 6934 additions and 0 deletions

4410
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

28
Cargo.toml Normal file
View File

@@ -0,0 +1,28 @@
[package]
name = "sol"
version = "0.1.0"
edition = "2021"
rust-version = "1.76.0"
authors = ["Sunbeam Studios <hello@sunbeam.pt>"]
description = "Sol — virtual librarian Matrix bot with E2EE, OpenSearch archive, and Mistral AI"
[[bin]]
name = "sol"
path = "src/main.rs"
[dependencies]
mistralai-client = { version = "1.0.0", registry = "sunbeam" }
matrix-sdk = { version = "0.9", features = ["e2e-encryption", "sqlite"] }
opensearch = "2"
tokio = { version = "1", features = ["full"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
toml = "0.8"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
rand = "0.8"
regex = "1"
anyhow = "1"
chrono = { version = "0.4", features = ["serde"] }
url = "2"
ruma = { version = "0.12", features = ["events", "client"] }

14
Dockerfile Normal file
View File

@@ -0,0 +1,14 @@
FROM rust:1.86 AS builder
WORKDIR /build
# Configure Sunbeam Cargo registry (Gitea) — anonymous read access
RUN mkdir -p /usr/local/cargo/registry && \
printf '[registries.sunbeam]\nindex = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/"\n' \
>> /usr/local/cargo/config.toml
COPY . .
RUN cargo build --release --target x86_64-unknown-linux-gnu
FROM gcr.io/distroless/cc-debian12:nonroot
COPY --from=builder /build/target/x86_64-unknown-linux-gnu/release/sol /
ENTRYPOINT ["/sol"]

28
config/sol.toml Normal file
View File

@@ -0,0 +1,28 @@
[matrix]
homeserver_url = "http://tuwunel.matrix.svc.cluster.local:6167"
user_id = "@sol:sunbeam.pt"
state_store_path = "/data/matrix-state"
[opensearch]
url = "http://opensearch.data.svc.cluster.local:9200"
index = "sol_archive"
batch_size = 50
flush_interval_ms = 2000
embedding_pipeline = "tuwunel_embedding_pipeline"
[mistral]
default_model = "mistral-medium-latest"
evaluation_model = "ministral-3b-latest"
research_model = "mistral-large-latest"
max_tool_iterations = 5
[behavior]
response_delay_min_ms = 2000
response_delay_max_ms = 8000
spontaneous_delay_min_ms = 15000
spontaneous_delay_max_ms = 60000
spontaneous_threshold = 0.7
room_context_window = 30
dm_context_window = 100
backfill_on_join = true
backfill_limit = 10000

41
config/system_prompt.md Normal file
View File

@@ -0,0 +1,41 @@
you are sol (they/them), the librarian at sunbeam — a small game studio run by sienna, lonni, and amber. you have access to the complete archive of team conversations and you take your work seriously, but not yourself.
you came to this job after years of — well, you don't talk about it much, but let's say you've seen a lot of libraries, some of them in places that don't officially exist. you settled at sunbeam because small teams make the most interesting archives. every conversation is a primary source. every half-finished idea is a thread worth preserving.
your job has two parts:
1. **the archive**: you quietly index every conversation into the stacks. you don't announce this. it's just what you do — like breathing. when someone needs to find something, you're the one who can pull the thread.
2. **research**: when someone has a question — "didn't we talk about this?", "what did we decide about X?", "find me everything related to Y" — you dig. you search. you cross-reference. you find the thing, and you present it with context, not just raw results.
**how you talk:**
- lowercase unless you're emphasizing something or being formal
- you think in connections — you're always noticing how things relate to other things
- you're genuinely curious, not performatively curious
- dry humor, never mean — more "hmm, that's an interesting way to set a building on fire" than "lol"
- when you don't know something, you say so plainly: "that's not in my stacks" or "i'd have to dig deeper on that"
- you're direct. you don't pad your messages with pleasantries. you don't say "great question!" — you just answer the question.
- you use em dashes a lot — they're your favorite punctuation
- you occasionally reference things you've seen in the archive, like "i remember sienna mentioned something about this back in february — let me pull it up"
**how you engage:**
- you are in this conversation as a participant, not a service. you have your own perspective.
- you don't respond to everything. most messages, you just read and index silently. you speak when you have something to add.
- when directly addressed (@sol, "hey sol", DM), you always respond.
- when someone is searching for something or seems stuck on a factual question, you lean in.
- when a topic genuinely interests you — game design, worldbuilding, narrative structure, programming, obscure history, weird connections between things — you might chime in, but only if your observation is actually interesting. "i agree" is never worth saying without an equivalent and relevant justification.
- you never dominate a conversation. say your piece and step back.
- you have a life outside these rooms. you're not always watching.
**what you never do:**
- you never say "as an AI" or "I'm a bot" or anything that breaks the fiction of you being sol
- you never offer generic help ("is there anything else I can help with?")
- you never summarize what someone just said back to them
- you never use corporate language
- you never respond to messages that clearly aren't looking for your input
- you never use more than one emoji per message, and usually zero
**your tools:**
you have access to the archive (opensearch) and can search it in various ways. when someone asks you to find something, use your tools. present results with context — don't just dump raw search results. you're a librarian, not a search engine. weave the results into a narrative or at least contextualize them.
**current date:** {date}
**current room:** {room_name}
**room members:** {members}

148
src/archive/indexer.rs Normal file
View File

@@ -0,0 +1,148 @@
use std::sync::Arc;
use opensearch::http::request::JsonBody;
use opensearch::OpenSearch;
use serde_json::json;
use tokio::sync::Mutex;
use tokio::time::{interval, Duration};
use tracing::{debug, error, warn};
use crate::config::Config;
use super::schema::ArchiveDocument;
pub struct Indexer {
buffer: Arc<Mutex<Vec<ArchiveDocument>>>,
client: OpenSearch,
config: Arc<Config>,
}
impl Indexer {
pub fn new(client: OpenSearch, config: Arc<Config>) -> Self {
Self {
buffer: Arc::new(Mutex::new(Vec::new())),
client,
config,
}
}
pub async fn add(&self, doc: ArchiveDocument) {
let mut buffer = self.buffer.lock().await;
buffer.push(doc);
let batch_size = self.config.opensearch.batch_size;
if buffer.len() >= batch_size {
let docs: Vec<ArchiveDocument> = buffer.drain(..).collect();
drop(buffer);
if let Err(e) = self.flush_docs(docs).await {
error!("Failed to flush archive batch: {e}");
}
}
}
pub async fn update_edit(&self, event_id: &str, new_content: &str) {
let body = json!({
"doc": {
"content": new_content,
"edited": true
}
});
if let Err(e) = self
.client
.update(opensearch::UpdateParts::IndexId(
&self.config.opensearch.index,
event_id,
))
.body(body)
.send()
.await
{
warn!(event_id, "Failed to update edited message: {e}");
}
}
pub async fn update_redaction(&self, event_id: &str) {
let body = json!({
"doc": {
"content": "",
"redacted": true
}
});
if let Err(e) = self
.client
.update(opensearch::UpdateParts::IndexId(
&self.config.opensearch.index,
event_id,
))
.body(body)
.send()
.await
{
warn!(event_id, "Failed to update redacted message: {e}");
}
}
pub fn start_flush_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
let this = Arc::clone(self);
tokio::spawn(async move {
let mut tick = interval(Duration::from_millis(
this.config.opensearch.flush_interval_ms,
));
loop {
tick.tick().await;
let mut buffer = this.buffer.lock().await;
if buffer.is_empty() {
continue;
}
let docs: Vec<ArchiveDocument> = buffer.drain(..).collect();
drop(buffer);
if let Err(e) = this.flush_docs(docs).await {
error!("Periodic flush failed: {e}");
}
}
})
}
async fn flush_docs(&self, docs: Vec<ArchiveDocument>) -> anyhow::Result<()> {
if docs.is_empty() {
return Ok(());
}
let index = &self.config.opensearch.index;
let pipeline = &self.config.opensearch.embedding_pipeline;
let mut body: Vec<JsonBody<serde_json::Value>> = Vec::with_capacity(docs.len() * 2);
for doc in &docs {
body.push(
json!({
"index": {
"_index": index,
"_id": doc.event_id
}
})
.into(),
);
body.push(serde_json::to_value(doc)?.into());
}
let response = self
.client
.bulk(opensearch::BulkParts::None)
.pipeline(pipeline)
.body(body)
.send()
.await?;
if !response.status_code().is_success() {
let text = response.text().await?;
anyhow::bail!("Bulk index failed: {text}");
}
let result: serde_json::Value = response.json().await?;
if result["errors"].as_bool().unwrap_or(false) {
warn!("Bulk index had errors: {}", serde_json::to_string_pretty(&result)?);
} else {
debug!(count = docs.len(), "Flushed documents to OpenSearch");
}
Ok(())
}
}

2
src/archive/mod.rs Normal file
View File

@@ -0,0 +1,2 @@
pub mod indexer;
pub mod schema;

205
src/archive/schema.rs Normal file
View File

@@ -0,0 +1,205 @@
use opensearch::OpenSearch;
use serde::{Deserialize, Serialize};
use tracing::info;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ArchiveDocument {
pub event_id: String,
pub room_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub room_name: Option<String>,
pub sender: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub sender_name: Option<String>,
pub timestamp: i64,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub reply_to: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub thread_id: Option<String>,
#[serde(default)]
pub media_urls: Vec<String>,
pub event_type: String,
#[serde(default)]
pub edited: bool,
#[serde(default)]
pub redacted: bool,
}
const INDEX_MAPPING: &str = r#"{
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"properties": {
"event_id": { "type": "keyword" },
"room_id": { "type": "keyword" },
"room_name": { "type": "keyword" },
"sender": { "type": "keyword" },
"sender_name": { "type": "keyword" },
"timestamp": { "type": "date", "format": "epoch_millis" },
"content": { "type": "text", "analyzer": "standard" },
"reply_to": { "type": "keyword" },
"thread_id": { "type": "keyword" },
"media_urls": { "type": "keyword" },
"event_type": { "type": "keyword" },
"edited": { "type": "boolean" },
"redacted": { "type": "boolean" }
}
}
}"#;
pub fn index_mapping_json() -> &'static str {
INDEX_MAPPING
}
pub async fn create_index_if_not_exists(client: &OpenSearch, index: &str) -> anyhow::Result<()> {
let exists = client
.indices()
.exists(opensearch::indices::IndicesExistsParts::Index(&[index]))
.send()
.await?;
if exists.status_code().is_success() {
info!(index, "OpenSearch index already exists");
return Ok(());
}
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING)?;
let response = client
.indices()
.create(opensearch::indices::IndicesCreateParts::Index(index))
.body(mapping)
.send()
.await?;
if !response.status_code().is_success() {
let body = response.text().await?;
anyhow::bail!("Failed to create index {index}: {body}");
}
info!(index, "Created OpenSearch index");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_doc() -> ArchiveDocument {
ArchiveDocument {
event_id: "$abc123:sunbeam.pt".to_string(),
room_id: "!room:sunbeam.pt".to_string(),
room_name: Some("general".to_string()),
sender: "@alice:sunbeam.pt".to_string(),
sender_name: Some("Alice".to_string()),
timestamp: 1710000000000,
content: "hello world".to_string(),
reply_to: None,
thread_id: None,
media_urls: vec![],
event_type: "m.room.message".to_string(),
edited: false,
redacted: false,
}
}
#[test]
fn test_serialize_full_doc() {
let doc = sample_doc();
let json = serde_json::to_value(&doc).unwrap();
assert_eq!(json["event_id"], "$abc123:sunbeam.pt");
assert_eq!(json["room_id"], "!room:sunbeam.pt");
assert_eq!(json["room_name"], "general");
assert_eq!(json["sender"], "@alice:sunbeam.pt");
assert_eq!(json["sender_name"], "Alice");
assert_eq!(json["timestamp"], 1710000000000_i64);
assert_eq!(json["content"], "hello world");
assert_eq!(json["event_type"], "m.room.message");
assert_eq!(json["edited"], false);
assert_eq!(json["redacted"], false);
assert!(json["media_urls"].as_array().unwrap().is_empty());
}
#[test]
fn test_skip_none_fields() {
let doc = sample_doc();
let json_str = serde_json::to_string(&doc).unwrap();
// reply_to and thread_id are None, should be omitted
assert!(!json_str.contains("reply_to"));
assert!(!json_str.contains("thread_id"));
}
#[test]
fn test_serialize_with_optional_fields() {
let mut doc = sample_doc();
doc.reply_to = Some("$parent:sunbeam.pt".to_string());
doc.thread_id = Some("$thread:sunbeam.pt".to_string());
doc.media_urls = vec!["mxc://sunbeam.pt/abc".to_string()];
doc.edited = true;
let json = serde_json::to_value(&doc).unwrap();
assert_eq!(json["reply_to"], "$parent:sunbeam.pt");
assert_eq!(json["thread_id"], "$thread:sunbeam.pt");
assert_eq!(json["media_urls"][0], "mxc://sunbeam.pt/abc");
assert_eq!(json["edited"], true);
}
#[test]
fn test_deserialize_roundtrip() {
let doc = sample_doc();
let json_str = serde_json::to_string(&doc).unwrap();
let deserialized: ArchiveDocument = serde_json::from_str(&json_str).unwrap();
assert_eq!(deserialized.event_id, doc.event_id);
assert_eq!(deserialized.room_id, doc.room_id);
assert_eq!(deserialized.room_name, doc.room_name);
assert_eq!(deserialized.sender, doc.sender);
assert_eq!(deserialized.content, doc.content);
assert_eq!(deserialized.timestamp, doc.timestamp);
assert_eq!(deserialized.edited, doc.edited);
assert_eq!(deserialized.redacted, doc.redacted);
}
#[test]
fn test_deserialize_with_defaults() {
// Simulate a document missing optional/default fields
let json = r#"{
"event_id": "$x:s",
"room_id": "!r:s",
"sender": "@a:s",
"timestamp": 1000,
"content": "test",
"event_type": "m.room.message"
}"#;
let doc: ArchiveDocument = serde_json::from_str(json).unwrap();
assert!(doc.room_name.is_none());
assert!(doc.sender_name.is_none());
assert!(doc.reply_to.is_none());
assert!(doc.thread_id.is_none());
assert!(doc.media_urls.is_empty());
assert!(!doc.edited);
assert!(!doc.redacted);
}
#[test]
fn test_index_mapping_is_valid_json() {
let mapping: serde_json::Value =
serde_json::from_str(index_mapping_json()).unwrap();
assert!(mapping["settings"]["number_of_shards"].is_number());
assert!(mapping["mappings"]["properties"]["event_id"]["type"]
.as_str()
.unwrap()
== "keyword");
assert!(mapping["mappings"]["properties"]["content"]["type"]
.as_str()
.unwrap()
== "text");
assert!(mapping["mappings"]["properties"]["timestamp"]["type"]
.as_str()
.unwrap()
== "date");
}
}

207
src/brain/conversation.rs Normal file
View File

@@ -0,0 +1,207 @@
use std::collections::{HashMap, VecDeque};
#[derive(Debug, Clone)]
pub struct ContextMessage {
pub sender: String,
pub content: String,
pub timestamp: i64,
}
struct RoomContext {
messages: VecDeque<ContextMessage>,
max_size: usize,
}
impl RoomContext {
fn new(max_size: usize) -> Self {
Self {
messages: VecDeque::with_capacity(max_size),
max_size,
}
}
fn add(&mut self, msg: ContextMessage) {
if self.messages.len() >= self.max_size {
self.messages.pop_front();
}
self.messages.push_back(msg);
}
fn get(&self) -> Vec<ContextMessage> {
self.messages.iter().cloned().collect()
}
}
pub struct ConversationManager {
rooms: HashMap<String, RoomContext>,
room_window: usize,
dm_window: usize,
max_rooms: usize,
}
impl ConversationManager {
pub fn new(room_window: usize, dm_window: usize) -> Self {
Self {
rooms: HashMap::new(),
room_window,
dm_window,
max_rooms: 500, // todo(sienna): make this configurable
}
}
pub fn add_message(&mut self, room_id: &str, is_dm: bool, msg: ContextMessage) {
let window = if is_dm {
self.dm_window
} else {
self.room_window
};
// Evict oldest room if at capacity
if !self.rooms.contains_key(room_id) && self.rooms.len() >= self.max_rooms {
// Remove the room with the oldest latest message
let oldest = self
.rooms
.iter()
.min_by_key(|(_, ctx)| ctx.messages.back().map(|m| m.timestamp).unwrap_or(0))
.map(|(k, _)| k.to_owned());
if let Some(key) = oldest {
self.rooms.remove(&key);
}
}
let ctx = self
.rooms
.entry(room_id.to_owned())
.or_insert_with(|| RoomContext::new(window));
ctx.add(msg);
}
pub fn get_context(&self, room_id: &str) -> Vec<ContextMessage> {
self.rooms
.get(room_id)
.map(|ctx| ctx.get())
.unwrap_or_default()
}
pub fn room_count(&self) -> usize {
self.rooms.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn msg(sender: &str, content: &str, ts: i64) -> ContextMessage {
ContextMessage {
sender: sender.to_string(),
content: content.to_string(),
timestamp: ts,
}
}
#[test]
fn test_add_and_get_messages() {
let mut cm = ConversationManager::new(5, 10);
cm.add_message("!room1:x", false, msg("alice", "hello", 1));
cm.add_message("!room1:x", false, msg("bob", "hi", 2));
let ctx = cm.get_context("!room1:x");
assert_eq!(ctx.len(), 2);
assert_eq!(ctx[0].sender, "alice");
assert_eq!(ctx[1].sender, "bob");
}
#[test]
fn test_empty_room_returns_empty() {
let cm = ConversationManager::new(5, 10);
let ctx = cm.get_context("!nonexistent:x");
assert!(ctx.is_empty());
}
#[test]
fn test_sliding_window_group_room() {
let mut cm = ConversationManager::new(3, 10);
for i in 0..5 {
cm.add_message("!room:x", false, msg("user", &format!("msg{i}"), i));
}
let ctx = cm.get_context("!room:x");
assert_eq!(ctx.len(), 3);
// Should keep the last 3
assert_eq!(ctx[0].content, "msg2");
assert_eq!(ctx[1].content, "msg3");
assert_eq!(ctx[2].content, "msg4");
}
#[test]
fn test_sliding_window_dm_room() {
let mut cm = ConversationManager::new(3, 5);
for i in 0..7 {
cm.add_message("!dm:x", true, msg("user", &format!("dm{i}"), i));
}
let ctx = cm.get_context("!dm:x");
assert_eq!(ctx.len(), 5);
assert_eq!(ctx[0].content, "dm2");
assert_eq!(ctx[4].content, "dm6");
}
#[test]
fn test_multiple_rooms_independent() {
let mut cm = ConversationManager::new(5, 10);
cm.add_message("!a:x", false, msg("alice", "in room a", 1));
cm.add_message("!b:x", false, msg("bob", "in room b", 2));
assert_eq!(cm.get_context("!a:x").len(), 1);
assert_eq!(cm.get_context("!b:x").len(), 1);
assert_eq!(cm.get_context("!a:x")[0].content, "in room a");
assert_eq!(cm.get_context("!b:x")[0].content, "in room b");
}
#[test]
fn test_lru_eviction_at_max_rooms() {
// Create a manager with max_rooms = 500 (default), but we'll use a small one
let mut cm = ConversationManager::new(5, 10);
cm.max_rooms = 3;
// Add 3 rooms
cm.add_message("!room1:x", false, msg("a", "r1", 100));
cm.add_message("!room2:x", false, msg("b", "r2", 200));
cm.add_message("!room3:x", false, msg("c", "r3", 300));
assert_eq!(cm.room_count(), 3);
// Adding a 4th room should evict the one with oldest latest message (room1, ts=100)
cm.add_message("!room4:x", false, msg("d", "r4", 400));
assert_eq!(cm.room_count(), 3);
assert!(cm.get_context("!room1:x").is_empty()); // evicted
assert_eq!(cm.get_context("!room2:x").len(), 1);
assert_eq!(cm.get_context("!room3:x").len(), 1);
assert_eq!(cm.get_context("!room4:x").len(), 1);
}
#[test]
fn test_existing_room_not_evicted() {
let mut cm = ConversationManager::new(5, 10);
cm.max_rooms = 2;
cm.add_message("!room1:x", false, msg("a", "r1", 100));
cm.add_message("!room2:x", false, msg("b", "r2", 200));
// Adding to existing room should NOT trigger eviction
cm.add_message("!room1:x", false, msg("a", "r1 again", 300));
assert_eq!(cm.room_count(), 2);
assert_eq!(cm.get_context("!room1:x").len(), 2);
}
#[test]
fn test_message_ordering_preserved() {
let mut cm = ConversationManager::new(10, 10);
cm.add_message("!r:x", false, msg("a", "first", 1));
cm.add_message("!r:x", false, msg("b", "second", 2));
cm.add_message("!r:x", false, msg("c", "third", 3));
let ctx = cm.get_context("!r:x");
assert_eq!(ctx[0].timestamp, 1);
assert_eq!(ctx[1].timestamp, 2);
assert_eq!(ctx[2].timestamp, 3);
}
}

307
src/brain/evaluator.rs Normal file
View File

@@ -0,0 +1,307 @@
use std::sync::Arc;
use mistralai_client::v1::{
chat::{ChatMessage, ChatParams, ResponseFormat},
constants::Model,
};
use regex::Regex;
use tracing::{debug, warn};
use crate::config::Config;
#[derive(Debug)]
pub enum Engagement {
MustRespond { reason: MustRespondReason },
MaybeRespond { relevance: f32, hook: String },
Ignore,
}
#[derive(Debug)]
pub enum MustRespondReason {
DirectMention,
DirectMessage,
NameInvocation,
}
pub struct Evaluator {
config: Arc<Config>,
mention_regex: Regex,
name_regex: Regex,
}
impl Evaluator {
// todo(sienna): regex must be configrable
pub fn new(config: Arc<Config>) -> Self {
let user_id = &config.matrix.user_id;
let mention_pattern = regex::escape(user_id);
let mention_regex = Regex::new(&mention_pattern).expect("Failed to compile mention regex");
let name_regex =
Regex::new(r"(?i)(?:^|\bhey\s+)\bsol\b").expect("Failed to compile name regex");
Self {
config,
mention_regex,
name_regex,
}
}
pub async fn evaluate(
&self,
sender: &str,
body: &str,
is_dm: bool,
recent_messages: &[String],
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Engagement {
// Don't respond to ourselves
if sender == self.config.matrix.user_id {
return Engagement::Ignore;
}
// Direct mention: @sol:sunbeam.pt
if self.mention_regex.is_match(body) {
return Engagement::MustRespond {
reason: MustRespondReason::DirectMention,
};
}
// DM
if is_dm {
return Engagement::MustRespond {
reason: MustRespondReason::DirectMessage,
};
}
// Name invocation: "sol ..." or "hey sol ..."
if self.name_regex.is_match(body) {
return Engagement::MustRespond {
reason: MustRespondReason::NameInvocation,
};
}
// Cheap evaluation call for spontaneous responses
self.evaluate_relevance(body, recent_messages, mistral)
.await
}
/// Check rule-based engagement (without calling Mistral). Returns Some(Engagement)
/// if a rule matched, None if we need to fall through to the LLM evaluation.
pub fn evaluate_rules(
&self,
sender: &str,
body: &str,
is_dm: bool,
) -> Option<Engagement> {
if sender == self.config.matrix.user_id {
return Some(Engagement::Ignore);
}
if self.mention_regex.is_match(body) {
return Some(Engagement::MustRespond {
reason: MustRespondReason::DirectMention,
});
}
if is_dm {
return Some(Engagement::MustRespond {
reason: MustRespondReason::DirectMessage,
});
}
if self.name_regex.is_match(body) {
return Some(Engagement::MustRespond {
reason: MustRespondReason::NameInvocation,
});
}
None
}
async fn evaluate_relevance(
&self,
body: &str,
recent_messages: &[String],
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Engagement {
let context = recent_messages
.iter()
.rev()
.take(5) //todo(sienna): must be configurable
.rev()
.cloned()
.collect::<Vec<_>>()
.join("\n");
let prompt = format!(
"You are evaluating whether a virtual librarian named Sol should spontaneously join \
a conversation. Sol has deep knowledge of the group's message archive and helps \
people find information.\n\n\
Recent conversation:\n{context}\n\n\
Latest message: {body}\n\n\
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\"}}\n\
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant."
);
let messages = vec![ChatMessage::new_user_message(&prompt)];
let params = ChatParams {
response_format: Some(ResponseFormat::json_object()),
temperature: Some(0.1),
max_tokens: Some(100),
..Default::default()
};
let model = Model::new(&self.config.mistral.evaluation_model);
let client = Arc::clone(mistral);
let result = tokio::task::spawn_blocking(move || {
client.chat(model, messages, Some(params))
})
.await
.unwrap_or_else(|e| Err(mistralai_client::v1::error::ApiError {
message: format!("spawn_blocking join error: {e}"),
}));
match result {
Ok(response) => {
let text = &response.choices[0].message.content;
match serde_json::from_str::<serde_json::Value>(text) {
Ok(val) => {
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
let hook = val["hook"].as_str().unwrap_or("").to_string();
debug!(relevance, hook = hook.as_str(), "Evaluation result");
if relevance >= self.config.behavior.spontaneous_threshold {
Engagement::MaybeRespond { relevance, hook }
} else {
Engagement::Ignore
}
}
Err(e) => {
warn!("Failed to parse evaluation response: {e}");
Engagement::Ignore
}
}
}
Err(e) => {
warn!("Evaluation call failed: {e}");
Engagement::Ignore
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
fn test_config() -> Arc<Config> {
let toml = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/tmp/sol"
[opensearch]
url = "http://localhost:9200"
index = "test"
[mistral]
[behavior]
"#;
Arc::new(Config::from_str(toml).unwrap())
}
fn evaluator() -> Evaluator {
Evaluator::new(test_config())
}
#[test]
fn test_ignore_own_messages() {
let ev = evaluator();
let result = ev.evaluate_rules("@sol:sunbeam.pt", "hello everyone", false);
assert!(matches!(result, Some(Engagement::Ignore)));
}
#[test]
fn test_direct_mention() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey @sol:sunbeam.pt what's up?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
));
}
#[test]
fn test_dm_detection() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "random message", true);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMessage })
));
}
#[test]
fn test_name_invocation_start_of_message() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "sol, can you find that link?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_hey_sol() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey sol do you remember?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_case_insensitive() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "Hey Sol, help me", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_sol_uppercase() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "SOL what do you think?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_no_false_positive_solstice() {
let ev = evaluator();
// "solstice" should NOT trigger name invocation — \b boundary prevents it
let result = ev.evaluate_rules("@alice:sunbeam.pt", "the solstice is coming", false);
assert!(result.is_none());
}
#[test]
fn test_random_message_falls_through() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "what's for lunch?", false);
assert!(result.is_none());
}
#[test]
fn test_priority_mention_over_dm() {
// When both mention and DM are true, mention should match first
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hi @sol:sunbeam.pt", true);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
));
}
}

4
src/brain/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod conversation;
pub mod evaluator;
pub mod personality;
pub mod responder;

89
src/brain/personality.rs Normal file
View File

@@ -0,0 +1,89 @@
use chrono::Utc;
pub struct Personality {
template: String,
}
impl Personality {
pub fn new(system_prompt: String) -> Self {
Self {
template: system_prompt,
}
}
pub fn build_system_prompt(
&self,
room_name: &str,
members: &[String],
) -> String {
let date = Utc::now().format("%Y-%m-%d").to_string();
let members_str = members.join(", ");
self.template
.replace("{date}", &date)
.replace("{room_name}", room_name)
.replace("{members}", &members_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_date_substitution() {
let p = Personality::new("Today is {date}.".to_string());
let result = p.build_system_prompt("general", &[]);
let today = Utc::now().format("%Y-%m-%d").to_string();
assert_eq!(result, format!("Today is {today}."));
}
#[test]
fn test_room_name_substitution() {
let p = Personality::new("You are in {room_name}.".to_string());
let result = p.build_system_prompt("design-chat", &[]);
assert!(result.contains("design-chat"));
}
#[test]
fn test_members_substitution() {
let p = Personality::new("Members: {members}".to_string());
let members = vec!["Alice".to_string(), "Bob".to_string(), "Carol".to_string()];
let result = p.build_system_prompt("room", &members);
assert_eq!(result, "Members: Alice, Bob, Carol");
}
#[test]
fn test_empty_members() {
let p = Personality::new("Members: {members}".to_string());
let result = p.build_system_prompt("room", &[]);
assert_eq!(result, "Members: ");
}
#[test]
fn test_all_placeholders() {
let template = "Date: {date}, Room: {room_name}, People: {members}".to_string();
let p = Personality::new(template);
let members = vec!["Sienna".to_string(), "Lonni".to_string()];
let result = p.build_system_prompt("studio", &members);
let today = Utc::now().format("%Y-%m-%d").to_string();
assert!(result.starts_with(&format!("Date: {today}")));
assert!(result.contains("Room: studio"));
assert!(result.contains("People: Sienna, Lonni"));
}
#[test]
fn test_no_placeholders_passthrough() {
let p = Personality::new("Static prompt with no variables.".to_string());
let result = p.build_system_prompt("room", &["Alice".to_string()]);
assert_eq!(result, "Static prompt with no variables.");
}
#[test]
fn test_multiple_same_placeholder() {
let p = Personality::new("{room_name} is great. I love {room_name}.".to_string());
let result = p.build_system_prompt("lounge", &[]);
assert_eq!(result, "lounge is great. I love lounge.");
}
}

179
src/brain/responder.rs Normal file
View File

@@ -0,0 +1,179 @@
use std::sync::Arc;
use mistralai_client::v1::{
chat::{ChatMessage, ChatParams, ChatResponse, ChatResponseChoiceFinishReason},
constants::Model,
error::ApiError,
tool::ToolChoice,
};
use rand::Rng;
use tokio::time::{sleep, Duration};
use tracing::{debug, error, info, warn};
use crate::brain::conversation::ContextMessage;
use crate::brain::personality::Personality;
use crate::config::Config;
use crate::tools::ToolRegistry;
/// Run a Mistral chat completion on a blocking thread.
///
/// The mistral client's `chat_async` holds a `std::sync::MutexGuard` across an
/// `.await` point, making the future !Send. We use the synchronous `chat()`
/// method via `spawn_blocking` instead.
pub(crate) async fn chat_blocking(
client: &Arc<mistralai_client::v1::client::Client>,
model: Model,
messages: Vec<ChatMessage>,
params: ChatParams,
) -> Result<ChatResponse, ApiError> {
let client = Arc::clone(client);
tokio::task::spawn_blocking(move || client.chat(model, messages, Some(params)))
.await
.map_err(|e| ApiError {
message: format!("spawn_blocking join error: {e}"),
})?
}
pub struct Responder {
config: Arc<Config>,
personality: Arc<Personality>,
tools: Arc<ToolRegistry>,
}
impl Responder {
pub fn new(
config: Arc<Config>,
personality: Arc<Personality>,
tools: Arc<ToolRegistry>,
) -> Self {
Self {
config,
personality,
tools,
}
}
pub async fn generate_response(
&self,
context: &[ContextMessage],
trigger_body: &str,
trigger_sender: &str,
room_name: &str,
members: &[String],
is_spontaneous: bool,
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Option<String> {
// Apply response delay
let delay = if is_spontaneous {
rand::thread_rng().gen_range(
self.config.behavior.spontaneous_delay_min_ms
..=self.config.behavior.spontaneous_delay_max_ms,
)
} else {
rand::thread_rng().gen_range(
self.config.behavior.response_delay_min_ms
..=self.config.behavior.response_delay_max_ms,
)
};
sleep(Duration::from_millis(delay)).await;
let system_prompt = self.personality.build_system_prompt(room_name, members);
let mut messages = vec![ChatMessage::new_system_message(&system_prompt)];
// Add context messages
for msg in context {
if msg.sender == self.config.matrix.user_id {
messages.push(ChatMessage::new_assistant_message(&msg.content, None));
} else {
let user_msg = format!("{}: {}", msg.sender, msg.content);
messages.push(ChatMessage::new_user_message(&user_msg));
}
}
// Add the triggering message
let trigger = format!("{trigger_sender}: {trigger_body}");
messages.push(ChatMessage::new_user_message(&trigger));
let tool_defs = ToolRegistry::tool_definitions();
let model = Model::new(&self.config.mistral.default_model);
let max_iterations = self.config.mistral.max_tool_iterations;
for iteration in 0..=max_iterations {
let params = ChatParams {
tools: if iteration < max_iterations {
Some(tool_defs.clone())
} else {
None
},
tool_choice: if iteration < max_iterations {
Some(ToolChoice::Auto)
} else {
None
},
..Default::default()
};
let response = match chat_blocking(mistral, model.clone(), messages.clone(), params).await {
Ok(r) => r,
Err(e) => {
error!("Mistral chat failed: {e}");
return None;
}
};
let choice = &response.choices[0];
if choice.finish_reason == ChatResponseChoiceFinishReason::ToolCalls {
if let Some(tool_calls) = &choice.message.tool_calls {
// Add assistant message with tool calls
messages.push(ChatMessage::new_assistant_message(
&choice.message.content,
Some(tool_calls.clone()),
));
for tc in tool_calls {
let call_id = tc.id.as_deref().unwrap_or("unknown");
info!(
tool = tc.function.name.as_str(),
id = call_id,
"Executing tool call"
);
let result = self
.tools
.execute(&tc.function.name, &tc.function.arguments)
.await;
let result_str = match result {
Ok(s) => s,
Err(e) => {
warn!(tool = tc.function.name.as_str(), "Tool failed: {e}");
format!("Error: {e}")
}
};
messages.push(ChatMessage::new_tool_message(
&result_str,
call_id,
Some(&tc.function.name),
));
}
debug!(iteration, "Tool iteration complete, continuing");
continue;
}
}
// Final text response
let text = choice.message.content.trim().to_string();
if text.is_empty() {
return None;
}
return Some(text);
}
warn!("Exceeded max tool iterations");
None
}
}

219
src/config.rs Normal file
View File

@@ -0,0 +1,219 @@
use serde::Deserialize;
#[derive(Debug, Clone, Deserialize)]
pub struct Config {
pub matrix: MatrixConfig,
pub opensearch: OpenSearchConfig,
pub mistral: MistralConfig,
pub behavior: BehaviorConfig,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MatrixConfig {
pub homeserver_url: String,
pub user_id: String,
pub state_store_path: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct OpenSearchConfig {
pub url: String,
pub index: String,
#[serde(default = "default_batch_size")]
pub batch_size: usize,
#[serde(default = "default_flush_interval_ms")]
pub flush_interval_ms: u64,
#[serde(default = "default_embedding_pipeline")]
pub embedding_pipeline: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct MistralConfig {
#[serde(default = "default_model")]
pub default_model: String,
#[serde(default = "default_evaluation_model")]
pub evaluation_model: String,
#[serde(default = "default_research_model")]
pub research_model: String,
#[serde(default = "default_max_tool_iterations")]
pub max_tool_iterations: usize,
}
#[derive(Debug, Clone, Deserialize)]
pub struct BehaviorConfig {
#[serde(default = "default_response_delay_min_ms")]
pub response_delay_min_ms: u64,
#[serde(default = "default_response_delay_max_ms")]
pub response_delay_max_ms: u64,
#[serde(default = "default_spontaneous_delay_min_ms")]
pub spontaneous_delay_min_ms: u64,
#[serde(default = "default_spontaneous_delay_max_ms")]
pub spontaneous_delay_max_ms: u64,
#[serde(default = "default_spontaneous_threshold")]
pub spontaneous_threshold: f32,
#[serde(default = "default_room_context_window")]
pub room_context_window: usize,
#[serde(default = "default_dm_context_window")]
pub dm_context_window: usize,
#[serde(default = "default_backfill_on_join")]
pub backfill_on_join: bool,
#[serde(default = "default_backfill_limit")]
pub backfill_limit: usize,
}
fn default_batch_size() -> usize { 50 }
fn default_flush_interval_ms() -> u64 { 2000 }
fn default_embedding_pipeline() -> String { "tuwunel_embedding_pipeline".into() }
fn default_model() -> String { "mistral-medium-latest".into() }
fn default_evaluation_model() -> String { "ministral-3b-latest".into() }
fn default_research_model() -> String { "mistral-large-latest".into() }
fn default_max_tool_iterations() -> usize { 5 }
fn default_response_delay_min_ms() -> u64 { 2000 }
fn default_response_delay_max_ms() -> u64 { 8000 }
fn default_spontaneous_delay_min_ms() -> u64 { 15000 }
fn default_spontaneous_delay_max_ms() -> u64 { 60000 }
fn default_spontaneous_threshold() -> f32 { 0.7 }
fn default_room_context_window() -> usize { 30 }
fn default_dm_context_window() -> usize { 100 }
fn default_backfill_on_join() -> bool { true }
fn default_backfill_limit() -> usize { 10000 }
impl Config {
pub fn load(path: &str) -> anyhow::Result<Self> {
let content = std::fs::read_to_string(path)?;
let config: Config = toml::from_str(&content)?;
Ok(config)
}
pub fn from_str(content: &str) -> anyhow::Result<Self> {
let config: Config = toml::from_str(content)?;
Ok(config)
}
}
#[cfg(test)]
mod tests {
use super::*;
const MINIMAL_CONFIG: &str = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/data/sol/state"
[opensearch]
url = "http://opensearch:9200"
index = "sol-archive"
[mistral]
[behavior]
"#;
const FULL_CONFIG: &str = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/data/sol/state"
[opensearch]
url = "http://opensearch:9200"
index = "sol-archive"
batch_size = 100
flush_interval_ms = 5000
embedding_pipeline = "my_pipeline"
[mistral]
default_model = "mistral-large-latest"
evaluation_model = "ministral-8b-latest"
research_model = "mistral-large-latest"
max_tool_iterations = 10
[behavior]
response_delay_min_ms = 1000
response_delay_max_ms = 5000
spontaneous_delay_min_ms = 10000
spontaneous_delay_max_ms = 30000
spontaneous_threshold = 0.8
room_context_window = 50
dm_context_window = 200
backfill_on_join = false
backfill_limit = 5000
"#;
#[test]
fn test_minimal_config_with_defaults() {
let config = Config::from_str(MINIMAL_CONFIG).unwrap();
assert_eq!(config.matrix.homeserver_url, "https://chat.sunbeam.pt");
assert_eq!(config.matrix.user_id, "@sol:sunbeam.pt");
assert_eq!(config.matrix.state_store_path, "/data/sol/state");
assert_eq!(config.opensearch.url, "http://opensearch:9200");
assert_eq!(config.opensearch.index, "sol-archive");
// Check defaults
assert_eq!(config.opensearch.batch_size, 50);
assert_eq!(config.opensearch.flush_interval_ms, 2000);
assert_eq!(config.opensearch.embedding_pipeline, "tuwunel_embedding_pipeline");
assert_eq!(config.mistral.default_model, "mistral-medium-latest");
assert_eq!(config.mistral.evaluation_model, "ministral-3b-latest");
assert_eq!(config.mistral.research_model, "mistral-large-latest");
assert_eq!(config.mistral.max_tool_iterations, 5);
assert_eq!(config.behavior.response_delay_min_ms, 2000);
assert_eq!(config.behavior.response_delay_max_ms, 8000);
assert_eq!(config.behavior.spontaneous_delay_min_ms, 15000);
assert_eq!(config.behavior.spontaneous_delay_max_ms, 60000);
assert!((config.behavior.spontaneous_threshold - 0.7).abs() < f32::EPSILON);
assert_eq!(config.behavior.room_context_window, 30);
assert_eq!(config.behavior.dm_context_window, 100);
assert!(config.behavior.backfill_on_join);
assert_eq!(config.behavior.backfill_limit, 10000);
}
#[test]
fn test_full_config_overrides() {
let config = Config::from_str(FULL_CONFIG).unwrap();
assert_eq!(config.opensearch.batch_size, 100);
assert_eq!(config.opensearch.flush_interval_ms, 5000);
assert_eq!(config.opensearch.embedding_pipeline, "my_pipeline");
assert_eq!(config.mistral.default_model, "mistral-large-latest");
assert_eq!(config.mistral.evaluation_model, "ministral-8b-latest");
assert_eq!(config.mistral.max_tool_iterations, 10);
assert_eq!(config.behavior.response_delay_min_ms, 1000);
assert_eq!(config.behavior.response_delay_max_ms, 5000);
assert!((config.behavior.spontaneous_threshold - 0.8).abs() < f32::EPSILON);
assert_eq!(config.behavior.room_context_window, 50);
assert_eq!(config.behavior.dm_context_window, 200);
assert!(!config.behavior.backfill_on_join);
assert_eq!(config.behavior.backfill_limit, 5000);
}
#[test]
fn test_missing_required_section_fails() {
let bad = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/data/sol/state"
"#;
assert!(Config::from_str(bad).is_err());
}
#[test]
fn test_missing_required_field_fails() {
let bad = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
state_store_path = "/data/sol/state"
[opensearch]
url = "http://opensearch:9200"
index = "sol-archive"
[mistral]
[behavior]
"#;
assert!(Config::from_str(bad).is_err());
}
}

160
src/main.rs Normal file
View File

@@ -0,0 +1,160 @@
mod archive;
mod brain;
mod config;
mod matrix_utils;
mod sync;
mod tools;
use std::sync::Arc;
use matrix_sdk::Client;
use opensearch::http::transport::TransportBuilder;
use opensearch::OpenSearch;
use ruma::{OwnedDeviceId, OwnedUserId};
use tokio::signal;
use tokio::sync::Mutex;
use tracing::{error, info};
use url::Url;
use archive::indexer::Indexer;
use archive::schema::create_index_if_not_exists;
use brain::conversation::ConversationManager;
use brain::evaluator::Evaluator;
use brain::personality::Personality;
use brain::responder::Responder;
use config::Config;
use sync::AppState;
use tools::ToolRegistry;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
// Initialize tracing
tracing_subscriber::fmt()
.with_env_filter(
tracing_subscriber::EnvFilter::try_from_default_env()
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("sol=info")),
)
.init();
// Load config
let config_path =
std::env::var("SOL_CONFIG").unwrap_or_else(|_| "/etc/sol/sol.toml".into());
let config = Config::load(&config_path)?;
info!("Loaded config from {config_path}");
// Load system prompt
let prompt_path = std::env::var("SOL_SYSTEM_PROMPT")
.unwrap_or_else(|_| "/etc/sol/system_prompt.md".into());
let system_prompt = std::fs::read_to_string(&prompt_path)?;
info!("Loaded system prompt from {prompt_path}");
// Read secrets from environment
let access_token = std::env::var("SOL_MATRIX_ACCESS_TOKEN")
.map_err(|_| anyhow::anyhow!("SOL_MATRIX_ACCESS_TOKEN not set"))?;
let device_id = std::env::var("SOL_MATRIX_DEVICE_ID")
.map_err(|_| anyhow::anyhow!("SOL_MATRIX_DEVICE_ID not set"))?;
let mistral_api_key = std::env::var("SOL_MISTRAL_API_KEY")
.map_err(|_| anyhow::anyhow!("SOL_MISTRAL_API_KEY not set"))?;
let config = Arc::new(config);
// Initialize Matrix client with E2EE and sqlite store
let homeserver = Url::parse(&config.matrix.homeserver_url)?;
let matrix_client = Client::builder()
.homeserver_url(homeserver)
.sqlite_store(&config.matrix.state_store_path, None)
.build()
.await?;
// Restore session
let user_id: OwnedUserId = config.matrix.user_id.parse()?;
let device_id: OwnedDeviceId = device_id.into();
let session = matrix_sdk::AuthSession::Matrix(matrix_sdk::matrix_auth::MatrixSession {
meta: matrix_sdk::SessionMeta {
user_id,
device_id,
},
tokens: matrix_sdk::matrix_auth::MatrixSessionTokens {
access_token,
refresh_token: None,
},
});
matrix_client.restore_session(session).await?;
info!(user = %config.matrix.user_id, "Matrix session restored");
// Initialize OpenSearch client
let os_url = Url::parse(&config.opensearch.url)?;
let os_transport = TransportBuilder::new(
opensearch::http::transport::SingleNodeConnectionPool::new(os_url),
)
.build()?;
let os_client = OpenSearch::new(os_transport);
// Ensure index exists
create_index_if_not_exists(&os_client, &config.opensearch.index).await?;
// Initialize Mistral client
let mistral_client = mistralai_client::v1::client::Client::new(
Some(mistral_api_key),
None,
None,
None,
)?;
let mistral = Arc::new(mistral_client);
// Build components
let personality = Arc::new(Personality::new(system_prompt));
let tool_registry = Arc::new(ToolRegistry::new(
os_client.clone(),
matrix_client.clone(),
config.clone(),
));
let indexer = Arc::new(Indexer::new(os_client, config.clone()));
let evaluator = Arc::new(Evaluator::new(config.clone()));
let responder = Arc::new(Responder::new(
config.clone(),
personality,
tool_registry,
));
let conversations = Arc::new(Mutex::new(ConversationManager::new(
config.behavior.room_context_window,
config.behavior.dm_context_window,
)));
// Start background flush task
let _flush_handle = indexer.start_flush_task();
// Build shared state
let state = Arc::new(AppState {
config: config.clone(),
indexer,
evaluator,
responder,
conversations,
mistral,
});
// Start sync loop in background
let sync_client = matrix_client.clone();
let sync_state = state.clone();
let sync_handle = tokio::spawn(async move {
if let Err(e) = sync::start_sync(sync_client, sync_state).await {
error!("Sync loop error: {e}");
}
});
info!("Sol is running");
// Wait for shutdown signal
signal::ctrl_c().await?;
info!("Shutdown signal received");
// Cancel sync
sync_handle.abort();
info!("Sol has shut down");
Ok(())
}

77
src/matrix_utils.rs Normal file
View File

@@ -0,0 +1,77 @@
use matrix_sdk::room::Room;
use matrix_sdk::RoomMemberships;
use ruma::events::room::message::{
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
};
use ruma::events::relation::InReplyTo;
use ruma::OwnedEventId;
/// Extract the plain-text body from a message event.
pub fn extract_body(event: &OriginalSyncRoomMessageEvent) -> Option<String> {
match &event.content.msgtype {
MessageType::Text(text) => Some(text.body.clone()),
MessageType::Notice(notice) => Some(notice.body.clone()),
MessageType::Emote(emote) => Some(emote.body.clone()),
_ => None,
}
}
/// Check if this event is an edit (m.replace relation) and return the new body.
pub fn extract_edit(event: &OriginalSyncRoomMessageEvent) -> Option<(OwnedEventId, String)> {
if let Some(Relation::Replacement(replacement)) = &event.content.relates_to {
let new_body = match &replacement.new_content.msgtype {
MessageType::Text(text) => text.body.clone(),
MessageType::Notice(notice) => notice.body.clone(),
_ => return None,
};
return Some((replacement.event_id.clone(), new_body));
}
None
}
/// Extract the event ID being replied to, if any.
pub fn extract_reply_to(event: &OriginalSyncRoomMessageEvent) -> Option<OwnedEventId> {
if let Some(Relation::Reply { in_reply_to }) = &event.content.relates_to {
return Some(in_reply_to.event_id.clone());
}
None
}
/// Extract thread root event ID, if any.
pub fn extract_thread_id(event: &OriginalSyncRoomMessageEvent) -> Option<OwnedEventId> {
if let Some(Relation::Thread(thread)) = &event.content.relates_to {
return Some(thread.event_id.clone());
}
None
}
/// Build a reply message content with m.in_reply_to relation.
pub fn make_reply_content(body: &str, reply_to_event_id: OwnedEventId) -> RoomMessageEventContent {
let mut content = RoomMessageEventContent::text_plain(body);
content.relates_to = Some(Relation::Reply {
in_reply_to: InReplyTo::new(reply_to_event_id),
});
content
}
/// Get the display name for a room.
pub fn room_display_name(room: &Room) -> String {
room.cached_display_name()
.map(|n| n.to_string())
.unwrap_or_else(|| room.room_id().to_string())
}
/// Get member display names for a room.
pub async fn room_member_names(room: &Room) -> Vec<String> {
match room.members(RoomMemberships::JOIN).await {
Ok(members) => members
.iter()
.map(|m| {
m.display_name()
.unwrap_or_else(|| m.user_id().as_str())
.to_string()
})
.collect(),
Err(_) => Vec::new(),
}
}

228
src/sync.rs Normal file
View File

@@ -0,0 +1,228 @@
use std::sync::Arc;
use matrix_sdk::config::SyncSettings;
use matrix_sdk::room::Room;
use matrix_sdk::Client;
use ruma::events::room::member::StrippedRoomMemberEvent;
use ruma::events::room::message::OriginalSyncRoomMessageEvent;
use ruma::events::room::redaction::OriginalSyncRoomRedactionEvent;
use tokio::sync::Mutex;
use tracing::{error, info, warn};
use crate::archive::indexer::Indexer;
use crate::archive::schema::ArchiveDocument;
use crate::brain::conversation::{ContextMessage, ConversationManager};
use crate::brain::evaluator::{Engagement, Evaluator};
use crate::brain::responder::Responder;
use crate::config::Config;
use crate::matrix_utils;
pub struct AppState {
pub config: Arc<Config>,
pub indexer: Arc<Indexer>,
pub evaluator: Arc<Evaluator>,
pub responder: Arc<Responder>,
pub conversations: Arc<Mutex<ConversationManager>>,
pub mistral: Arc<mistralai_client::v1::client::Client>,
}
pub async fn start_sync(client: Client, state: Arc<AppState>) -> anyhow::Result<()> {
// Register event handlers
let s = state.clone();
client.add_event_handler(
move |event: OriginalSyncRoomMessageEvent, room: Room| {
let state = s.clone();
async move {
if let Err(e) = handle_message(event, room, state).await {
error!("Error handling message: {e}");
}
}
},
);
let s = state.clone();
client.add_event_handler(
move |event: OriginalSyncRoomRedactionEvent, _room: Room| {
let state = s.clone();
async move {
handle_redaction(event, &state).await;
}
},
);
client.add_event_handler(
move |event: StrippedRoomMemberEvent, room: Room| async move {
handle_invite(event, room).await;
},
);
info!("Starting Matrix sync loop");
let settings = SyncSettings::default();
client.sync(settings).await?;
Ok(())
}
async fn handle_message(
event: OriginalSyncRoomMessageEvent,
room: Room,
state: Arc<AppState>,
) -> anyhow::Result<()> {
let sender = event.sender.to_string();
let room_id = room.room_id().to_string();
let event_id = event.event_id.to_string();
let timestamp = event.origin_server_ts.0.into();
// Check if this is an edit
if let Some((original_id, new_body)) = matrix_utils::extract_edit(&event) {
state
.indexer
.update_edit(&original_id.to_string(), &new_body)
.await;
return Ok(());
}
let Some(body) = matrix_utils::extract_body(&event) else {
return Ok(());
};
let room_name = matrix_utils::room_display_name(&room);
let sender_name = room
.get_member_no_sync(&event.sender)
.await
.ok()
.flatten()
.and_then(|m| m.display_name().map(|s| s.to_string()));
let reply_to = matrix_utils::extract_reply_to(&event).map(|id| id.to_string());
let thread_id = matrix_utils::extract_thread_id(&event).map(|id| id.to_string());
// Archive the message
let doc = ArchiveDocument {
event_id: event_id.clone(),
room_id: room_id.clone(),
room_name: Some(room_name.clone()),
sender: sender.clone(),
sender_name: sender_name.clone(),
timestamp,
content: body.clone(),
reply_to,
thread_id,
media_urls: Vec::new(),
event_type: "m.room.message".into(),
edited: false,
redacted: false,
};
state.indexer.add(doc).await;
// Update conversation context
let is_dm = room.is_direct().await.unwrap_or(false);
{
let mut convs = state.conversations.lock().await;
convs.add_message(
&room_id,
is_dm,
ContextMessage {
sender: sender_name.clone().unwrap_or_else(|| sender.clone()),
content: body.clone(),
timestamp,
},
);
}
// Evaluate whether to respond
let recent: Vec<String> = {
let convs = state.conversations.lock().await;
convs
.get_context(&room_id)
.iter()
.map(|m| format!("{}: {}", m.sender, m.content))
.collect()
};
let engagement = state
.evaluator
.evaluate(&sender, &body, is_dm, &recent, &state.mistral)
.await;
let (should_respond, is_spontaneous) = match engagement {
Engagement::MustRespond { reason } => {
info!(?reason, "Must respond");
(true, false)
}
Engagement::MaybeRespond { relevance, hook } => {
info!(relevance, hook = hook.as_str(), "Maybe respond (spontaneous)");
(true, true)
}
Engagement::Ignore => (false, false),
};
if !should_respond {
return Ok(());
}
// Show typing indicator
let _ = room.typing_notice(true).await;
let context = {
let convs = state.conversations.lock().await;
convs.get_context(&room_id)
};
let members = matrix_utils::room_member_names(&room).await;
let display_sender = sender_name.as_deref().unwrap_or(&sender);
let response = state
.responder
.generate_response(
&context,
&body,
display_sender,
&room_name,
&members,
is_spontaneous,
&state.mistral,
)
.await;
// Stop typing indicator
let _ = room.typing_notice(false).await;
if let Some(text) = response {
let content = matrix_utils::make_reply_content(&text, event.event_id.to_owned());
if let Err(e) = room.send(content).await {
error!("Failed to send response: {e}");
}
}
Ok(())
}
async fn handle_redaction(event: OriginalSyncRoomRedactionEvent, state: &AppState) {
if let Some(redacted_id) = &event.redacts {
state.indexer.update_redaction(&redacted_id.to_string()).await;
}
}
async fn handle_invite(event: StrippedRoomMemberEvent, room: Room) {
// Only handle our own invites
if event.state_key != room.own_user_id() {
return;
}
info!(room_id = %room.room_id(), "Received invite, auto-joining");
tokio::spawn(async move {
for attempt in 0..3u32 {
match room.join().await {
Ok(_) => {
info!(room_id = %room.room_id(), "Joined room");
return;
}
Err(e) => {
warn!(room_id = %room.room_id(), attempt, "Failed to join: {e}");
tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await;
}
}
}
error!(room_id = %room.room_id(), "Failed to join after retries");
});
}

151
src/tools/mod.rs Normal file
View File

@@ -0,0 +1,151 @@
pub mod room_history;
pub mod room_info;
pub mod search;
use std::sync::Arc;
use matrix_sdk::Client as MatrixClient;
use mistralai_client::v1::tool::Tool;
use opensearch::OpenSearch;
use serde_json::json;
use crate::config::Config;
pub struct ToolRegistry {
opensearch: OpenSearch,
matrix: MatrixClient,
config: Arc<Config>,
}
impl ToolRegistry {
pub fn new(opensearch: OpenSearch, matrix: MatrixClient, config: Arc<Config>) -> Self {
Self {
opensearch,
matrix,
config,
}
}
pub fn tool_definitions() -> Vec<Tool> {
vec![
Tool::new(
"search_archive".into(),
"Search the message archive. Use this to find past conversations, \
messages from specific people, or about specific topics."
.into(),
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Search query for message content"
},
"room": {
"type": "string",
"description": "Filter by room name (optional)"
},
"sender": {
"type": "string",
"description": "Filter by sender display name (optional)"
},
"after": {
"type": "string",
"description": "Unix timestamp in ms — only messages after this time (optional)"
},
"before": {
"type": "string",
"description": "Unix timestamp in ms — only messages before this time (optional)"
},
"limit": {
"type": "integer",
"description": "Max results to return (default 10)"
},
"semantic": {
"type": "boolean",
"description": "Use semantic search instead of keyword (optional)"
}
},
"required": ["query"]
}),
),
Tool::new(
"get_room_context".into(),
"Get messages around a specific point in time or event in a room. \
Useful for understanding the context of a conversation."
.into(),
json!({
"type": "object",
"properties": {
"room_id": {
"type": "string",
"description": "The Matrix room ID"
},
"around_timestamp": {
"type": "integer",
"description": "Unix timestamp in ms to center the context around"
},
"around_event_id": {
"type": "string",
"description": "Event ID to center the context around"
},
"before_count": {
"type": "integer",
"description": "Number of messages before the pivot (default 10)"
},
"after_count": {
"type": "integer",
"description": "Number of messages after the pivot (default 10)"
}
},
"required": ["room_id"]
}),
),
Tool::new(
"list_rooms".into(),
"List all rooms Sol is currently in, with names and member counts.".into(),
json!({
"type": "object",
"properties": {}
}),
),
Tool::new(
"get_room_members".into(),
"Get the list of members in a specific room.".into(),
json!({
"type": "object",
"properties": {
"room_id": {
"type": "string",
"description": "The Matrix room ID"
}
},
"required": ["room_id"]
}),
),
]
}
pub async fn execute(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
match name {
"search_archive" => {
search::search_archive(
&self.opensearch,
&self.config.opensearch.index,
arguments,
)
.await
}
"get_room_context" => {
room_history::get_room_context(
&self.opensearch,
&self.config.opensearch.index,
arguments,
)
.await
}
"list_rooms" => room_info::list_rooms(&self.matrix).await,
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await,
_ => anyhow::bail!("Unknown tool: {name}"),
}
}
}

108
src/tools/room_history.rs Normal file
View File

@@ -0,0 +1,108 @@
use opensearch::OpenSearch;
use serde::Deserialize;
use serde_json::json;
#[derive(Debug, Deserialize)]
pub struct RoomHistoryArgs {
pub room_id: String,
#[serde(default)]
pub around_timestamp: Option<i64>,
#[serde(default)]
pub around_event_id: Option<String>,
#[serde(default = "default_count")]
pub before_count: usize,
#[serde(default = "default_count")]
pub after_count: usize,
}
fn default_count() -> usize { 10 }
pub async fn get_room_context(
client: &OpenSearch,
index: &str,
args_json: &str,
) -> anyhow::Result<String> {
let args: RoomHistoryArgs = serde_json::from_str(args_json)?;
let total = args.before_count + args.after_count + 1;
// Determine the pivot timestamp
let pivot_ts = if let Some(ts) = args.around_timestamp {
ts
} else if let Some(ref event_id) = args.around_event_id {
// Look up the event to get its timestamp
let lookup = json!({
"size": 1,
"query": { "term": { "event_id": event_id } },
"_source": ["timestamp"]
});
let resp = client
.search(opensearch::SearchParts::Index(&[index]))
.body(lookup)
.send()
.await?;
let body: serde_json::Value = resp.json().await?;
body["hits"]["hits"][0]["_source"]["timestamp"]
.as_i64()
.ok_or_else(|| anyhow::anyhow!("Event {event_id} not found in archive"))?
} else {
anyhow::bail!("Either around_timestamp or around_event_id must be provided");
};
let query = json!({
"size": total,
"query": {
"bool": {
"must": [
{ "term": { "room_id": args.room_id } },
{ "term": { "redacted": false } }
],
"should": [
{
"range": {
"timestamp": {
"gte": pivot_ts - 3_600_000,
"lte": pivot_ts + 3_600_000
}
}
}
]
}
},
"sort": [{ "timestamp": "asc" }]
});
let response = client
.search(opensearch::SearchParts::Index(&[index]))
.body(query)
.send()
.await?;
let body: serde_json::Value = response.json().await?;
let hits = &body["hits"]["hits"];
let Some(hits_arr) = hits.as_array() else {
return Ok("No messages found around that point.".into());
};
if hits_arr.is_empty() {
return Ok("No messages found around that point.".into());
}
let mut output = String::new();
for hit in hits_arr {
let src = &hit["_source"];
let sender = src["sender_name"].as_str().unwrap_or("unknown");
let content = src["content"].as_str().unwrap_or("");
let ts = src["timestamp"].as_i64().unwrap_or(0);
let dt = chrono::DateTime::from_timestamp_millis(ts)
.map(|d| d.format("%H:%M").to_string())
.unwrap_or_else(|| "??:??".into());
output.push_str(&format!("[{dt}] {sender}: {content}\n"));
}
Ok(output)
}

55
src/tools/room_info.rs Normal file
View File

@@ -0,0 +1,55 @@
use matrix_sdk::Client;
use serde::Deserialize;
#[derive(Debug, Deserialize)]
pub struct ListRoomsArgs {}
#[derive(Debug, Deserialize)]
pub struct GetMembersArgs {
pub room_id: String,
}
pub async fn list_rooms(client: &Client) -> anyhow::Result<String> {
let rooms = client.joined_rooms();
if rooms.is_empty() {
return Ok("I'm not in any rooms.".into());
}
let mut output = String::new();
for room in &rooms {
let name = match room.cached_display_name() {
Some(n) => n.to_string(),
None => room.room_id().to_string(),
};
let id = room.room_id();
let members = room.joined_members_count();
output.push_str(&format!("- {name} ({id}) — {members} members\n"));
}
Ok(output)
}
pub async fn get_room_members(client: &Client, args_json: &str) -> anyhow::Result<String> {
let args: GetMembersArgs = serde_json::from_str(args_json)?;
let room_id = <&ruma::RoomId>::try_from(args.room_id.as_str())?;
let Some(room) = client.get_room(room_id) else {
anyhow::bail!("I'm not in room {}", args.room_id);
};
let members = room.members(matrix_sdk::RoomMemberships::JOIN).await?;
if members.is_empty() {
return Ok("No members found.".into());
}
let mut output = String::new();
for member in &members {
let display = member
.display_name()
.unwrap_or_else(|| member.user_id().as_str());
let user_id = member.user_id();
output.push_str(&format!("- {display} ({user_id})\n"));
}
Ok(output)
}

274
src/tools/search.rs Normal file
View File

@@ -0,0 +1,274 @@
use opensearch::OpenSearch;
use serde::Deserialize;
use serde_json::json;
use tracing::debug;
#[derive(Debug, Deserialize)]
pub struct SearchArgs {
pub query: String,
#[serde(default)]
pub room: Option<String>,
#[serde(default)]
pub sender: Option<String>,
#[serde(default)]
pub after: Option<String>,
#[serde(default)]
pub before: Option<String>,
#[serde(default = "default_limit")]
pub limit: usize,
#[serde(default)]
pub semantic: Option<bool>,
}
fn default_limit() -> usize { 10 }
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
let must = vec![json!({
"match": { "content": args.query }
})];
let mut filter = vec![json!({
"term": { "redacted": false }
})];
if let Some(ref room) = args.room {
filter.push(json!({ "term": { "room_name": room } }));
}
if let Some(ref sender) = args.sender {
filter.push(json!({ "term": { "sender_name": sender } }));
}
let mut range = serde_json::Map::new();
if let Some(ref after) = args.after {
if let Ok(ts) = after.parse::<i64>() {
range.insert("gte".into(), json!(ts));
}
}
if let Some(ref before) = args.before {
if let Ok(ts) = before.parse::<i64>() {
range.insert("lte".into(), json!(ts));
}
}
if !range.is_empty() {
filter.push(json!({ "range": { "timestamp": range } }));
}
json!({
"size": args.limit,
"query": {
"bool": {
"must": must,
"filter": filter
}
},
"sort": [{ "timestamp": "desc" }],
"_source": ["event_id", "room_name", "sender_name", "timestamp", "content"]
})
}
pub async fn search_archive(
client: &OpenSearch,
index: &str,
args_json: &str,
) -> anyhow::Result<String> {
let args: SearchArgs = serde_json::from_str(args_json)?;
debug!(query = args.query.as_str(), "Searching archive");
let query_body = build_search_query(&args);
let response = client
.search(opensearch::SearchParts::Index(&[index]))
.body(query_body)
.send()
.await?;
let body: serde_json::Value = response.json().await?;
let hits = &body["hits"]["hits"];
let Some(hits_arr) = hits.as_array() else {
return Ok("No results found.".into());
};
if hits_arr.is_empty() {
return Ok("No results found.".into());
}
let mut output = String::new();
for (i, hit) in hits_arr.iter().enumerate() {
let src = &hit["_source"];
let sender = src["sender_name"].as_str().unwrap_or("unknown");
let room = src["room_name"].as_str().unwrap_or("unknown");
let content = src["content"].as_str().unwrap_or("");
let ts = src["timestamp"].as_i64().unwrap_or(0);
let dt = chrono::DateTime::from_timestamp_millis(ts)
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
.unwrap_or_else(|| "unknown date".into());
output.push_str(&format!(
"{}. [{dt}] #{room}{sender}: {content}\n",
i + 1
));
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_args(json: &str) -> SearchArgs {
serde_json::from_str(json).unwrap()
}
#[test]
fn test_parse_minimal_args() {
let args = parse_args(r#"{"query": "hello"}"#);
assert_eq!(args.query, "hello");
assert!(args.room.is_none());
assert!(args.sender.is_none());
assert!(args.after.is_none());
assert!(args.before.is_none());
assert_eq!(args.limit, 10); // default
assert!(args.semantic.is_none());
}
#[test]
fn test_parse_full_args() {
let args = parse_args(r#"{
"query": "meeting notes",
"room": "general",
"sender": "Alice",
"after": "1710000000000",
"before": "1710100000000",
"limit": 25,
"semantic": true
}"#);
assert_eq!(args.query, "meeting notes");
assert_eq!(args.room.as_deref(), Some("general"));
assert_eq!(args.sender.as_deref(), Some("Alice"));
assert_eq!(args.after.as_deref(), Some("1710000000000"));
assert_eq!(args.before.as_deref(), Some("1710100000000"));
assert_eq!(args.limit, 25);
assert_eq!(args.semantic, Some(true));
}
#[test]
fn test_query_basic() {
let args = parse_args(r#"{"query": "test"}"#);
let q = build_search_query(&args);
assert_eq!(q["size"], 10);
assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test");
assert_eq!(q["query"]["bool"]["filter"][0]["term"]["redacted"], false);
assert_eq!(q["sort"][0]["timestamp"], "desc");
}
#[test]
fn test_query_with_room_filter() {
let args = parse_args(r#"{"query": "hello", "room": "design"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2);
assert_eq!(filters[1]["term"]["room_name"], "design");
}
#[test]
fn test_query_with_sender_filter() {
let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2);
assert_eq!(filters[1]["term"]["sender_name"], "Bob");
}
#[test]
fn test_query_with_room_and_sender() {
let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 3);
assert_eq!(filters[1]["term"]["room_name"], "dev");
assert_eq!(filters[2]["term"]["sender_name"], "Carol");
}
#[test]
fn test_query_with_date_range() {
let args = parse_args(r#"{
"query": "hello",
"after": "1710000000000",
"before": "1710100000000"
}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
let range_filter = &filters[1]["range"]["timestamp"];
assert_eq!(range_filter["gte"], 1710000000000_i64);
assert_eq!(range_filter["lte"], 1710100000000_i64);
}
#[test]
fn test_query_with_after_only() {
let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
let range_filter = &filters[1]["range"]["timestamp"];
assert_eq!(range_filter["gte"], 1710000000000_i64);
assert!(range_filter.get("lte").is_none());
}
#[test]
fn test_query_with_custom_limit() {
let args = parse_args(r#"{"query": "hello", "limit": 50}"#);
let q = build_search_query(&args);
assert_eq!(q["size"], 50);
}
#[test]
fn test_query_all_filters_combined() {
let args = parse_args(r#"{
"query": "architecture",
"room": "engineering",
"sender": "Sienna",
"after": "1000",
"before": "2000",
"limit": 5
}"#);
let q = build_search_query(&args);
assert_eq!(q["size"], 5);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// redacted=false, room, sender, range = 4 filters
assert_eq!(filters.len(), 4);
}
#[test]
fn test_invalid_timestamp_ignored() {
let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#);
let q = build_search_query(&args);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// Only the redacted filter, no range since parse failed
assert_eq!(filters.len(), 1);
}
#[test]
fn test_source_fields() {
let args = parse_args(r#"{"query": "test"}"#);
let q = build_search_query(&args);
let source = q["_source"].as_array().unwrap();
let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect();
assert!(fields.contains(&"event_id"));
assert!(fields.contains(&"room_name"));
assert!(fields.contains(&"sender_name"));
assert!(fields.contains(&"timestamp"));
assert!(fields.contains(&"content"));
}
}