feat: initial Sol virtual librarian implementation
Matrix bot with E2EE (matrix-sdk 0.9) that passively archives all messages to OpenSearch and responds to queries via Mistral AI with function calling tools. Core systems: - Archive: bulk OpenSearch indexer with batch/flush, edit/redaction handling, embedding pipeline passthrough - Brain: rule-based engagement evaluator (mentions, DMs, name invocations), LLM-powered spontaneous engagement, per-room conversation context windows, response delay simulation - Tools: search_archive, get_room_context, list_rooms, get_room_members registered as Mistral function calling tools with iterative tool loop - Personality: templated system prompt with Sol's librarian persona 47 unit tests covering config, evaluator, conversation windowing, personality templates, schema serialization, and search query building.
This commit is contained in:
151
src/tools/mod.rs
Normal file
151
src/tools/mod.rs
Normal file
@@ -0,0 +1,151 @@
|
||||
pub mod room_history;
|
||||
pub mod room_info;
|
||||
pub mod search;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk::Client as MatrixClient;
|
||||
use mistralai_client::v1::tool::Tool;
|
||||
use opensearch::OpenSearch;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::config::Config;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
opensearch: OpenSearch,
|
||||
matrix: MatrixClient,
|
||||
config: Arc<Config>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
pub fn new(opensearch: OpenSearch, matrix: MatrixClient, config: Arc<Config>) -> Self {
|
||||
Self {
|
||||
opensearch,
|
||||
matrix,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool_definitions() -> Vec<Tool> {
|
||||
vec![
|
||||
Tool::new(
|
||||
"search_archive".into(),
|
||||
"Search the message archive. Use this to find past conversations, \
|
||||
messages from specific people, or about specific topics."
|
||||
.into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query for message content"
|
||||
},
|
||||
"room": {
|
||||
"type": "string",
|
||||
"description": "Filter by room name (optional)"
|
||||
},
|
||||
"sender": {
|
||||
"type": "string",
|
||||
"description": "Filter by sender display name (optional)"
|
||||
},
|
||||
"after": {
|
||||
"type": "string",
|
||||
"description": "Unix timestamp in ms — only messages after this time (optional)"
|
||||
},
|
||||
"before": {
|
||||
"type": "string",
|
||||
"description": "Unix timestamp in ms — only messages before this time (optional)"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default 10)"
|
||||
},
|
||||
"semantic": {
|
||||
"type": "boolean",
|
||||
"description": "Use semantic search instead of keyword (optional)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"get_room_context".into(),
|
||||
"Get messages around a specific point in time or event in a room. \
|
||||
Useful for understanding the context of a conversation."
|
||||
.into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"room_id": {
|
||||
"type": "string",
|
||||
"description": "The Matrix room ID"
|
||||
},
|
||||
"around_timestamp": {
|
||||
"type": "integer",
|
||||
"description": "Unix timestamp in ms to center the context around"
|
||||
},
|
||||
"around_event_id": {
|
||||
"type": "string",
|
||||
"description": "Event ID to center the context around"
|
||||
},
|
||||
"before_count": {
|
||||
"type": "integer",
|
||||
"description": "Number of messages before the pivot (default 10)"
|
||||
},
|
||||
"after_count": {
|
||||
"type": "integer",
|
||||
"description": "Number of messages after the pivot (default 10)"
|
||||
}
|
||||
},
|
||||
"required": ["room_id"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"list_rooms".into(),
|
||||
"List all rooms Sol is currently in, with names and member counts.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"get_room_members".into(),
|
||||
"Get the list of members in a specific room.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"room_id": {
|
||||
"type": "string",
|
||||
"description": "The Matrix room ID"
|
||||
}
|
||||
},
|
||||
"required": ["room_id"]
|
||||
}),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
pub async fn execute(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
|
||||
match name {
|
||||
"search_archive" => {
|
||||
search::search_archive(
|
||||
&self.opensearch,
|
||||
&self.config.opensearch.index,
|
||||
arguments,
|
||||
)
|
||||
.await
|
||||
}
|
||||
"get_room_context" => {
|
||||
room_history::get_room_context(
|
||||
&self.opensearch,
|
||||
&self.config.opensearch.index,
|
||||
arguments,
|
||||
)
|
||||
.await
|
||||
}
|
||||
"list_rooms" => room_info::list_rooms(&self.matrix).await,
|
||||
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await,
|
||||
_ => anyhow::bail!("Unknown tool: {name}"),
|
||||
}
|
||||
}
|
||||
}
|
||||
108
src/tools/room_history.rs
Normal file
108
src/tools/room_history.rs
Normal file
@@ -0,0 +1,108 @@
|
||||
use opensearch::OpenSearch;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct RoomHistoryArgs {
|
||||
pub room_id: String,
|
||||
#[serde(default)]
|
||||
pub around_timestamp: Option<i64>,
|
||||
#[serde(default)]
|
||||
pub around_event_id: Option<String>,
|
||||
#[serde(default = "default_count")]
|
||||
pub before_count: usize,
|
||||
#[serde(default = "default_count")]
|
||||
pub after_count: usize,
|
||||
}
|
||||
|
||||
fn default_count() -> usize { 10 }
|
||||
|
||||
pub async fn get_room_context(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
args_json: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let args: RoomHistoryArgs = serde_json::from_str(args_json)?;
|
||||
let total = args.before_count + args.after_count + 1;
|
||||
|
||||
// Determine the pivot timestamp
|
||||
let pivot_ts = if let Some(ts) = args.around_timestamp {
|
||||
ts
|
||||
} else if let Some(ref event_id) = args.around_event_id {
|
||||
// Look up the event to get its timestamp
|
||||
let lookup = json!({
|
||||
"size": 1,
|
||||
"query": { "term": { "event_id": event_id } },
|
||||
"_source": ["timestamp"]
|
||||
});
|
||||
|
||||
let resp = client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(lookup)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body: serde_json::Value = resp.json().await?;
|
||||
body["hits"]["hits"][0]["_source"]["timestamp"]
|
||||
.as_i64()
|
||||
.ok_or_else(|| anyhow::anyhow!("Event {event_id} not found in archive"))?
|
||||
} else {
|
||||
anyhow::bail!("Either around_timestamp or around_event_id must be provided");
|
||||
};
|
||||
|
||||
let query = json!({
|
||||
"size": total,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": [
|
||||
{ "term": { "room_id": args.room_id } },
|
||||
{ "term": { "redacted": false } }
|
||||
],
|
||||
"should": [
|
||||
{
|
||||
"range": {
|
||||
"timestamp": {
|
||||
"gte": pivot_ts - 3_600_000,
|
||||
"lte": pivot_ts + 3_600_000
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
"sort": [{ "timestamp": "asc" }]
|
||||
});
|
||||
|
||||
let response = client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(query)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let hits = &body["hits"]["hits"];
|
||||
|
||||
let Some(hits_arr) = hits.as_array() else {
|
||||
return Ok("No messages found around that point.".into());
|
||||
};
|
||||
|
||||
if hits_arr.is_empty() {
|
||||
return Ok("No messages found around that point.".into());
|
||||
}
|
||||
|
||||
let mut output = String::new();
|
||||
for hit in hits_arr {
|
||||
let src = &hit["_source"];
|
||||
let sender = src["sender_name"].as_str().unwrap_or("unknown");
|
||||
let content = src["content"].as_str().unwrap_or("");
|
||||
let ts = src["timestamp"].as_i64().unwrap_or(0);
|
||||
|
||||
let dt = chrono::DateTime::from_timestamp_millis(ts)
|
||||
.map(|d| d.format("%H:%M").to_string())
|
||||
.unwrap_or_else(|| "??:??".into());
|
||||
|
||||
output.push_str(&format!("[{dt}] {sender}: {content}\n"));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
55
src/tools/room_info.rs
Normal file
55
src/tools/room_info.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use matrix_sdk::Client;
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct ListRoomsArgs {}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct GetMembersArgs {
|
||||
pub room_id: String,
|
||||
}
|
||||
|
||||
pub async fn list_rooms(client: &Client) -> anyhow::Result<String> {
|
||||
let rooms = client.joined_rooms();
|
||||
if rooms.is_empty() {
|
||||
return Ok("I'm not in any rooms.".into());
|
||||
}
|
||||
|
||||
let mut output = String::new();
|
||||
for room in &rooms {
|
||||
let name = match room.cached_display_name() {
|
||||
Some(n) => n.to_string(),
|
||||
None => room.room_id().to_string(),
|
||||
};
|
||||
let id = room.room_id();
|
||||
let members = room.joined_members_count();
|
||||
output.push_str(&format!("- {name} ({id}) — {members} members\n"));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub async fn get_room_members(client: &Client, args_json: &str) -> anyhow::Result<String> {
|
||||
let args: GetMembersArgs = serde_json::from_str(args_json)?;
|
||||
let room_id = <&ruma::RoomId>::try_from(args.room_id.as_str())?;
|
||||
|
||||
let Some(room) = client.get_room(room_id) else {
|
||||
anyhow::bail!("I'm not in room {}", args.room_id);
|
||||
};
|
||||
|
||||
let members = room.members(matrix_sdk::RoomMemberships::JOIN).await?;
|
||||
if members.is_empty() {
|
||||
return Ok("No members found.".into());
|
||||
}
|
||||
|
||||
let mut output = String::new();
|
||||
for member in &members {
|
||||
let display = member
|
||||
.display_name()
|
||||
.unwrap_or_else(|| member.user_id().as_str());
|
||||
let user_id = member.user_id();
|
||||
output.push_str(&format!("- {display} ({user_id})\n"));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
274
src/tools/search.rs
Normal file
274
src/tools/search.rs
Normal file
@@ -0,0 +1,274 @@
|
||||
use opensearch::OpenSearch;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tracing::debug;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SearchArgs {
|
||||
pub query: String,
|
||||
#[serde(default)]
|
||||
pub room: Option<String>,
|
||||
#[serde(default)]
|
||||
pub sender: Option<String>,
|
||||
#[serde(default)]
|
||||
pub after: Option<String>,
|
||||
#[serde(default)]
|
||||
pub before: Option<String>,
|
||||
#[serde(default = "default_limit")]
|
||||
pub limit: usize,
|
||||
#[serde(default)]
|
||||
pub semantic: Option<bool>,
|
||||
}
|
||||
|
||||
fn default_limit() -> usize { 10 }
|
||||
|
||||
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
|
||||
pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
|
||||
let must = vec![json!({
|
||||
"match": { "content": args.query }
|
||||
})];
|
||||
|
||||
let mut filter = vec![json!({
|
||||
"term": { "redacted": false }
|
||||
})];
|
||||
|
||||
if let Some(ref room) = args.room {
|
||||
filter.push(json!({ "term": { "room_name": room } }));
|
||||
}
|
||||
if let Some(ref sender) = args.sender {
|
||||
filter.push(json!({ "term": { "sender_name": sender } }));
|
||||
}
|
||||
|
||||
let mut range = serde_json::Map::new();
|
||||
if let Some(ref after) = args.after {
|
||||
if let Ok(ts) = after.parse::<i64>() {
|
||||
range.insert("gte".into(), json!(ts));
|
||||
}
|
||||
}
|
||||
if let Some(ref before) = args.before {
|
||||
if let Ok(ts) = before.parse::<i64>() {
|
||||
range.insert("lte".into(), json!(ts));
|
||||
}
|
||||
}
|
||||
if !range.is_empty() {
|
||||
filter.push(json!({ "range": { "timestamp": range } }));
|
||||
}
|
||||
|
||||
json!({
|
||||
"size": args.limit,
|
||||
"query": {
|
||||
"bool": {
|
||||
"must": must,
|
||||
"filter": filter
|
||||
}
|
||||
},
|
||||
"sort": [{ "timestamp": "desc" }],
|
||||
"_source": ["event_id", "room_name", "sender_name", "timestamp", "content"]
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn search_archive(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
args_json: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let args: SearchArgs = serde_json::from_str(args_json)?;
|
||||
debug!(query = args.query.as_str(), "Searching archive");
|
||||
|
||||
let query_body = build_search_query(&args);
|
||||
|
||||
let response = client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(query_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let hits = &body["hits"]["hits"];
|
||||
|
||||
let Some(hits_arr) = hits.as_array() else {
|
||||
return Ok("No results found.".into());
|
||||
};
|
||||
|
||||
if hits_arr.is_empty() {
|
||||
return Ok("No results found.".into());
|
||||
}
|
||||
|
||||
let mut output = String::new();
|
||||
for (i, hit) in hits_arr.iter().enumerate() {
|
||||
let src = &hit["_source"];
|
||||
let sender = src["sender_name"].as_str().unwrap_or("unknown");
|
||||
let room = src["room_name"].as_str().unwrap_or("unknown");
|
||||
let content = src["content"].as_str().unwrap_or("");
|
||||
let ts = src["timestamp"].as_i64().unwrap_or(0);
|
||||
|
||||
let dt = chrono::DateTime::from_timestamp_millis(ts)
|
||||
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
|
||||
.unwrap_or_else(|| "unknown date".into());
|
||||
|
||||
output.push_str(&format!(
|
||||
"{}. [{dt}] #{room} — {sender}: {content}\n",
|
||||
i + 1
|
||||
));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn parse_args(json: &str) -> SearchArgs {
|
||||
serde_json::from_str(json).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_minimal_args() {
|
||||
let args = parse_args(r#"{"query": "hello"}"#);
|
||||
assert_eq!(args.query, "hello");
|
||||
assert!(args.room.is_none());
|
||||
assert!(args.sender.is_none());
|
||||
assert!(args.after.is_none());
|
||||
assert!(args.before.is_none());
|
||||
assert_eq!(args.limit, 10); // default
|
||||
assert!(args.semantic.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_full_args() {
|
||||
let args = parse_args(r#"{
|
||||
"query": "meeting notes",
|
||||
"room": "general",
|
||||
"sender": "Alice",
|
||||
"after": "1710000000000",
|
||||
"before": "1710100000000",
|
||||
"limit": 25,
|
||||
"semantic": true
|
||||
}"#);
|
||||
assert_eq!(args.query, "meeting notes");
|
||||
assert_eq!(args.room.as_deref(), Some("general"));
|
||||
assert_eq!(args.sender.as_deref(), Some("Alice"));
|
||||
assert_eq!(args.after.as_deref(), Some("1710000000000"));
|
||||
assert_eq!(args.before.as_deref(), Some("1710100000000"));
|
||||
assert_eq!(args.limit, 25);
|
||||
assert_eq!(args.semantic, Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_basic() {
|
||||
let args = parse_args(r#"{"query": "test"}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
assert_eq!(q["size"], 10);
|
||||
assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test");
|
||||
assert_eq!(q["query"]["bool"]["filter"][0]["term"]["redacted"], false);
|
||||
assert_eq!(q["sort"][0]["timestamp"], "desc");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_with_room_filter() {
|
||||
let args = parse_args(r#"{"query": "hello", "room": "design"}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 2);
|
||||
assert_eq!(filters[1]["term"]["room_name"], "design");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_with_sender_filter() {
|
||||
let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 2);
|
||||
assert_eq!(filters[1]["term"]["sender_name"], "Bob");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_with_room_and_sender() {
|
||||
let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 3);
|
||||
assert_eq!(filters[1]["term"]["room_name"], "dev");
|
||||
assert_eq!(filters[2]["term"]["sender_name"], "Carol");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_with_date_range() {
|
||||
let args = parse_args(r#"{
|
||||
"query": "hello",
|
||||
"after": "1710000000000",
|
||||
"before": "1710100000000"
|
||||
}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
let range_filter = &filters[1]["range"]["timestamp"];
|
||||
assert_eq!(range_filter["gte"], 1710000000000_i64);
|
||||
assert_eq!(range_filter["lte"], 1710100000000_i64);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_with_after_only() {
|
||||
let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
let range_filter = &filters[1]["range"]["timestamp"];
|
||||
assert_eq!(range_filter["gte"], 1710000000000_i64);
|
||||
assert!(range_filter.get("lte").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_with_custom_limit() {
|
||||
let args = parse_args(r#"{"query": "hello", "limit": 50}"#);
|
||||
let q = build_search_query(&args);
|
||||
assert_eq!(q["size"], 50);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_query_all_filters_combined() {
|
||||
let args = parse_args(r#"{
|
||||
"query": "architecture",
|
||||
"room": "engineering",
|
||||
"sender": "Sienna",
|
||||
"after": "1000",
|
||||
"before": "2000",
|
||||
"limit": 5
|
||||
}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
assert_eq!(q["size"], 5);
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
// redacted=false, room, sender, range = 4 filters
|
||||
assert_eq!(filters.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_timestamp_ignored() {
|
||||
let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
// Only the redacted filter, no range since parse failed
|
||||
assert_eq!(filters.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_source_fields() {
|
||||
let args = parse_args(r#"{"query": "test"}"#);
|
||||
let q = build_search_query(&args);
|
||||
|
||||
let source = q["_source"].as_array().unwrap();
|
||||
let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect();
|
||||
assert!(fields.contains(&"event_id"));
|
||||
assert!(fields.contains(&"room_name"));
|
||||
assert!(fields.contains(&"sender_name"));
|
||||
assert!(fields.contains(&"timestamp"));
|
||||
assert!(fields.contains(&"content"));
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user