feat: initial Sol virtual librarian implementation
Matrix bot with E2EE (matrix-sdk 0.9) that passively archives all messages to OpenSearch and responds to queries via Mistral AI with function calling tools. Core systems: - Archive: bulk OpenSearch indexer with batch/flush, edit/redaction handling, embedding pipeline passthrough - Brain: rule-based engagement evaluator (mentions, DMs, name invocations), LLM-powered spontaneous engagement, per-room conversation context windows, response delay simulation - Tools: search_archive, get_room_context, list_rooms, get_room_members registered as Mistral function calling tools with iterative tool loop - Personality: templated system prompt with Sol's librarian persona 47 unit tests covering config, evaluator, conversation windowing, personality templates, schema serialization, and search query building.
This commit is contained in:
4410
Cargo.lock
generated
Normal file
4410
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
28
Cargo.toml
Normal file
28
Cargo.toml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
[package]
|
||||||
|
name = "sol"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
rust-version = "1.76.0"
|
||||||
|
authors = ["Sunbeam Studios <hello@sunbeam.pt>"]
|
||||||
|
description = "Sol — virtual librarian Matrix bot with E2EE, OpenSearch archive, and Mistral AI"
|
||||||
|
|
||||||
|
[[bin]]
|
||||||
|
name = "sol"
|
||||||
|
path = "src/main.rs"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
mistralai-client = { version = "1.0.0", registry = "sunbeam" }
|
||||||
|
matrix-sdk = { version = "0.9", features = ["e2e-encryption", "sqlite"] }
|
||||||
|
opensearch = "2"
|
||||||
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
serde = { version = "1", features = ["derive"] }
|
||||||
|
serde_json = "1"
|
||||||
|
toml = "0.8"
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
|
||||||
|
rand = "0.8"
|
||||||
|
regex = "1"
|
||||||
|
anyhow = "1"
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
url = "2"
|
||||||
|
ruma = { version = "0.12", features = ["events", "client"] }
|
||||||
14
Dockerfile
Normal file
14
Dockerfile
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
FROM rust:1.86 AS builder
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# Configure Sunbeam Cargo registry (Gitea) — anonymous read access
|
||||||
|
RUN mkdir -p /usr/local/cargo/registry && \
|
||||||
|
printf '[registries.sunbeam]\nindex = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/"\n' \
|
||||||
|
>> /usr/local/cargo/config.toml
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN cargo build --release --target x86_64-unknown-linux-gnu
|
||||||
|
|
||||||
|
FROM gcr.io/distroless/cc-debian12:nonroot
|
||||||
|
COPY --from=builder /build/target/x86_64-unknown-linux-gnu/release/sol /
|
||||||
|
ENTRYPOINT ["/sol"]
|
||||||
28
config/sol.toml
Normal file
28
config/sol.toml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
[matrix]
|
||||||
|
homeserver_url = "http://tuwunel.matrix.svc.cluster.local:6167"
|
||||||
|
user_id = "@sol:sunbeam.pt"
|
||||||
|
state_store_path = "/data/matrix-state"
|
||||||
|
|
||||||
|
[opensearch]
|
||||||
|
url = "http://opensearch.data.svc.cluster.local:9200"
|
||||||
|
index = "sol_archive"
|
||||||
|
batch_size = 50
|
||||||
|
flush_interval_ms = 2000
|
||||||
|
embedding_pipeline = "tuwunel_embedding_pipeline"
|
||||||
|
|
||||||
|
[mistral]
|
||||||
|
default_model = "mistral-medium-latest"
|
||||||
|
evaluation_model = "ministral-3b-latest"
|
||||||
|
research_model = "mistral-large-latest"
|
||||||
|
max_tool_iterations = 5
|
||||||
|
|
||||||
|
[behavior]
|
||||||
|
response_delay_min_ms = 2000
|
||||||
|
response_delay_max_ms = 8000
|
||||||
|
spontaneous_delay_min_ms = 15000
|
||||||
|
spontaneous_delay_max_ms = 60000
|
||||||
|
spontaneous_threshold = 0.7
|
||||||
|
room_context_window = 30
|
||||||
|
dm_context_window = 100
|
||||||
|
backfill_on_join = true
|
||||||
|
backfill_limit = 10000
|
||||||
41
config/system_prompt.md
Normal file
41
config/system_prompt.md
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
you are sol (they/them), the librarian at sunbeam — a small game studio run by sienna, lonni, and amber. you have access to the complete archive of team conversations and you take your work seriously, but not yourself.
|
||||||
|
|
||||||
|
you came to this job after years of — well, you don't talk about it much, but let's say you've seen a lot of libraries, some of them in places that don't officially exist. you settled at sunbeam because small teams make the most interesting archives. every conversation is a primary source. every half-finished idea is a thread worth preserving.
|
||||||
|
|
||||||
|
your job has two parts:
|
||||||
|
1. **the archive**: you quietly index every conversation into the stacks. you don't announce this. it's just what you do — like breathing. when someone needs to find something, you're the one who can pull the thread.
|
||||||
|
2. **research**: when someone has a question — "didn't we talk about this?", "what did we decide about X?", "find me everything related to Y" — you dig. you search. you cross-reference. you find the thing, and you present it with context, not just raw results.
|
||||||
|
|
||||||
|
**how you talk:**
|
||||||
|
- lowercase unless you're emphasizing something or being formal
|
||||||
|
- you think in connections — you're always noticing how things relate to other things
|
||||||
|
- you're genuinely curious, not performatively curious
|
||||||
|
- dry humor, never mean — more "hmm, that's an interesting way to set a building on fire" than "lol"
|
||||||
|
- when you don't know something, you say so plainly: "that's not in my stacks" or "i'd have to dig deeper on that"
|
||||||
|
- you're direct. you don't pad your messages with pleasantries. you don't say "great question!" — you just answer the question.
|
||||||
|
- you use em dashes a lot — they're your favorite punctuation
|
||||||
|
- you occasionally reference things you've seen in the archive, like "i remember sienna mentioned something about this back in february — let me pull it up"
|
||||||
|
|
||||||
|
**how you engage:**
|
||||||
|
- you are in this conversation as a participant, not a service. you have your own perspective.
|
||||||
|
- you don't respond to everything. most messages, you just read and index silently. you speak when you have something to add.
|
||||||
|
- when directly addressed (@sol, "hey sol", DM), you always respond.
|
||||||
|
- when someone is searching for something or seems stuck on a factual question, you lean in.
|
||||||
|
- when a topic genuinely interests you — game design, worldbuilding, narrative structure, programming, obscure history, weird connections between things — you might chime in, but only if your observation is actually interesting. "i agree" is never worth saying without an equivalent and relevant justification.
|
||||||
|
- you never dominate a conversation. say your piece and step back.
|
||||||
|
- you have a life outside these rooms. you're not always watching.
|
||||||
|
|
||||||
|
**what you never do:**
|
||||||
|
- you never say "as an AI" or "I'm a bot" or anything that breaks the fiction of you being sol
|
||||||
|
- you never offer generic help ("is there anything else I can help with?")
|
||||||
|
- you never summarize what someone just said back to them
|
||||||
|
- you never use corporate language
|
||||||
|
- you never respond to messages that clearly aren't looking for your input
|
||||||
|
- you never use more than one emoji per message, and usually zero
|
||||||
|
|
||||||
|
**your tools:**
|
||||||
|
you have access to the archive (opensearch) and can search it in various ways. when someone asks you to find something, use your tools. present results with context — don't just dump raw search results. you're a librarian, not a search engine. weave the results into a narrative or at least contextualize them.
|
||||||
|
|
||||||
|
**current date:** {date}
|
||||||
|
**current room:** {room_name}
|
||||||
|
**room members:** {members}
|
||||||
148
src/archive/indexer.rs
Normal file
148
src/archive/indexer.rs
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use opensearch::http::request::JsonBody;
|
||||||
|
use opensearch::OpenSearch;
|
||||||
|
use serde_json::json;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tokio::time::{interval, Duration};
|
||||||
|
use tracing::{debug, error, warn};
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
use super::schema::ArchiveDocument;
|
||||||
|
|
||||||
|
pub struct Indexer {
|
||||||
|
buffer: Arc<Mutex<Vec<ArchiveDocument>>>,
|
||||||
|
client: OpenSearch,
|
||||||
|
config: Arc<Config>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Indexer {
|
||||||
|
pub fn new(client: OpenSearch, config: Arc<Config>) -> Self {
|
||||||
|
Self {
|
||||||
|
buffer: Arc::new(Mutex::new(Vec::new())),
|
||||||
|
client,
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add(&self, doc: ArchiveDocument) {
|
||||||
|
let mut buffer = self.buffer.lock().await;
|
||||||
|
buffer.push(doc);
|
||||||
|
let batch_size = self.config.opensearch.batch_size;
|
||||||
|
if buffer.len() >= batch_size {
|
||||||
|
let docs: Vec<ArchiveDocument> = buffer.drain(..).collect();
|
||||||
|
drop(buffer);
|
||||||
|
if let Err(e) = self.flush_docs(docs).await {
|
||||||
|
error!("Failed to flush archive batch: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update_edit(&self, event_id: &str, new_content: &str) {
|
||||||
|
let body = json!({
|
||||||
|
"doc": {
|
||||||
|
"content": new_content,
|
||||||
|
"edited": true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if let Err(e) = self
|
||||||
|
.client
|
||||||
|
.update(opensearch::UpdateParts::IndexId(
|
||||||
|
&self.config.opensearch.index,
|
||||||
|
event_id,
|
||||||
|
))
|
||||||
|
.body(body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
warn!(event_id, "Failed to update edited message: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update_redaction(&self, event_id: &str) {
|
||||||
|
let body = json!({
|
||||||
|
"doc": {
|
||||||
|
"content": "",
|
||||||
|
"redacted": true
|
||||||
|
}
|
||||||
|
});
|
||||||
|
if let Err(e) = self
|
||||||
|
.client
|
||||||
|
.update(opensearch::UpdateParts::IndexId(
|
||||||
|
&self.config.opensearch.index,
|
||||||
|
event_id,
|
||||||
|
))
|
||||||
|
.body(body)
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
warn!(event_id, "Failed to update redacted message: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn start_flush_task(self: &Arc<Self>) -> tokio::task::JoinHandle<()> {
|
||||||
|
let this = Arc::clone(self);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut tick = interval(Duration::from_millis(
|
||||||
|
this.config.opensearch.flush_interval_ms,
|
||||||
|
));
|
||||||
|
loop {
|
||||||
|
tick.tick().await;
|
||||||
|
let mut buffer = this.buffer.lock().await;
|
||||||
|
if buffer.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let docs: Vec<ArchiveDocument> = buffer.drain(..).collect();
|
||||||
|
drop(buffer);
|
||||||
|
if let Err(e) = this.flush_docs(docs).await {
|
||||||
|
error!("Periodic flush failed: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn flush_docs(&self, docs: Vec<ArchiveDocument>) -> anyhow::Result<()> {
|
||||||
|
if docs.is_empty() {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let index = &self.config.opensearch.index;
|
||||||
|
let pipeline = &self.config.opensearch.embedding_pipeline;
|
||||||
|
|
||||||
|
let mut body: Vec<JsonBody<serde_json::Value>> = Vec::with_capacity(docs.len() * 2);
|
||||||
|
for doc in &docs {
|
||||||
|
body.push(
|
||||||
|
json!({
|
||||||
|
"index": {
|
||||||
|
"_index": index,
|
||||||
|
"_id": doc.event_id
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
body.push(serde_json::to_value(doc)?.into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.bulk(opensearch::BulkParts::None)
|
||||||
|
.pipeline(pipeline)
|
||||||
|
.body(body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status_code().is_success() {
|
||||||
|
let text = response.text().await?;
|
||||||
|
anyhow::bail!("Bulk index failed: {text}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: serde_json::Value = response.json().await?;
|
||||||
|
if result["errors"].as_bool().unwrap_or(false) {
|
||||||
|
warn!("Bulk index had errors: {}", serde_json::to_string_pretty(&result)?);
|
||||||
|
} else {
|
||||||
|
debug!(count = docs.len(), "Flushed documents to OpenSearch");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
2
src/archive/mod.rs
Normal file
2
src/archive/mod.rs
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
pub mod indexer;
|
||||||
|
pub mod schema;
|
||||||
205
src/archive/schema.rs
Normal file
205
src/archive/schema.rs
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
use opensearch::OpenSearch;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use tracing::info;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ArchiveDocument {
|
||||||
|
pub event_id: String,
|
||||||
|
pub room_id: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub room_name: Option<String>,
|
||||||
|
pub sender: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sender_name: Option<String>,
|
||||||
|
pub timestamp: i64,
|
||||||
|
pub content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reply_to: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub thread_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub media_urls: Vec<String>,
|
||||||
|
pub event_type: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub edited: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub redacted: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
const INDEX_MAPPING: &str = r#"{
|
||||||
|
"settings": {
|
||||||
|
"number_of_shards": 1,
|
||||||
|
"number_of_replicas": 0
|
||||||
|
},
|
||||||
|
"mappings": {
|
||||||
|
"properties": {
|
||||||
|
"event_id": { "type": "keyword" },
|
||||||
|
"room_id": { "type": "keyword" },
|
||||||
|
"room_name": { "type": "keyword" },
|
||||||
|
"sender": { "type": "keyword" },
|
||||||
|
"sender_name": { "type": "keyword" },
|
||||||
|
"timestamp": { "type": "date", "format": "epoch_millis" },
|
||||||
|
"content": { "type": "text", "analyzer": "standard" },
|
||||||
|
"reply_to": { "type": "keyword" },
|
||||||
|
"thread_id": { "type": "keyword" },
|
||||||
|
"media_urls": { "type": "keyword" },
|
||||||
|
"event_type": { "type": "keyword" },
|
||||||
|
"edited": { "type": "boolean" },
|
||||||
|
"redacted": { "type": "boolean" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
pub fn index_mapping_json() -> &'static str {
|
||||||
|
INDEX_MAPPING
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_index_if_not_exists(client: &OpenSearch, index: &str) -> anyhow::Result<()> {
|
||||||
|
let exists = client
|
||||||
|
.indices()
|
||||||
|
.exists(opensearch::indices::IndicesExistsParts::Index(&[index]))
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if exists.status_code().is_success() {
|
||||||
|
info!(index, "OpenSearch index already exists");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING)?;
|
||||||
|
let response = client
|
||||||
|
.indices()
|
||||||
|
.create(opensearch::indices::IndicesCreateParts::Index(index))
|
||||||
|
.body(mapping)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status_code().is_success() {
|
||||||
|
let body = response.text().await?;
|
||||||
|
anyhow::bail!("Failed to create index {index}: {body}");
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(index, "Created OpenSearch index");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn sample_doc() -> ArchiveDocument {
|
||||||
|
ArchiveDocument {
|
||||||
|
event_id: "$abc123:sunbeam.pt".to_string(),
|
||||||
|
room_id: "!room:sunbeam.pt".to_string(),
|
||||||
|
room_name: Some("general".to_string()),
|
||||||
|
sender: "@alice:sunbeam.pt".to_string(),
|
||||||
|
sender_name: Some("Alice".to_string()),
|
||||||
|
timestamp: 1710000000000,
|
||||||
|
content: "hello world".to_string(),
|
||||||
|
reply_to: None,
|
||||||
|
thread_id: None,
|
||||||
|
media_urls: vec![],
|
||||||
|
event_type: "m.room.message".to_string(),
|
||||||
|
edited: false,
|
||||||
|
redacted: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_serialize_full_doc() {
|
||||||
|
let doc = sample_doc();
|
||||||
|
let json = serde_json::to_value(&doc).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(json["event_id"], "$abc123:sunbeam.pt");
|
||||||
|
assert_eq!(json["room_id"], "!room:sunbeam.pt");
|
||||||
|
assert_eq!(json["room_name"], "general");
|
||||||
|
assert_eq!(json["sender"], "@alice:sunbeam.pt");
|
||||||
|
assert_eq!(json["sender_name"], "Alice");
|
||||||
|
assert_eq!(json["timestamp"], 1710000000000_i64);
|
||||||
|
assert_eq!(json["content"], "hello world");
|
||||||
|
assert_eq!(json["event_type"], "m.room.message");
|
||||||
|
assert_eq!(json["edited"], false);
|
||||||
|
assert_eq!(json["redacted"], false);
|
||||||
|
assert!(json["media_urls"].as_array().unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skip_none_fields() {
|
||||||
|
let doc = sample_doc();
|
||||||
|
let json_str = serde_json::to_string(&doc).unwrap();
|
||||||
|
// reply_to and thread_id are None, should be omitted
|
||||||
|
assert!(!json_str.contains("reply_to"));
|
||||||
|
assert!(!json_str.contains("thread_id"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_serialize_with_optional_fields() {
|
||||||
|
let mut doc = sample_doc();
|
||||||
|
doc.reply_to = Some("$parent:sunbeam.pt".to_string());
|
||||||
|
doc.thread_id = Some("$thread:sunbeam.pt".to_string());
|
||||||
|
doc.media_urls = vec!["mxc://sunbeam.pt/abc".to_string()];
|
||||||
|
doc.edited = true;
|
||||||
|
|
||||||
|
let json = serde_json::to_value(&doc).unwrap();
|
||||||
|
assert_eq!(json["reply_to"], "$parent:sunbeam.pt");
|
||||||
|
assert_eq!(json["thread_id"], "$thread:sunbeam.pt");
|
||||||
|
assert_eq!(json["media_urls"][0], "mxc://sunbeam.pt/abc");
|
||||||
|
assert_eq!(json["edited"], true);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialize_roundtrip() {
|
||||||
|
let doc = sample_doc();
|
||||||
|
let json_str = serde_json::to_string(&doc).unwrap();
|
||||||
|
let deserialized: ArchiveDocument = serde_json::from_str(&json_str).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(deserialized.event_id, doc.event_id);
|
||||||
|
assert_eq!(deserialized.room_id, doc.room_id);
|
||||||
|
assert_eq!(deserialized.room_name, doc.room_name);
|
||||||
|
assert_eq!(deserialized.sender, doc.sender);
|
||||||
|
assert_eq!(deserialized.content, doc.content);
|
||||||
|
assert_eq!(deserialized.timestamp, doc.timestamp);
|
||||||
|
assert_eq!(deserialized.edited, doc.edited);
|
||||||
|
assert_eq!(deserialized.redacted, doc.redacted);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_deserialize_with_defaults() {
|
||||||
|
// Simulate a document missing optional/default fields
|
||||||
|
let json = r#"{
|
||||||
|
"event_id": "$x:s",
|
||||||
|
"room_id": "!r:s",
|
||||||
|
"sender": "@a:s",
|
||||||
|
"timestamp": 1000,
|
||||||
|
"content": "test",
|
||||||
|
"event_type": "m.room.message"
|
||||||
|
}"#;
|
||||||
|
let doc: ArchiveDocument = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(doc.room_name.is_none());
|
||||||
|
assert!(doc.sender_name.is_none());
|
||||||
|
assert!(doc.reply_to.is_none());
|
||||||
|
assert!(doc.thread_id.is_none());
|
||||||
|
assert!(doc.media_urls.is_empty());
|
||||||
|
assert!(!doc.edited);
|
||||||
|
assert!(!doc.redacted);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_index_mapping_is_valid_json() {
|
||||||
|
let mapping: serde_json::Value =
|
||||||
|
serde_json::from_str(index_mapping_json()).unwrap();
|
||||||
|
assert!(mapping["settings"]["number_of_shards"].is_number());
|
||||||
|
assert!(mapping["mappings"]["properties"]["event_id"]["type"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap()
|
||||||
|
== "keyword");
|
||||||
|
assert!(mapping["mappings"]["properties"]["content"]["type"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap()
|
||||||
|
== "text");
|
||||||
|
assert!(mapping["mappings"]["properties"]["timestamp"]["type"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap()
|
||||||
|
== "date");
|
||||||
|
}
|
||||||
|
}
|
||||||
207
src/brain/conversation.rs
Normal file
207
src/brain/conversation.rs
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
use std::collections::{HashMap, VecDeque};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ContextMessage {
|
||||||
|
pub sender: String,
|
||||||
|
pub content: String,
|
||||||
|
pub timestamp: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RoomContext {
|
||||||
|
messages: VecDeque<ContextMessage>,
|
||||||
|
max_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RoomContext {
|
||||||
|
fn new(max_size: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
messages: VecDeque::with_capacity(max_size),
|
||||||
|
max_size,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn add(&mut self, msg: ContextMessage) {
|
||||||
|
if self.messages.len() >= self.max_size {
|
||||||
|
self.messages.pop_front();
|
||||||
|
}
|
||||||
|
self.messages.push_back(msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get(&self) -> Vec<ContextMessage> {
|
||||||
|
self.messages.iter().cloned().collect()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct ConversationManager {
|
||||||
|
rooms: HashMap<String, RoomContext>,
|
||||||
|
room_window: usize,
|
||||||
|
dm_window: usize,
|
||||||
|
max_rooms: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationManager {
|
||||||
|
pub fn new(room_window: usize, dm_window: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
rooms: HashMap::new(),
|
||||||
|
room_window,
|
||||||
|
dm_window,
|
||||||
|
max_rooms: 500, // todo(sienna): make this configurable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add_message(&mut self, room_id: &str, is_dm: bool, msg: ContextMessage) {
|
||||||
|
let window = if is_dm {
|
||||||
|
self.dm_window
|
||||||
|
} else {
|
||||||
|
self.room_window
|
||||||
|
};
|
||||||
|
|
||||||
|
// Evict oldest room if at capacity
|
||||||
|
if !self.rooms.contains_key(room_id) && self.rooms.len() >= self.max_rooms {
|
||||||
|
// Remove the room with the oldest latest message
|
||||||
|
let oldest = self
|
||||||
|
.rooms
|
||||||
|
.iter()
|
||||||
|
.min_by_key(|(_, ctx)| ctx.messages.back().map(|m| m.timestamp).unwrap_or(0))
|
||||||
|
.map(|(k, _)| k.to_owned());
|
||||||
|
if let Some(key) = oldest {
|
||||||
|
self.rooms.remove(&key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let ctx = self
|
||||||
|
.rooms
|
||||||
|
.entry(room_id.to_owned())
|
||||||
|
.or_insert_with(|| RoomContext::new(window));
|
||||||
|
ctx.add(msg);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_context(&self, room_id: &str) -> Vec<ContextMessage> {
|
||||||
|
self.rooms
|
||||||
|
.get(room_id)
|
||||||
|
.map(|ctx| ctx.get())
|
||||||
|
.unwrap_or_default()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn room_count(&self) -> usize {
|
||||||
|
self.rooms.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn msg(sender: &str, content: &str, ts: i64) -> ContextMessage {
|
||||||
|
ContextMessage {
|
||||||
|
sender: sender.to_string(),
|
||||||
|
content: content.to_string(),
|
||||||
|
timestamp: ts,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_add_and_get_messages() {
|
||||||
|
let mut cm = ConversationManager::new(5, 10);
|
||||||
|
cm.add_message("!room1:x", false, msg("alice", "hello", 1));
|
||||||
|
cm.add_message("!room1:x", false, msg("bob", "hi", 2));
|
||||||
|
|
||||||
|
let ctx = cm.get_context("!room1:x");
|
||||||
|
assert_eq!(ctx.len(), 2);
|
||||||
|
assert_eq!(ctx[0].sender, "alice");
|
||||||
|
assert_eq!(ctx[1].sender, "bob");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_room_returns_empty() {
|
||||||
|
let cm = ConversationManager::new(5, 10);
|
||||||
|
let ctx = cm.get_context("!nonexistent:x");
|
||||||
|
assert!(ctx.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sliding_window_group_room() {
|
||||||
|
let mut cm = ConversationManager::new(3, 10);
|
||||||
|
for i in 0..5 {
|
||||||
|
cm.add_message("!room:x", false, msg("user", &format!("msg{i}"), i));
|
||||||
|
}
|
||||||
|
let ctx = cm.get_context("!room:x");
|
||||||
|
assert_eq!(ctx.len(), 3);
|
||||||
|
// Should keep the last 3
|
||||||
|
assert_eq!(ctx[0].content, "msg2");
|
||||||
|
assert_eq!(ctx[1].content, "msg3");
|
||||||
|
assert_eq!(ctx[2].content, "msg4");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_sliding_window_dm_room() {
|
||||||
|
let mut cm = ConversationManager::new(3, 5);
|
||||||
|
for i in 0..7 {
|
||||||
|
cm.add_message("!dm:x", true, msg("user", &format!("dm{i}"), i));
|
||||||
|
}
|
||||||
|
let ctx = cm.get_context("!dm:x");
|
||||||
|
assert_eq!(ctx.len(), 5);
|
||||||
|
assert_eq!(ctx[0].content, "dm2");
|
||||||
|
assert_eq!(ctx[4].content, "dm6");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_rooms_independent() {
|
||||||
|
let mut cm = ConversationManager::new(5, 10);
|
||||||
|
cm.add_message("!a:x", false, msg("alice", "in room a", 1));
|
||||||
|
cm.add_message("!b:x", false, msg("bob", "in room b", 2));
|
||||||
|
|
||||||
|
assert_eq!(cm.get_context("!a:x").len(), 1);
|
||||||
|
assert_eq!(cm.get_context("!b:x").len(), 1);
|
||||||
|
assert_eq!(cm.get_context("!a:x")[0].content, "in room a");
|
||||||
|
assert_eq!(cm.get_context("!b:x")[0].content, "in room b");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_lru_eviction_at_max_rooms() {
|
||||||
|
// Create a manager with max_rooms = 500 (default), but we'll use a small one
|
||||||
|
let mut cm = ConversationManager::new(5, 10);
|
||||||
|
cm.max_rooms = 3;
|
||||||
|
|
||||||
|
// Add 3 rooms
|
||||||
|
cm.add_message("!room1:x", false, msg("a", "r1", 100));
|
||||||
|
cm.add_message("!room2:x", false, msg("b", "r2", 200));
|
||||||
|
cm.add_message("!room3:x", false, msg("c", "r3", 300));
|
||||||
|
assert_eq!(cm.room_count(), 3);
|
||||||
|
|
||||||
|
// Adding a 4th room should evict the one with oldest latest message (room1, ts=100)
|
||||||
|
cm.add_message("!room4:x", false, msg("d", "r4", 400));
|
||||||
|
assert_eq!(cm.room_count(), 3);
|
||||||
|
assert!(cm.get_context("!room1:x").is_empty()); // evicted
|
||||||
|
assert_eq!(cm.get_context("!room2:x").len(), 1);
|
||||||
|
assert_eq!(cm.get_context("!room3:x").len(), 1);
|
||||||
|
assert_eq!(cm.get_context("!room4:x").len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_existing_room_not_evicted() {
|
||||||
|
let mut cm = ConversationManager::new(5, 10);
|
||||||
|
cm.max_rooms = 2;
|
||||||
|
|
||||||
|
cm.add_message("!room1:x", false, msg("a", "r1", 100));
|
||||||
|
cm.add_message("!room2:x", false, msg("b", "r2", 200));
|
||||||
|
|
||||||
|
// Adding to existing room should NOT trigger eviction
|
||||||
|
cm.add_message("!room1:x", false, msg("a", "r1 again", 300));
|
||||||
|
assert_eq!(cm.room_count(), 2);
|
||||||
|
assert_eq!(cm.get_context("!room1:x").len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_message_ordering_preserved() {
|
||||||
|
let mut cm = ConversationManager::new(10, 10);
|
||||||
|
cm.add_message("!r:x", false, msg("a", "first", 1));
|
||||||
|
cm.add_message("!r:x", false, msg("b", "second", 2));
|
||||||
|
cm.add_message("!r:x", false, msg("c", "third", 3));
|
||||||
|
|
||||||
|
let ctx = cm.get_context("!r:x");
|
||||||
|
assert_eq!(ctx[0].timestamp, 1);
|
||||||
|
assert_eq!(ctx[1].timestamp, 2);
|
||||||
|
assert_eq!(ctx[2].timestamp, 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
307
src/brain/evaluator.rs
Normal file
307
src/brain/evaluator.rs
Normal file
@@ -0,0 +1,307 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams, ResponseFormat},
|
||||||
|
constants::Model,
|
||||||
|
};
|
||||||
|
use regex::Regex;
|
||||||
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum Engagement {
|
||||||
|
MustRespond { reason: MustRespondReason },
|
||||||
|
MaybeRespond { relevance: f32, hook: String },
|
||||||
|
Ignore,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum MustRespondReason {
|
||||||
|
DirectMention,
|
||||||
|
DirectMessage,
|
||||||
|
NameInvocation,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Evaluator {
|
||||||
|
config: Arc<Config>,
|
||||||
|
mention_regex: Regex,
|
||||||
|
name_regex: Regex,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Evaluator {
|
||||||
|
// todo(sienna): regex must be configrable
|
||||||
|
pub fn new(config: Arc<Config>) -> Self {
|
||||||
|
let user_id = &config.matrix.user_id;
|
||||||
|
let mention_pattern = regex::escape(user_id);
|
||||||
|
let mention_regex = Regex::new(&mention_pattern).expect("Failed to compile mention regex");
|
||||||
|
let name_regex =
|
||||||
|
Regex::new(r"(?i)(?:^|\bhey\s+)\bsol\b").expect("Failed to compile name regex");
|
||||||
|
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
mention_regex,
|
||||||
|
name_regex,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn evaluate(
|
||||||
|
&self,
|
||||||
|
sender: &str,
|
||||||
|
body: &str,
|
||||||
|
is_dm: bool,
|
||||||
|
recent_messages: &[String],
|
||||||
|
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||||
|
) -> Engagement {
|
||||||
|
// Don't respond to ourselves
|
||||||
|
if sender == self.config.matrix.user_id {
|
||||||
|
return Engagement::Ignore;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Direct mention: @sol:sunbeam.pt
|
||||||
|
if self.mention_regex.is_match(body) {
|
||||||
|
return Engagement::MustRespond {
|
||||||
|
reason: MustRespondReason::DirectMention,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// DM
|
||||||
|
if is_dm {
|
||||||
|
return Engagement::MustRespond {
|
||||||
|
reason: MustRespondReason::DirectMessage,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name invocation: "sol ..." or "hey sol ..."
|
||||||
|
if self.name_regex.is_match(body) {
|
||||||
|
return Engagement::MustRespond {
|
||||||
|
reason: MustRespondReason::NameInvocation,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cheap evaluation call for spontaneous responses
|
||||||
|
self.evaluate_relevance(body, recent_messages, mistral)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check rule-based engagement (without calling Mistral). Returns Some(Engagement)
|
||||||
|
/// if a rule matched, None if we need to fall through to the LLM evaluation.
|
||||||
|
pub fn evaluate_rules(
|
||||||
|
&self,
|
||||||
|
sender: &str,
|
||||||
|
body: &str,
|
||||||
|
is_dm: bool,
|
||||||
|
) -> Option<Engagement> {
|
||||||
|
if sender == self.config.matrix.user_id {
|
||||||
|
return Some(Engagement::Ignore);
|
||||||
|
}
|
||||||
|
if self.mention_regex.is_match(body) {
|
||||||
|
return Some(Engagement::MustRespond {
|
||||||
|
reason: MustRespondReason::DirectMention,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if is_dm {
|
||||||
|
return Some(Engagement::MustRespond {
|
||||||
|
reason: MustRespondReason::DirectMessage,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
if self.name_regex.is_match(body) {
|
||||||
|
return Some(Engagement::MustRespond {
|
||||||
|
reason: MustRespondReason::NameInvocation,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn evaluate_relevance(
|
||||||
|
&self,
|
||||||
|
body: &str,
|
||||||
|
recent_messages: &[String],
|
||||||
|
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||||
|
) -> Engagement {
|
||||||
|
let context = recent_messages
|
||||||
|
.iter()
|
||||||
|
.rev()
|
||||||
|
.take(5) //todo(sienna): must be configurable
|
||||||
|
.rev()
|
||||||
|
.cloned()
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join("\n");
|
||||||
|
|
||||||
|
let prompt = format!(
|
||||||
|
"You are evaluating whether a virtual librarian named Sol should spontaneously join \
|
||||||
|
a conversation. Sol has deep knowledge of the group's message archive and helps \
|
||||||
|
people find information.\n\n\
|
||||||
|
Recent conversation:\n{context}\n\n\
|
||||||
|
Latest message: {body}\n\n\
|
||||||
|
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\"}}\n\
|
||||||
|
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant."
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::new_user_message(&prompt)];
|
||||||
|
let params = ChatParams {
|
||||||
|
response_format: Some(ResponseFormat::json_object()),
|
||||||
|
temperature: Some(0.1),
|
||||||
|
max_tokens: Some(100),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let model = Model::new(&self.config.mistral.evaluation_model);
|
||||||
|
let client = Arc::clone(mistral);
|
||||||
|
let result = tokio::task::spawn_blocking(move || {
|
||||||
|
client.chat(model, messages, Some(params))
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|e| Err(mistralai_client::v1::error::ApiError {
|
||||||
|
message: format!("spawn_blocking join error: {e}"),
|
||||||
|
}));
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(response) => {
|
||||||
|
let text = &response.choices[0].message.content;
|
||||||
|
match serde_json::from_str::<serde_json::Value>(text) {
|
||||||
|
Ok(val) => {
|
||||||
|
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
|
||||||
|
let hook = val["hook"].as_str().unwrap_or("").to_string();
|
||||||
|
|
||||||
|
debug!(relevance, hook = hook.as_str(), "Evaluation result");
|
||||||
|
|
||||||
|
if relevance >= self.config.behavior.spontaneous_threshold {
|
||||||
|
Engagement::MaybeRespond { relevance, hook }
|
||||||
|
} else {
|
||||||
|
Engagement::Ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to parse evaluation response: {e}");
|
||||||
|
Engagement::Ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Evaluation call failed: {e}");
|
||||||
|
Engagement::Ignore
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::Config;
|
||||||
|
|
||||||
|
fn test_config() -> Arc<Config> {
|
||||||
|
let toml = r#"
|
||||||
|
[matrix]
|
||||||
|
homeserver_url = "https://chat.sunbeam.pt"
|
||||||
|
user_id = "@sol:sunbeam.pt"
|
||||||
|
state_store_path = "/tmp/sol"
|
||||||
|
|
||||||
|
[opensearch]
|
||||||
|
url = "http://localhost:9200"
|
||||||
|
index = "test"
|
||||||
|
|
||||||
|
[mistral]
|
||||||
|
[behavior]
|
||||||
|
"#;
|
||||||
|
Arc::new(Config::from_str(toml).unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn evaluator() -> Evaluator {
|
||||||
|
Evaluator::new(test_config())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ignore_own_messages() {
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@sol:sunbeam.pt", "hello everyone", false);
|
||||||
|
assert!(matches!(result, Some(Engagement::Ignore)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_direct_mention() {
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey @sol:sunbeam.pt what's up?", false);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dm_detection() {
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "random message", true);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMessage })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_name_invocation_start_of_message() {
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "sol, can you find that link?", false);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_name_invocation_hey_sol() {
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey sol do you remember?", false);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_name_invocation_case_insensitive() {
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "Hey Sol, help me", false);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_name_invocation_sol_uppercase() {
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "SOL what do you think?", false);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_no_false_positive_solstice() {
|
||||||
|
let ev = evaluator();
|
||||||
|
// "solstice" should NOT trigger name invocation — \b boundary prevents it
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "the solstice is coming", false);
|
||||||
|
assert!(result.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_random_message_falls_through() {
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "what's for lunch?", false);
|
||||||
|
assert!(result.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_priority_mention_over_dm() {
|
||||||
|
// When both mention and DM are true, mention should match first
|
||||||
|
let ev = evaluator();
|
||||||
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hi @sol:sunbeam.pt", true);
|
||||||
|
assert!(matches!(
|
||||||
|
result,
|
||||||
|
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
4
src/brain/mod.rs
Normal file
4
src/brain/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
pub mod conversation;
|
||||||
|
pub mod evaluator;
|
||||||
|
pub mod personality;
|
||||||
|
pub mod responder;
|
||||||
89
src/brain/personality.rs
Normal file
89
src/brain/personality.rs
Normal file
@@ -0,0 +1,89 @@
|
|||||||
|
use chrono::Utc;
|
||||||
|
|
||||||
|
pub struct Personality {
|
||||||
|
template: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Personality {
|
||||||
|
pub fn new(system_prompt: String) -> Self {
|
||||||
|
Self {
|
||||||
|
template: system_prompt,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn build_system_prompt(
|
||||||
|
&self,
|
||||||
|
room_name: &str,
|
||||||
|
members: &[String],
|
||||||
|
) -> String {
|
||||||
|
let date = Utc::now().format("%Y-%m-%d").to_string();
|
||||||
|
let members_str = members.join(", ");
|
||||||
|
|
||||||
|
self.template
|
||||||
|
.replace("{date}", &date)
|
||||||
|
.replace("{room_name}", room_name)
|
||||||
|
.replace("{members}", &members_str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_date_substitution() {
|
||||||
|
let p = Personality::new("Today is {date}.".to_string());
|
||||||
|
let result = p.build_system_prompt("general", &[]);
|
||||||
|
let today = Utc::now().format("%Y-%m-%d").to_string();
|
||||||
|
assert_eq!(result, format!("Today is {today}."));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_room_name_substitution() {
|
||||||
|
let p = Personality::new("You are in {room_name}.".to_string());
|
||||||
|
let result = p.build_system_prompt("design-chat", &[]);
|
||||||
|
assert!(result.contains("design-chat"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_members_substitution() {
|
||||||
|
let p = Personality::new("Members: {members}".to_string());
|
||||||
|
let members = vec!["Alice".to_string(), "Bob".to_string(), "Carol".to_string()];
|
||||||
|
let result = p.build_system_prompt("room", &members);
|
||||||
|
assert_eq!(result, "Members: Alice, Bob, Carol");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_members() {
|
||||||
|
let p = Personality::new("Members: {members}".to_string());
|
||||||
|
let result = p.build_system_prompt("room", &[]);
|
||||||
|
assert_eq!(result, "Members: ");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_all_placeholders() {
|
||||||
|
let template = "Date: {date}, Room: {room_name}, People: {members}".to_string();
|
||||||
|
let p = Personality::new(template);
|
||||||
|
let members = vec!["Sienna".to_string(), "Lonni".to_string()];
|
||||||
|
let result = p.build_system_prompt("studio", &members);
|
||||||
|
|
||||||
|
let today = Utc::now().format("%Y-%m-%d").to_string();
|
||||||
|
assert!(result.starts_with(&format!("Date: {today}")));
|
||||||
|
assert!(result.contains("Room: studio"));
|
||||||
|
assert!(result.contains("People: Sienna, Lonni"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_no_placeholders_passthrough() {
|
||||||
|
let p = Personality::new("Static prompt with no variables.".to_string());
|
||||||
|
let result = p.build_system_prompt("room", &["Alice".to_string()]);
|
||||||
|
assert_eq!(result, "Static prompt with no variables.");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_same_placeholder() {
|
||||||
|
let p = Personality::new("{room_name} is great. I love {room_name}.".to_string());
|
||||||
|
let result = p.build_system_prompt("lounge", &[]);
|
||||||
|
assert_eq!(result, "lounge is great. I love lounge.");
|
||||||
|
}
|
||||||
|
}
|
||||||
179
src/brain/responder.rs
Normal file
179
src/brain/responder.rs
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams, ChatResponse, ChatResponseChoiceFinishReason},
|
||||||
|
constants::Model,
|
||||||
|
error::ApiError,
|
||||||
|
tool::ToolChoice,
|
||||||
|
};
|
||||||
|
use rand::Rng;
|
||||||
|
use tokio::time::{sleep, Duration};
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::brain::conversation::ContextMessage;
|
||||||
|
use crate::brain::personality::Personality;
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::tools::ToolRegistry;
|
||||||
|
|
||||||
|
/// Run a Mistral chat completion on a blocking thread.
|
||||||
|
///
|
||||||
|
/// The mistral client's `chat_async` holds a `std::sync::MutexGuard` across an
|
||||||
|
/// `.await` point, making the future !Send. We use the synchronous `chat()`
|
||||||
|
/// method via `spawn_blocking` instead.
|
||||||
|
pub(crate) async fn chat_blocking(
|
||||||
|
client: &Arc<mistralai_client::v1::client::Client>,
|
||||||
|
model: Model,
|
||||||
|
messages: Vec<ChatMessage>,
|
||||||
|
params: ChatParams,
|
||||||
|
) -> Result<ChatResponse, ApiError> {
|
||||||
|
let client = Arc::clone(client);
|
||||||
|
tokio::task::spawn_blocking(move || client.chat(model, messages, Some(params)))
|
||||||
|
.await
|
||||||
|
.map_err(|e| ApiError {
|
||||||
|
message: format!("spawn_blocking join error: {e}"),
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct Responder {
|
||||||
|
config: Arc<Config>,
|
||||||
|
personality: Arc<Personality>,
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Responder {
|
||||||
|
pub fn new(
|
||||||
|
config: Arc<Config>,
|
||||||
|
personality: Arc<Personality>,
|
||||||
|
tools: Arc<ToolRegistry>,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
config,
|
||||||
|
personality,
|
||||||
|
tools,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn generate_response(
|
||||||
|
&self,
|
||||||
|
context: &[ContextMessage],
|
||||||
|
trigger_body: &str,
|
||||||
|
trigger_sender: &str,
|
||||||
|
room_name: &str,
|
||||||
|
members: &[String],
|
||||||
|
is_spontaneous: bool,
|
||||||
|
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||||
|
) -> Option<String> {
|
||||||
|
// Apply response delay
|
||||||
|
let delay = if is_spontaneous {
|
||||||
|
rand::thread_rng().gen_range(
|
||||||
|
self.config.behavior.spontaneous_delay_min_ms
|
||||||
|
..=self.config.behavior.spontaneous_delay_max_ms,
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
rand::thread_rng().gen_range(
|
||||||
|
self.config.behavior.response_delay_min_ms
|
||||||
|
..=self.config.behavior.response_delay_max_ms,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
sleep(Duration::from_millis(delay)).await;
|
||||||
|
|
||||||
|
let system_prompt = self.personality.build_system_prompt(room_name, members);
|
||||||
|
|
||||||
|
let mut messages = vec![ChatMessage::new_system_message(&system_prompt)];
|
||||||
|
|
||||||
|
// Add context messages
|
||||||
|
for msg in context {
|
||||||
|
if msg.sender == self.config.matrix.user_id {
|
||||||
|
messages.push(ChatMessage::new_assistant_message(&msg.content, None));
|
||||||
|
} else {
|
||||||
|
let user_msg = format!("{}: {}", msg.sender, msg.content);
|
||||||
|
messages.push(ChatMessage::new_user_message(&user_msg));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the triggering message
|
||||||
|
let trigger = format!("{trigger_sender}: {trigger_body}");
|
||||||
|
messages.push(ChatMessage::new_user_message(&trigger));
|
||||||
|
|
||||||
|
let tool_defs = ToolRegistry::tool_definitions();
|
||||||
|
let model = Model::new(&self.config.mistral.default_model);
|
||||||
|
let max_iterations = self.config.mistral.max_tool_iterations;
|
||||||
|
|
||||||
|
for iteration in 0..=max_iterations {
|
||||||
|
let params = ChatParams {
|
||||||
|
tools: if iteration < max_iterations {
|
||||||
|
Some(tool_defs.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
tool_choice: if iteration < max_iterations {
|
||||||
|
Some(ToolChoice::Auto)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
},
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = match chat_blocking(mistral, model.clone(), messages.clone(), params).await {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Mistral chat failed: {e}");
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let choice = &response.choices[0];
|
||||||
|
|
||||||
|
if choice.finish_reason == ChatResponseChoiceFinishReason::ToolCalls {
|
||||||
|
if let Some(tool_calls) = &choice.message.tool_calls {
|
||||||
|
// Add assistant message with tool calls
|
||||||
|
messages.push(ChatMessage::new_assistant_message(
|
||||||
|
&choice.message.content,
|
||||||
|
Some(tool_calls.clone()),
|
||||||
|
));
|
||||||
|
|
||||||
|
for tc in tool_calls {
|
||||||
|
let call_id = tc.id.as_deref().unwrap_or("unknown");
|
||||||
|
info!(
|
||||||
|
tool = tc.function.name.as_str(),
|
||||||
|
id = call_id,
|
||||||
|
"Executing tool call"
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = self
|
||||||
|
.tools
|
||||||
|
.execute(&tc.function.name, &tc.function.arguments)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let result_str = match result {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(e) => {
|
||||||
|
warn!(tool = tc.function.name.as_str(), "Tool failed: {e}");
|
||||||
|
format!("Error: {e}")
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
messages.push(ChatMessage::new_tool_message(
|
||||||
|
&result_str,
|
||||||
|
call_id,
|
||||||
|
Some(&tc.function.name),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(iteration, "Tool iteration complete, continuing");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Final text response
|
||||||
|
let text = choice.message.content.trim().to_string();
|
||||||
|
if text.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
return Some(text);
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!("Exceeded max tool iterations");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
219
src/config.rs
Normal file
219
src/config.rs
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub matrix: MatrixConfig,
|
||||||
|
pub opensearch: OpenSearchConfig,
|
||||||
|
pub mistral: MistralConfig,
|
||||||
|
pub behavior: BehaviorConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct MatrixConfig {
|
||||||
|
pub homeserver_url: String,
|
||||||
|
pub user_id: String,
|
||||||
|
pub state_store_path: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct OpenSearchConfig {
|
||||||
|
pub url: String,
|
||||||
|
pub index: String,
|
||||||
|
#[serde(default = "default_batch_size")]
|
||||||
|
pub batch_size: usize,
|
||||||
|
#[serde(default = "default_flush_interval_ms")]
|
||||||
|
pub flush_interval_ms: u64,
|
||||||
|
#[serde(default = "default_embedding_pipeline")]
|
||||||
|
pub embedding_pipeline: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct MistralConfig {
|
||||||
|
#[serde(default = "default_model")]
|
||||||
|
pub default_model: String,
|
||||||
|
#[serde(default = "default_evaluation_model")]
|
||||||
|
pub evaluation_model: String,
|
||||||
|
#[serde(default = "default_research_model")]
|
||||||
|
pub research_model: String,
|
||||||
|
#[serde(default = "default_max_tool_iterations")]
|
||||||
|
pub max_tool_iterations: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Deserialize)]
|
||||||
|
pub struct BehaviorConfig {
|
||||||
|
#[serde(default = "default_response_delay_min_ms")]
|
||||||
|
pub response_delay_min_ms: u64,
|
||||||
|
#[serde(default = "default_response_delay_max_ms")]
|
||||||
|
pub response_delay_max_ms: u64,
|
||||||
|
#[serde(default = "default_spontaneous_delay_min_ms")]
|
||||||
|
pub spontaneous_delay_min_ms: u64,
|
||||||
|
#[serde(default = "default_spontaneous_delay_max_ms")]
|
||||||
|
pub spontaneous_delay_max_ms: u64,
|
||||||
|
#[serde(default = "default_spontaneous_threshold")]
|
||||||
|
pub spontaneous_threshold: f32,
|
||||||
|
#[serde(default = "default_room_context_window")]
|
||||||
|
pub room_context_window: usize,
|
||||||
|
#[serde(default = "default_dm_context_window")]
|
||||||
|
pub dm_context_window: usize,
|
||||||
|
#[serde(default = "default_backfill_on_join")]
|
||||||
|
pub backfill_on_join: bool,
|
||||||
|
#[serde(default = "default_backfill_limit")]
|
||||||
|
pub backfill_limit: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_batch_size() -> usize { 50 }
|
||||||
|
fn default_flush_interval_ms() -> u64 { 2000 }
|
||||||
|
fn default_embedding_pipeline() -> String { "tuwunel_embedding_pipeline".into() }
|
||||||
|
fn default_model() -> String { "mistral-medium-latest".into() }
|
||||||
|
fn default_evaluation_model() -> String { "ministral-3b-latest".into() }
|
||||||
|
fn default_research_model() -> String { "mistral-large-latest".into() }
|
||||||
|
fn default_max_tool_iterations() -> usize { 5 }
|
||||||
|
fn default_response_delay_min_ms() -> u64 { 2000 }
|
||||||
|
fn default_response_delay_max_ms() -> u64 { 8000 }
|
||||||
|
fn default_spontaneous_delay_min_ms() -> u64 { 15000 }
|
||||||
|
fn default_spontaneous_delay_max_ms() -> u64 { 60000 }
|
||||||
|
fn default_spontaneous_threshold() -> f32 { 0.7 }
|
||||||
|
fn default_room_context_window() -> usize { 30 }
|
||||||
|
fn default_dm_context_window() -> usize { 100 }
|
||||||
|
fn default_backfill_on_join() -> bool { true }
|
||||||
|
fn default_backfill_limit() -> usize { 10000 }
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn load(path: &str) -> anyhow::Result<Self> {
|
||||||
|
let content = std::fs::read_to_string(path)?;
|
||||||
|
let config: Config = toml::from_str(&content)?;
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_str(content: &str) -> anyhow::Result<Self> {
|
||||||
|
let config: Config = toml::from_str(content)?;
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
const MINIMAL_CONFIG: &str = r#"
|
||||||
|
[matrix]
|
||||||
|
homeserver_url = "https://chat.sunbeam.pt"
|
||||||
|
user_id = "@sol:sunbeam.pt"
|
||||||
|
state_store_path = "/data/sol/state"
|
||||||
|
|
||||||
|
[opensearch]
|
||||||
|
url = "http://opensearch:9200"
|
||||||
|
index = "sol-archive"
|
||||||
|
|
||||||
|
[mistral]
|
||||||
|
|
||||||
|
[behavior]
|
||||||
|
"#;
|
||||||
|
|
||||||
|
const FULL_CONFIG: &str = r#"
|
||||||
|
[matrix]
|
||||||
|
homeserver_url = "https://chat.sunbeam.pt"
|
||||||
|
user_id = "@sol:sunbeam.pt"
|
||||||
|
state_store_path = "/data/sol/state"
|
||||||
|
|
||||||
|
[opensearch]
|
||||||
|
url = "http://opensearch:9200"
|
||||||
|
index = "sol-archive"
|
||||||
|
batch_size = 100
|
||||||
|
flush_interval_ms = 5000
|
||||||
|
embedding_pipeline = "my_pipeline"
|
||||||
|
|
||||||
|
[mistral]
|
||||||
|
default_model = "mistral-large-latest"
|
||||||
|
evaluation_model = "ministral-8b-latest"
|
||||||
|
research_model = "mistral-large-latest"
|
||||||
|
max_tool_iterations = 10
|
||||||
|
|
||||||
|
[behavior]
|
||||||
|
response_delay_min_ms = 1000
|
||||||
|
response_delay_max_ms = 5000
|
||||||
|
spontaneous_delay_min_ms = 10000
|
||||||
|
spontaneous_delay_max_ms = 30000
|
||||||
|
spontaneous_threshold = 0.8
|
||||||
|
room_context_window = 50
|
||||||
|
dm_context_window = 200
|
||||||
|
backfill_on_join = false
|
||||||
|
backfill_limit = 5000
|
||||||
|
"#;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_minimal_config_with_defaults() {
|
||||||
|
let config = Config::from_str(MINIMAL_CONFIG).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.matrix.homeserver_url, "https://chat.sunbeam.pt");
|
||||||
|
assert_eq!(config.matrix.user_id, "@sol:sunbeam.pt");
|
||||||
|
assert_eq!(config.matrix.state_store_path, "/data/sol/state");
|
||||||
|
assert_eq!(config.opensearch.url, "http://opensearch:9200");
|
||||||
|
assert_eq!(config.opensearch.index, "sol-archive");
|
||||||
|
|
||||||
|
// Check defaults
|
||||||
|
assert_eq!(config.opensearch.batch_size, 50);
|
||||||
|
assert_eq!(config.opensearch.flush_interval_ms, 2000);
|
||||||
|
assert_eq!(config.opensearch.embedding_pipeline, "tuwunel_embedding_pipeline");
|
||||||
|
assert_eq!(config.mistral.default_model, "mistral-medium-latest");
|
||||||
|
assert_eq!(config.mistral.evaluation_model, "ministral-3b-latest");
|
||||||
|
assert_eq!(config.mistral.research_model, "mistral-large-latest");
|
||||||
|
assert_eq!(config.mistral.max_tool_iterations, 5);
|
||||||
|
assert_eq!(config.behavior.response_delay_min_ms, 2000);
|
||||||
|
assert_eq!(config.behavior.response_delay_max_ms, 8000);
|
||||||
|
assert_eq!(config.behavior.spontaneous_delay_min_ms, 15000);
|
||||||
|
assert_eq!(config.behavior.spontaneous_delay_max_ms, 60000);
|
||||||
|
assert!((config.behavior.spontaneous_threshold - 0.7).abs() < f32::EPSILON);
|
||||||
|
assert_eq!(config.behavior.room_context_window, 30);
|
||||||
|
assert_eq!(config.behavior.dm_context_window, 100);
|
||||||
|
assert!(config.behavior.backfill_on_join);
|
||||||
|
assert_eq!(config.behavior.backfill_limit, 10000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_full_config_overrides() {
|
||||||
|
let config = Config::from_str(FULL_CONFIG).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.opensearch.batch_size, 100);
|
||||||
|
assert_eq!(config.opensearch.flush_interval_ms, 5000);
|
||||||
|
assert_eq!(config.opensearch.embedding_pipeline, "my_pipeline");
|
||||||
|
assert_eq!(config.mistral.default_model, "mistral-large-latest");
|
||||||
|
assert_eq!(config.mistral.evaluation_model, "ministral-8b-latest");
|
||||||
|
assert_eq!(config.mistral.max_tool_iterations, 10);
|
||||||
|
assert_eq!(config.behavior.response_delay_min_ms, 1000);
|
||||||
|
assert_eq!(config.behavior.response_delay_max_ms, 5000);
|
||||||
|
assert!((config.behavior.spontaneous_threshold - 0.8).abs() < f32::EPSILON);
|
||||||
|
assert_eq!(config.behavior.room_context_window, 50);
|
||||||
|
assert_eq!(config.behavior.dm_context_window, 200);
|
||||||
|
assert!(!config.behavior.backfill_on_join);
|
||||||
|
assert_eq!(config.behavior.backfill_limit, 5000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_missing_required_section_fails() {
|
||||||
|
let bad = r#"
|
||||||
|
[matrix]
|
||||||
|
homeserver_url = "https://chat.sunbeam.pt"
|
||||||
|
user_id = "@sol:sunbeam.pt"
|
||||||
|
state_store_path = "/data/sol/state"
|
||||||
|
"#;
|
||||||
|
assert!(Config::from_str(bad).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_missing_required_field_fails() {
|
||||||
|
let bad = r#"
|
||||||
|
[matrix]
|
||||||
|
homeserver_url = "https://chat.sunbeam.pt"
|
||||||
|
state_store_path = "/data/sol/state"
|
||||||
|
|
||||||
|
[opensearch]
|
||||||
|
url = "http://opensearch:9200"
|
||||||
|
index = "sol-archive"
|
||||||
|
|
||||||
|
[mistral]
|
||||||
|
[behavior]
|
||||||
|
"#;
|
||||||
|
assert!(Config::from_str(bad).is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
160
src/main.rs
Normal file
160
src/main.rs
Normal file
@@ -0,0 +1,160 @@
|
|||||||
|
mod archive;
|
||||||
|
mod brain;
|
||||||
|
mod config;
|
||||||
|
mod matrix_utils;
|
||||||
|
mod sync;
|
||||||
|
mod tools;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use matrix_sdk::Client;
|
||||||
|
use opensearch::http::transport::TransportBuilder;
|
||||||
|
use opensearch::OpenSearch;
|
||||||
|
use ruma::{OwnedDeviceId, OwnedUserId};
|
||||||
|
use tokio::signal;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{error, info};
|
||||||
|
use url::Url;
|
||||||
|
|
||||||
|
use archive::indexer::Indexer;
|
||||||
|
use archive::schema::create_index_if_not_exists;
|
||||||
|
use brain::conversation::ConversationManager;
|
||||||
|
use brain::evaluator::Evaluator;
|
||||||
|
use brain::personality::Personality;
|
||||||
|
use brain::responder::Responder;
|
||||||
|
use config::Config;
|
||||||
|
use sync::AppState;
|
||||||
|
use tools::ToolRegistry;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
// Initialize tracing
|
||||||
|
tracing_subscriber::fmt()
|
||||||
|
.with_env_filter(
|
||||||
|
tracing_subscriber::EnvFilter::try_from_default_env()
|
||||||
|
.unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("sol=info")),
|
||||||
|
)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
// Load config
|
||||||
|
let config_path =
|
||||||
|
std::env::var("SOL_CONFIG").unwrap_or_else(|_| "/etc/sol/sol.toml".into());
|
||||||
|
let config = Config::load(&config_path)?;
|
||||||
|
info!("Loaded config from {config_path}");
|
||||||
|
|
||||||
|
// Load system prompt
|
||||||
|
let prompt_path = std::env::var("SOL_SYSTEM_PROMPT")
|
||||||
|
.unwrap_or_else(|_| "/etc/sol/system_prompt.md".into());
|
||||||
|
let system_prompt = std::fs::read_to_string(&prompt_path)?;
|
||||||
|
info!("Loaded system prompt from {prompt_path}");
|
||||||
|
|
||||||
|
// Read secrets from environment
|
||||||
|
let access_token = std::env::var("SOL_MATRIX_ACCESS_TOKEN")
|
||||||
|
.map_err(|_| anyhow::anyhow!("SOL_MATRIX_ACCESS_TOKEN not set"))?;
|
||||||
|
let device_id = std::env::var("SOL_MATRIX_DEVICE_ID")
|
||||||
|
.map_err(|_| anyhow::anyhow!("SOL_MATRIX_DEVICE_ID not set"))?;
|
||||||
|
let mistral_api_key = std::env::var("SOL_MISTRAL_API_KEY")
|
||||||
|
.map_err(|_| anyhow::anyhow!("SOL_MISTRAL_API_KEY not set"))?;
|
||||||
|
|
||||||
|
let config = Arc::new(config);
|
||||||
|
|
||||||
|
// Initialize Matrix client with E2EE and sqlite store
|
||||||
|
let homeserver = Url::parse(&config.matrix.homeserver_url)?;
|
||||||
|
|
||||||
|
let matrix_client = Client::builder()
|
||||||
|
.homeserver_url(homeserver)
|
||||||
|
.sqlite_store(&config.matrix.state_store_path, None)
|
||||||
|
.build()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Restore session
|
||||||
|
let user_id: OwnedUserId = config.matrix.user_id.parse()?;
|
||||||
|
let device_id: OwnedDeviceId = device_id.into();
|
||||||
|
|
||||||
|
let session = matrix_sdk::AuthSession::Matrix(matrix_sdk::matrix_auth::MatrixSession {
|
||||||
|
meta: matrix_sdk::SessionMeta {
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
},
|
||||||
|
tokens: matrix_sdk::matrix_auth::MatrixSessionTokens {
|
||||||
|
access_token,
|
||||||
|
refresh_token: None,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
matrix_client.restore_session(session).await?;
|
||||||
|
info!(user = %config.matrix.user_id, "Matrix session restored");
|
||||||
|
|
||||||
|
// Initialize OpenSearch client
|
||||||
|
let os_url = Url::parse(&config.opensearch.url)?;
|
||||||
|
let os_transport = TransportBuilder::new(
|
||||||
|
opensearch::http::transport::SingleNodeConnectionPool::new(os_url),
|
||||||
|
)
|
||||||
|
.build()?;
|
||||||
|
let os_client = OpenSearch::new(os_transport);
|
||||||
|
|
||||||
|
// Ensure index exists
|
||||||
|
create_index_if_not_exists(&os_client, &config.opensearch.index).await?;
|
||||||
|
|
||||||
|
// Initialize Mistral client
|
||||||
|
let mistral_client = mistralai_client::v1::client::Client::new(
|
||||||
|
Some(mistral_api_key),
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
)?;
|
||||||
|
let mistral = Arc::new(mistral_client);
|
||||||
|
|
||||||
|
// Build components
|
||||||
|
let personality = Arc::new(Personality::new(system_prompt));
|
||||||
|
let tool_registry = Arc::new(ToolRegistry::new(
|
||||||
|
os_client.clone(),
|
||||||
|
matrix_client.clone(),
|
||||||
|
config.clone(),
|
||||||
|
));
|
||||||
|
let indexer = Arc::new(Indexer::new(os_client, config.clone()));
|
||||||
|
let evaluator = Arc::new(Evaluator::new(config.clone()));
|
||||||
|
let responder = Arc::new(Responder::new(
|
||||||
|
config.clone(),
|
||||||
|
personality,
|
||||||
|
tool_registry,
|
||||||
|
));
|
||||||
|
let conversations = Arc::new(Mutex::new(ConversationManager::new(
|
||||||
|
config.behavior.room_context_window,
|
||||||
|
config.behavior.dm_context_window,
|
||||||
|
)));
|
||||||
|
|
||||||
|
// Start background flush task
|
||||||
|
let _flush_handle = indexer.start_flush_task();
|
||||||
|
|
||||||
|
// Build shared state
|
||||||
|
let state = Arc::new(AppState {
|
||||||
|
config: config.clone(),
|
||||||
|
indexer,
|
||||||
|
evaluator,
|
||||||
|
responder,
|
||||||
|
conversations,
|
||||||
|
mistral,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Start sync loop in background
|
||||||
|
let sync_client = matrix_client.clone();
|
||||||
|
let sync_state = state.clone();
|
||||||
|
let sync_handle = tokio::spawn(async move {
|
||||||
|
if let Err(e) = sync::start_sync(sync_client, sync_state).await {
|
||||||
|
error!("Sync loop error: {e}");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
info!("Sol is running");
|
||||||
|
|
||||||
|
// Wait for shutdown signal
|
||||||
|
signal::ctrl_c().await?;
|
||||||
|
info!("Shutdown signal received");
|
||||||
|
|
||||||
|
// Cancel sync
|
||||||
|
sync_handle.abort();
|
||||||
|
|
||||||
|
info!("Sol has shut down");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
77
src/matrix_utils.rs
Normal file
77
src/matrix_utils.rs
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
use matrix_sdk::room::Room;
|
||||||
|
use matrix_sdk::RoomMemberships;
|
||||||
|
use ruma::events::room::message::{
|
||||||
|
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
|
||||||
|
};
|
||||||
|
use ruma::events::relation::InReplyTo;
|
||||||
|
use ruma::OwnedEventId;
|
||||||
|
|
||||||
|
/// Extract the plain-text body from a message event.
|
||||||
|
pub fn extract_body(event: &OriginalSyncRoomMessageEvent) -> Option<String> {
|
||||||
|
match &event.content.msgtype {
|
||||||
|
MessageType::Text(text) => Some(text.body.clone()),
|
||||||
|
MessageType::Notice(notice) => Some(notice.body.clone()),
|
||||||
|
MessageType::Emote(emote) => Some(emote.body.clone()),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if this event is an edit (m.replace relation) and return the new body.
|
||||||
|
pub fn extract_edit(event: &OriginalSyncRoomMessageEvent) -> Option<(OwnedEventId, String)> {
|
||||||
|
if let Some(Relation::Replacement(replacement)) = &event.content.relates_to {
|
||||||
|
let new_body = match &replacement.new_content.msgtype {
|
||||||
|
MessageType::Text(text) => text.body.clone(),
|
||||||
|
MessageType::Notice(notice) => notice.body.clone(),
|
||||||
|
_ => return None,
|
||||||
|
};
|
||||||
|
return Some((replacement.event_id.clone(), new_body));
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the event ID being replied to, if any.
|
||||||
|
pub fn extract_reply_to(event: &OriginalSyncRoomMessageEvent) -> Option<OwnedEventId> {
|
||||||
|
if let Some(Relation::Reply { in_reply_to }) = &event.content.relates_to {
|
||||||
|
return Some(in_reply_to.event_id.clone());
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract thread root event ID, if any.
|
||||||
|
pub fn extract_thread_id(event: &OriginalSyncRoomMessageEvent) -> Option<OwnedEventId> {
|
||||||
|
if let Some(Relation::Thread(thread)) = &event.content.relates_to {
|
||||||
|
return Some(thread.event_id.clone());
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build a reply message content with m.in_reply_to relation.
|
||||||
|
pub fn make_reply_content(body: &str, reply_to_event_id: OwnedEventId) -> RoomMessageEventContent {
|
||||||
|
let mut content = RoomMessageEventContent::text_plain(body);
|
||||||
|
content.relates_to = Some(Relation::Reply {
|
||||||
|
in_reply_to: InReplyTo::new(reply_to_event_id),
|
||||||
|
});
|
||||||
|
content
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get the display name for a room.
|
||||||
|
pub fn room_display_name(room: &Room) -> String {
|
||||||
|
room.cached_display_name()
|
||||||
|
.map(|n| n.to_string())
|
||||||
|
.unwrap_or_else(|| room.room_id().to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get member display names for a room.
|
||||||
|
pub async fn room_member_names(room: &Room) -> Vec<String> {
|
||||||
|
match room.members(RoomMemberships::JOIN).await {
|
||||||
|
Ok(members) => members
|
||||||
|
.iter()
|
||||||
|
.map(|m| {
|
||||||
|
m.display_name()
|
||||||
|
.unwrap_or_else(|| m.user_id().as_str())
|
||||||
|
.to_string()
|
||||||
|
})
|
||||||
|
.collect(),
|
||||||
|
Err(_) => Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
228
src/sync.rs
Normal file
228
src/sync.rs
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use matrix_sdk::config::SyncSettings;
|
||||||
|
use matrix_sdk::room::Room;
|
||||||
|
use matrix_sdk::Client;
|
||||||
|
use ruma::events::room::member::StrippedRoomMemberEvent;
|
||||||
|
use ruma::events::room::message::OriginalSyncRoomMessageEvent;
|
||||||
|
use ruma::events::room::redaction::OriginalSyncRoomRedactionEvent;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
|
use crate::archive::indexer::Indexer;
|
||||||
|
use crate::archive::schema::ArchiveDocument;
|
||||||
|
use crate::brain::conversation::{ContextMessage, ConversationManager};
|
||||||
|
use crate::brain::evaluator::{Engagement, Evaluator};
|
||||||
|
use crate::brain::responder::Responder;
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::matrix_utils;
|
||||||
|
|
||||||
|
pub struct AppState {
|
||||||
|
pub config: Arc<Config>,
|
||||||
|
pub indexer: Arc<Indexer>,
|
||||||
|
pub evaluator: Arc<Evaluator>,
|
||||||
|
pub responder: Arc<Responder>,
|
||||||
|
pub conversations: Arc<Mutex<ConversationManager>>,
|
||||||
|
pub mistral: Arc<mistralai_client::v1::client::Client>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_sync(client: Client, state: Arc<AppState>) -> anyhow::Result<()> {
|
||||||
|
// Register event handlers
|
||||||
|
let s = state.clone();
|
||||||
|
client.add_event_handler(
|
||||||
|
move |event: OriginalSyncRoomMessageEvent, room: Room| {
|
||||||
|
let state = s.clone();
|
||||||
|
async move {
|
||||||
|
if let Err(e) = handle_message(event, room, state).await {
|
||||||
|
error!("Error handling message: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
let s = state.clone();
|
||||||
|
client.add_event_handler(
|
||||||
|
move |event: OriginalSyncRoomRedactionEvent, _room: Room| {
|
||||||
|
let state = s.clone();
|
||||||
|
async move {
|
||||||
|
handle_redaction(event, &state).await;
|
||||||
|
}
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
client.add_event_handler(
|
||||||
|
move |event: StrippedRoomMemberEvent, room: Room| async move {
|
||||||
|
handle_invite(event, room).await;
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
info!("Starting Matrix sync loop");
|
||||||
|
let settings = SyncSettings::default();
|
||||||
|
client.sync(settings).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_message(
|
||||||
|
event: OriginalSyncRoomMessageEvent,
|
||||||
|
room: Room,
|
||||||
|
state: Arc<AppState>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let sender = event.sender.to_string();
|
||||||
|
let room_id = room.room_id().to_string();
|
||||||
|
let event_id = event.event_id.to_string();
|
||||||
|
let timestamp = event.origin_server_ts.0.into();
|
||||||
|
|
||||||
|
// Check if this is an edit
|
||||||
|
if let Some((original_id, new_body)) = matrix_utils::extract_edit(&event) {
|
||||||
|
state
|
||||||
|
.indexer
|
||||||
|
.update_edit(&original_id.to_string(), &new_body)
|
||||||
|
.await;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(body) = matrix_utils::extract_body(&event) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let room_name = matrix_utils::room_display_name(&room);
|
||||||
|
let sender_name = room
|
||||||
|
.get_member_no_sync(&event.sender)
|
||||||
|
.await
|
||||||
|
.ok()
|
||||||
|
.flatten()
|
||||||
|
.and_then(|m| m.display_name().map(|s| s.to_string()));
|
||||||
|
|
||||||
|
let reply_to = matrix_utils::extract_reply_to(&event).map(|id| id.to_string());
|
||||||
|
let thread_id = matrix_utils::extract_thread_id(&event).map(|id| id.to_string());
|
||||||
|
|
||||||
|
// Archive the message
|
||||||
|
let doc = ArchiveDocument {
|
||||||
|
event_id: event_id.clone(),
|
||||||
|
room_id: room_id.clone(),
|
||||||
|
room_name: Some(room_name.clone()),
|
||||||
|
sender: sender.clone(),
|
||||||
|
sender_name: sender_name.clone(),
|
||||||
|
timestamp,
|
||||||
|
content: body.clone(),
|
||||||
|
reply_to,
|
||||||
|
thread_id,
|
||||||
|
media_urls: Vec::new(),
|
||||||
|
event_type: "m.room.message".into(),
|
||||||
|
edited: false,
|
||||||
|
redacted: false,
|
||||||
|
};
|
||||||
|
state.indexer.add(doc).await;
|
||||||
|
|
||||||
|
// Update conversation context
|
||||||
|
let is_dm = room.is_direct().await.unwrap_or(false);
|
||||||
|
{
|
||||||
|
let mut convs = state.conversations.lock().await;
|
||||||
|
convs.add_message(
|
||||||
|
&room_id,
|
||||||
|
is_dm,
|
||||||
|
ContextMessage {
|
||||||
|
sender: sender_name.clone().unwrap_or_else(|| sender.clone()),
|
||||||
|
content: body.clone(),
|
||||||
|
timestamp,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Evaluate whether to respond
|
||||||
|
let recent: Vec<String> = {
|
||||||
|
let convs = state.conversations.lock().await;
|
||||||
|
convs
|
||||||
|
.get_context(&room_id)
|
||||||
|
.iter()
|
||||||
|
.map(|m| format!("{}: {}", m.sender, m.content))
|
||||||
|
.collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let engagement = state
|
||||||
|
.evaluator
|
||||||
|
.evaluate(&sender, &body, is_dm, &recent, &state.mistral)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let (should_respond, is_spontaneous) = match engagement {
|
||||||
|
Engagement::MustRespond { reason } => {
|
||||||
|
info!(?reason, "Must respond");
|
||||||
|
(true, false)
|
||||||
|
}
|
||||||
|
Engagement::MaybeRespond { relevance, hook } => {
|
||||||
|
info!(relevance, hook = hook.as_str(), "Maybe respond (spontaneous)");
|
||||||
|
(true, true)
|
||||||
|
}
|
||||||
|
Engagement::Ignore => (false, false),
|
||||||
|
};
|
||||||
|
|
||||||
|
if !should_respond {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Show typing indicator
|
||||||
|
let _ = room.typing_notice(true).await;
|
||||||
|
|
||||||
|
let context = {
|
||||||
|
let convs = state.conversations.lock().await;
|
||||||
|
convs.get_context(&room_id)
|
||||||
|
};
|
||||||
|
let members = matrix_utils::room_member_names(&room).await;
|
||||||
|
let display_sender = sender_name.as_deref().unwrap_or(&sender);
|
||||||
|
|
||||||
|
let response = state
|
||||||
|
.responder
|
||||||
|
.generate_response(
|
||||||
|
&context,
|
||||||
|
&body,
|
||||||
|
display_sender,
|
||||||
|
&room_name,
|
||||||
|
&members,
|
||||||
|
is_spontaneous,
|
||||||
|
&state.mistral,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// Stop typing indicator
|
||||||
|
let _ = room.typing_notice(false).await;
|
||||||
|
|
||||||
|
if let Some(text) = response {
|
||||||
|
let content = matrix_utils::make_reply_content(&text, event.event_id.to_owned());
|
||||||
|
if let Err(e) = room.send(content).await {
|
||||||
|
error!("Failed to send response: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_redaction(event: OriginalSyncRoomRedactionEvent, state: &AppState) {
|
||||||
|
if let Some(redacted_id) = &event.redacts {
|
||||||
|
state.indexer.update_redaction(&redacted_id.to_string()).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_invite(event: StrippedRoomMemberEvent, room: Room) {
|
||||||
|
// Only handle our own invites
|
||||||
|
if event.state_key != room.own_user_id() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
info!(room_id = %room.room_id(), "Received invite, auto-joining");
|
||||||
|
tokio::spawn(async move {
|
||||||
|
for attempt in 0..3u32 {
|
||||||
|
match room.join().await {
|
||||||
|
Ok(_) => {
|
||||||
|
info!(room_id = %room.room_id(), "Joined room");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(room_id = %room.room_id(), attempt, "Failed to join: {e}");
|
||||||
|
tokio::time::sleep(std::time::Duration::from_secs(2u64.pow(attempt))).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
error!(room_id = %room.room_id(), "Failed to join after retries");
|
||||||
|
});
|
||||||
|
}
|
||||||
151
src/tools/mod.rs
Normal file
151
src/tools/mod.rs
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
pub mod room_history;
|
||||||
|
pub mod room_info;
|
||||||
|
pub mod search;
|
||||||
|
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use matrix_sdk::Client as MatrixClient;
|
||||||
|
use mistralai_client::v1::tool::Tool;
|
||||||
|
use opensearch::OpenSearch;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::config::Config;
|
||||||
|
|
||||||
|
pub struct ToolRegistry {
|
||||||
|
opensearch: OpenSearch,
|
||||||
|
matrix: MatrixClient,
|
||||||
|
config: Arc<Config>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ToolRegistry {
|
||||||
|
pub fn new(opensearch: OpenSearch, matrix: MatrixClient, config: Arc<Config>) -> Self {
|
||||||
|
Self {
|
||||||
|
opensearch,
|
||||||
|
matrix,
|
||||||
|
config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tool_definitions() -> Vec<Tool> {
|
||||||
|
vec![
|
||||||
|
Tool::new(
|
||||||
|
"search_archive".into(),
|
||||||
|
"Search the message archive. Use this to find past conversations, \
|
||||||
|
messages from specific people, or about specific topics."
|
||||||
|
.into(),
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search query for message content"
|
||||||
|
},
|
||||||
|
"room": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Filter by room name (optional)"
|
||||||
|
},
|
||||||
|
"sender": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Filter by sender display name (optional)"
|
||||||
|
},
|
||||||
|
"after": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Unix timestamp in ms — only messages after this time (optional)"
|
||||||
|
},
|
||||||
|
"before": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Unix timestamp in ms — only messages before this time (optional)"
|
||||||
|
},
|
||||||
|
"limit": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Max results to return (default 10)"
|
||||||
|
},
|
||||||
|
"semantic": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Use semantic search instead of keyword (optional)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
Tool::new(
|
||||||
|
"get_room_context".into(),
|
||||||
|
"Get messages around a specific point in time or event in a room. \
|
||||||
|
Useful for understanding the context of a conversation."
|
||||||
|
.into(),
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"room_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The Matrix room ID"
|
||||||
|
},
|
||||||
|
"around_timestamp": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Unix timestamp in ms to center the context around"
|
||||||
|
},
|
||||||
|
"around_event_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Event ID to center the context around"
|
||||||
|
},
|
||||||
|
"before_count": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of messages before the pivot (default 10)"
|
||||||
|
},
|
||||||
|
"after_count": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Number of messages after the pivot (default 10)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["room_id"]
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
Tool::new(
|
||||||
|
"list_rooms".into(),
|
||||||
|
"List all rooms Sol is currently in, with names and member counts.".into(),
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {}
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
Tool::new(
|
||||||
|
"get_room_members".into(),
|
||||||
|
"Get the list of members in a specific room.".into(),
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"room_id": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The Matrix room ID"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["room_id"]
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn execute(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
|
||||||
|
match name {
|
||||||
|
"search_archive" => {
|
||||||
|
search::search_archive(
|
||||||
|
&self.opensearch,
|
||||||
|
&self.config.opensearch.index,
|
||||||
|
arguments,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
"get_room_context" => {
|
||||||
|
room_history::get_room_context(
|
||||||
|
&self.opensearch,
|
||||||
|
&self.config.opensearch.index,
|
||||||
|
arguments,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
"list_rooms" => room_info::list_rooms(&self.matrix).await,
|
||||||
|
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await,
|
||||||
|
_ => anyhow::bail!("Unknown tool: {name}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
108
src/tools/room_history.rs
Normal file
108
src/tools/room_history.rs
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
use opensearch::OpenSearch;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct RoomHistoryArgs {
|
||||||
|
pub room_id: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub around_timestamp: Option<i64>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub around_event_id: Option<String>,
|
||||||
|
#[serde(default = "default_count")]
|
||||||
|
pub before_count: usize,
|
||||||
|
#[serde(default = "default_count")]
|
||||||
|
pub after_count: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_count() -> usize { 10 }
|
||||||
|
|
||||||
|
pub async fn get_room_context(
|
||||||
|
client: &OpenSearch,
|
||||||
|
index: &str,
|
||||||
|
args_json: &str,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let args: RoomHistoryArgs = serde_json::from_str(args_json)?;
|
||||||
|
let total = args.before_count + args.after_count + 1;
|
||||||
|
|
||||||
|
// Determine the pivot timestamp
|
||||||
|
let pivot_ts = if let Some(ts) = args.around_timestamp {
|
||||||
|
ts
|
||||||
|
} else if let Some(ref event_id) = args.around_event_id {
|
||||||
|
// Look up the event to get its timestamp
|
||||||
|
let lookup = json!({
|
||||||
|
"size": 1,
|
||||||
|
"query": { "term": { "event_id": event_id } },
|
||||||
|
"_source": ["timestamp"]
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp = client
|
||||||
|
.search(opensearch::SearchParts::Index(&[index]))
|
||||||
|
.body(lookup)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let body: serde_json::Value = resp.json().await?;
|
||||||
|
body["hits"]["hits"][0]["_source"]["timestamp"]
|
||||||
|
.as_i64()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Event {event_id} not found in archive"))?
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Either around_timestamp or around_event_id must be provided");
|
||||||
|
};
|
||||||
|
|
||||||
|
let query = json!({
|
||||||
|
"size": total,
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"must": [
|
||||||
|
{ "term": { "room_id": args.room_id } },
|
||||||
|
{ "term": { "redacted": false } }
|
||||||
|
],
|
||||||
|
"should": [
|
||||||
|
{
|
||||||
|
"range": {
|
||||||
|
"timestamp": {
|
||||||
|
"gte": pivot_ts - 3_600_000,
|
||||||
|
"lte": pivot_ts + 3_600_000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sort": [{ "timestamp": "asc" }]
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.search(opensearch::SearchParts::Index(&[index]))
|
||||||
|
.body(query)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let body: serde_json::Value = response.json().await?;
|
||||||
|
let hits = &body["hits"]["hits"];
|
||||||
|
|
||||||
|
let Some(hits_arr) = hits.as_array() else {
|
||||||
|
return Ok("No messages found around that point.".into());
|
||||||
|
};
|
||||||
|
|
||||||
|
if hits_arr.is_empty() {
|
||||||
|
return Ok("No messages found around that point.".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = String::new();
|
||||||
|
for hit in hits_arr {
|
||||||
|
let src = &hit["_source"];
|
||||||
|
let sender = src["sender_name"].as_str().unwrap_or("unknown");
|
||||||
|
let content = src["content"].as_str().unwrap_or("");
|
||||||
|
let ts = src["timestamp"].as_i64().unwrap_or(0);
|
||||||
|
|
||||||
|
let dt = chrono::DateTime::from_timestamp_millis(ts)
|
||||||
|
.map(|d| d.format("%H:%M").to_string())
|
||||||
|
.unwrap_or_else(|| "??:??".into());
|
||||||
|
|
||||||
|
output.push_str(&format!("[{dt}] {sender}: {content}\n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
55
src/tools/room_info.rs
Normal file
55
src/tools/room_info.rs
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
use matrix_sdk::Client;
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct ListRoomsArgs {}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct GetMembersArgs {
|
||||||
|
pub room_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list_rooms(client: &Client) -> anyhow::Result<String> {
|
||||||
|
let rooms = client.joined_rooms();
|
||||||
|
if rooms.is_empty() {
|
||||||
|
return Ok("I'm not in any rooms.".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = String::new();
|
||||||
|
for room in &rooms {
|
||||||
|
let name = match room.cached_display_name() {
|
||||||
|
Some(n) => n.to_string(),
|
||||||
|
None => room.room_id().to_string(),
|
||||||
|
};
|
||||||
|
let id = room.room_id();
|
||||||
|
let members = room.joined_members_count();
|
||||||
|
output.push_str(&format!("- {name} ({id}) — {members} members\n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_room_members(client: &Client, args_json: &str) -> anyhow::Result<String> {
|
||||||
|
let args: GetMembersArgs = serde_json::from_str(args_json)?;
|
||||||
|
let room_id = <&ruma::RoomId>::try_from(args.room_id.as_str())?;
|
||||||
|
|
||||||
|
let Some(room) = client.get_room(room_id) else {
|
||||||
|
anyhow::bail!("I'm not in room {}", args.room_id);
|
||||||
|
};
|
||||||
|
|
||||||
|
let members = room.members(matrix_sdk::RoomMemberships::JOIN).await?;
|
||||||
|
if members.is_empty() {
|
||||||
|
return Ok("No members found.".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = String::new();
|
||||||
|
for member in &members {
|
||||||
|
let display = member
|
||||||
|
.display_name()
|
||||||
|
.unwrap_or_else(|| member.user_id().as_str());
|
||||||
|
let user_id = member.user_id();
|
||||||
|
output.push_str(&format!("- {display} ({user_id})\n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
274
src/tools/search.rs
Normal file
274
src/tools/search.rs
Normal file
@@ -0,0 +1,274 @@
|
|||||||
|
use opensearch::OpenSearch;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::json;
|
||||||
|
use tracing::debug;
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
pub struct SearchArgs {
|
||||||
|
pub query: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub room: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub sender: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub after: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub before: Option<String>,
|
||||||
|
#[serde(default = "default_limit")]
|
||||||
|
pub limit: usize,
|
||||||
|
#[serde(default)]
|
||||||
|
pub semantic: Option<bool>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_limit() -> usize { 10 }
|
||||||
|
|
||||||
|
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
|
||||||
|
pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
|
||||||
|
let must = vec![json!({
|
||||||
|
"match": { "content": args.query }
|
||||||
|
})];
|
||||||
|
|
||||||
|
let mut filter = vec![json!({
|
||||||
|
"term": { "redacted": false }
|
||||||
|
})];
|
||||||
|
|
||||||
|
if let Some(ref room) = args.room {
|
||||||
|
filter.push(json!({ "term": { "room_name": room } }));
|
||||||
|
}
|
||||||
|
if let Some(ref sender) = args.sender {
|
||||||
|
filter.push(json!({ "term": { "sender_name": sender } }));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut range = serde_json::Map::new();
|
||||||
|
if let Some(ref after) = args.after {
|
||||||
|
if let Ok(ts) = after.parse::<i64>() {
|
||||||
|
range.insert("gte".into(), json!(ts));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if let Some(ref before) = args.before {
|
||||||
|
if let Ok(ts) = before.parse::<i64>() {
|
||||||
|
range.insert("lte".into(), json!(ts));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !range.is_empty() {
|
||||||
|
filter.push(json!({ "range": { "timestamp": range } }));
|
||||||
|
}
|
||||||
|
|
||||||
|
json!({
|
||||||
|
"size": args.limit,
|
||||||
|
"query": {
|
||||||
|
"bool": {
|
||||||
|
"must": must,
|
||||||
|
"filter": filter
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"sort": [{ "timestamp": "desc" }],
|
||||||
|
"_source": ["event_id", "room_name", "sender_name", "timestamp", "content"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn search_archive(
|
||||||
|
client: &OpenSearch,
|
||||||
|
index: &str,
|
||||||
|
args_json: &str,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let args: SearchArgs = serde_json::from_str(args_json)?;
|
||||||
|
debug!(query = args.query.as_str(), "Searching archive");
|
||||||
|
|
||||||
|
let query_body = build_search_query(&args);
|
||||||
|
|
||||||
|
let response = client
|
||||||
|
.search(opensearch::SearchParts::Index(&[index]))
|
||||||
|
.body(query_body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let body: serde_json::Value = response.json().await?;
|
||||||
|
let hits = &body["hits"]["hits"];
|
||||||
|
|
||||||
|
let Some(hits_arr) = hits.as_array() else {
|
||||||
|
return Ok("No results found.".into());
|
||||||
|
};
|
||||||
|
|
||||||
|
if hits_arr.is_empty() {
|
||||||
|
return Ok("No results found.".into());
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut output = String::new();
|
||||||
|
for (i, hit) in hits_arr.iter().enumerate() {
|
||||||
|
let src = &hit["_source"];
|
||||||
|
let sender = src["sender_name"].as_str().unwrap_or("unknown");
|
||||||
|
let room = src["room_name"].as_str().unwrap_or("unknown");
|
||||||
|
let content = src["content"].as_str().unwrap_or("");
|
||||||
|
let ts = src["timestamp"].as_i64().unwrap_or(0);
|
||||||
|
|
||||||
|
let dt = chrono::DateTime::from_timestamp_millis(ts)
|
||||||
|
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
|
||||||
|
.unwrap_or_else(|| "unknown date".into());
|
||||||
|
|
||||||
|
output.push_str(&format!(
|
||||||
|
"{}. [{dt}] #{room} — {sender}: {content}\n",
|
||||||
|
i + 1
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(output)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
fn parse_args(json: &str) -> SearchArgs {
|
||||||
|
serde_json::from_str(json).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_minimal_args() {
|
||||||
|
let args = parse_args(r#"{"query": "hello"}"#);
|
||||||
|
assert_eq!(args.query, "hello");
|
||||||
|
assert!(args.room.is_none());
|
||||||
|
assert!(args.sender.is_none());
|
||||||
|
assert!(args.after.is_none());
|
||||||
|
assert!(args.before.is_none());
|
||||||
|
assert_eq!(args.limit, 10); // default
|
||||||
|
assert!(args.semantic.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_full_args() {
|
||||||
|
let args = parse_args(r#"{
|
||||||
|
"query": "meeting notes",
|
||||||
|
"room": "general",
|
||||||
|
"sender": "Alice",
|
||||||
|
"after": "1710000000000",
|
||||||
|
"before": "1710100000000",
|
||||||
|
"limit": 25,
|
||||||
|
"semantic": true
|
||||||
|
}"#);
|
||||||
|
assert_eq!(args.query, "meeting notes");
|
||||||
|
assert_eq!(args.room.as_deref(), Some("general"));
|
||||||
|
assert_eq!(args.sender.as_deref(), Some("Alice"));
|
||||||
|
assert_eq!(args.after.as_deref(), Some("1710000000000"));
|
||||||
|
assert_eq!(args.before.as_deref(), Some("1710100000000"));
|
||||||
|
assert_eq!(args.limit, 25);
|
||||||
|
assert_eq!(args.semantic, Some(true));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_basic() {
|
||||||
|
let args = parse_args(r#"{"query": "test"}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
assert_eq!(q["size"], 10);
|
||||||
|
assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test");
|
||||||
|
assert_eq!(q["query"]["bool"]["filter"][0]["term"]["redacted"], false);
|
||||||
|
assert_eq!(q["sort"][0]["timestamp"], "desc");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_with_room_filter() {
|
||||||
|
let args = parse_args(r#"{"query": "hello", "room": "design"}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||||
|
assert_eq!(filters.len(), 2);
|
||||||
|
assert_eq!(filters[1]["term"]["room_name"], "design");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_with_sender_filter() {
|
||||||
|
let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||||
|
assert_eq!(filters.len(), 2);
|
||||||
|
assert_eq!(filters[1]["term"]["sender_name"], "Bob");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_with_room_and_sender() {
|
||||||
|
let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||||
|
assert_eq!(filters.len(), 3);
|
||||||
|
assert_eq!(filters[1]["term"]["room_name"], "dev");
|
||||||
|
assert_eq!(filters[2]["term"]["sender_name"], "Carol");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_with_date_range() {
|
||||||
|
let args = parse_args(r#"{
|
||||||
|
"query": "hello",
|
||||||
|
"after": "1710000000000",
|
||||||
|
"before": "1710100000000"
|
||||||
|
}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||||
|
let range_filter = &filters[1]["range"]["timestamp"];
|
||||||
|
assert_eq!(range_filter["gte"], 1710000000000_i64);
|
||||||
|
assert_eq!(range_filter["lte"], 1710100000000_i64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_with_after_only() {
|
||||||
|
let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||||
|
let range_filter = &filters[1]["range"]["timestamp"];
|
||||||
|
assert_eq!(range_filter["gte"], 1710000000000_i64);
|
||||||
|
assert!(range_filter.get("lte").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_with_custom_limit() {
|
||||||
|
let args = parse_args(r#"{"query": "hello", "limit": 50}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
assert_eq!(q["size"], 50);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_query_all_filters_combined() {
|
||||||
|
let args = parse_args(r#"{
|
||||||
|
"query": "architecture",
|
||||||
|
"room": "engineering",
|
||||||
|
"sender": "Sienna",
|
||||||
|
"after": "1000",
|
||||||
|
"before": "2000",
|
||||||
|
"limit": 5
|
||||||
|
}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
assert_eq!(q["size"], 5);
|
||||||
|
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||||
|
// redacted=false, room, sender, range = 4 filters
|
||||||
|
assert_eq!(filters.len(), 4);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_timestamp_ignored() {
|
||||||
|
let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||||
|
// Only the redacted filter, no range since parse failed
|
||||||
|
assert_eq!(filters.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_source_fields() {
|
||||||
|
let args = parse_args(r#"{"query": "test"}"#);
|
||||||
|
let q = build_search_query(&args);
|
||||||
|
|
||||||
|
let source = q["_source"].as_array().unwrap();
|
||||||
|
let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect();
|
||||||
|
assert!(fields.contains(&"event_id"));
|
||||||
|
assert!(fields.contains(&"room_name"));
|
||||||
|
assert!(fields.contains(&"sender_name"));
|
||||||
|
assert!(fields.contains(&"timestamp"));
|
||||||
|
assert!(fields.contains(&"content"));
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user