feat: per-user auto-memory with ResponseContext

Three memory channels: hidden tool (sol.memory.set/get in scripts),
pre-response injection (relevant memories loaded into system prompt),
and post-response extraction (ministral-3b extracts facts after each
response). User isolation enforced at Rust level — user_id derived
from Matrix sender, never from script arguments.

New modules: context (ResponseContext), memory (schema, store, extractor).
ResponseContext threaded through responder → tools → script runtime.
OpenSearch index sol_user_memory created on startup alongside archive.
This commit is contained in:
2026-03-21 15:51:31 +00:00
parent 4dc20bee23
commit 4949e70ecc
23 changed files with 4494 additions and 124 deletions

View File

@@ -59,6 +59,34 @@ impl Indexer {
}
}
pub async fn add_reaction(&self, target_event_id: &str, sender: &str, emoji: &str, timestamp: i64) {
// Use a script to append to the reactions array (upsert-safe)
let body = json!({
"script": {
"source": "if (ctx._source.reactions == null) { ctx._source.reactions = []; } ctx._source.reactions.add(params.reaction)",
"params": {
"reaction": {
"sender": sender,
"emoji": emoji,
"timestamp": timestamp
}
}
}
});
if let Err(e) = self
.client
.update(opensearch::UpdateParts::IndexId(
&self.config.opensearch.index,
target_event_id,
))
.body(body)
.send()
.await
{
warn!(target_event_id, sender, emoji, "Failed to add reaction: {e}");
}
}
pub async fn update_redaction(&self, event_id: &str) {
let body = json!({
"doc": {

View File

@@ -24,6 +24,15 @@ pub struct ArchiveDocument {
pub edited: bool,
#[serde(default)]
pub redacted: bool,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub reactions: Vec<Reaction>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Reaction {
pub sender: String,
pub emoji: String,
pub timestamp: i64,
}
const INDEX_MAPPING: &str = r#"{
@@ -45,7 +54,15 @@ const INDEX_MAPPING: &str = r#"{
"media_urls": { "type": "keyword" },
"event_type": { "type": "keyword" },
"edited": { "type": "boolean" },
"redacted": { "type": "boolean" }
"redacted": { "type": "boolean" },
"reactions": {
"type": "nested",
"properties": {
"sender": { "type": "keyword" },
"emoji": { "type": "keyword" },
"timestamp": { "type": "date", "format": "epoch_millis" }
}
}
}
}
}"#;
@@ -63,6 +80,25 @@ pub async fn create_index_if_not_exists(client: &OpenSearch, index: &str) -> any
if exists.status_code().is_success() {
info!(index, "OpenSearch index already exists");
// Ensure reactions field exists (added after initial schema)
let reactions_mapping = serde_json::json!({
"properties": {
"reactions": {
"type": "nested",
"properties": {
"sender": { "type": "keyword" },
"emoji": { "type": "keyword" },
"timestamp": { "type": "date", "format": "epoch_millis" }
}
}
}
});
let _ = client
.indices()
.put_mapping(opensearch::indices::IndicesPutMappingParts::Index(&[index]))
.body(reactions_mapping)
.send()
.await;
return Ok(());
}
@@ -102,6 +138,7 @@ mod tests {
event_type: "m.room.message".to_string(),
edited: false,
redacted: false,
reactions: vec![],
}
}

View File

@@ -5,7 +5,7 @@ use mistralai_client::v1::{
constants::Model,
};
use regex::Regex;
use tracing::{debug, warn};
use tracing::{debug, info, warn};
use crate::config::Config;
@@ -13,6 +13,7 @@ use crate::config::Config;
pub enum Engagement {
MustRespond { reason: MustRespondReason },
MaybeRespond { relevance: f32, hook: String },
React { emoji: String, relevance: f32 },
Ignore,
}
@@ -33,7 +34,9 @@ impl Evaluator {
// todo(sienna): regex must be configrable
pub fn new(config: Arc<Config>) -> Self {
let user_id = &config.matrix.user_id;
let mention_pattern = regex::escape(user_id);
// Match both plain @sol:sunbeam.pt and Matrix link format [sol](https://matrix.to/#/@sol:sunbeam.pt)
let escaped = regex::escape(user_id);
let mention_pattern = format!(r"{}|matrix\.to/#/{}", escaped, escaped);
let mention_regex = Regex::new(&mention_pattern).expect("Failed to compile mention regex");
let name_regex =
Regex::new(r"(?i)(?:^|\bhey\s+)\bsol\b").expect("Failed to compile name regex");
@@ -53,13 +56,17 @@ impl Evaluator {
recent_messages: &[String],
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Engagement {
let body_preview: String = body.chars().take(80).collect();
// Don't respond to ourselves
if sender == self.config.matrix.user_id {
debug!(sender, body = body_preview.as_str(), "Ignoring own message");
return Engagement::Ignore;
}
// Direct mention: @sol:sunbeam.pt
if self.mention_regex.is_match(body) {
info!(sender, body = body_preview.as_str(), rule = "direct_mention", "Engagement: MustRespond");
return Engagement::MustRespond {
reason: MustRespondReason::DirectMention,
};
@@ -67,6 +74,7 @@ impl Evaluator {
// DM
if is_dm {
info!(sender, body = body_preview.as_str(), rule = "dm", "Engagement: MustRespond");
return Engagement::MustRespond {
reason: MustRespondReason::DirectMessage,
};
@@ -74,11 +82,22 @@ impl Evaluator {
// Name invocation: "sol ..." or "hey sol ..."
if self.name_regex.is_match(body) {
info!(sender, body = body_preview.as_str(), rule = "name_invocation", "Engagement: MustRespond");
return Engagement::MustRespond {
reason: MustRespondReason::NameInvocation,
};
}
info!(
sender, body = body_preview.as_str(),
threshold = self.config.behavior.spontaneous_threshold,
model = self.config.mistral.evaluation_model.as_str(),
context_len = recent_messages.len(),
eval_window = self.config.behavior.evaluation_context_window,
detect_sol = self.config.behavior.detect_sol_in_conversation,
"No rule match — running LLM relevance evaluation"
);
// Cheap evaluation call for spontaneous responses
self.evaluate_relevance(body, recent_messages, mistral)
.await
@@ -119,23 +138,56 @@ impl Evaluator {
recent_messages: &[String],
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Engagement {
let window = self.config.behavior.evaluation_context_window;
let context = recent_messages
.iter()
.rev()
.take(5) //todo(sienna): must be configurable
.take(window)
.rev()
.cloned()
.collect::<Vec<_>>()
.join("\n");
// Check if Sol recently participated in this conversation
let sol_in_context = self.config.behavior.detect_sol_in_conversation
&& recent_messages.iter().any(|m| {
let lower = m.to_lowercase();
lower.starts_with("sol:") || lower.starts_with("sol ") || lower.contains("@sol:")
});
let default_active = "Sol is ALREADY part of this conversation (see messages above from Sol). \
Messages that follow up on Sol's response, ask Sol a question, or continue \
a thread Sol is in should score HIGH (0.8+). Sol should respond to follow-ups \
directed at them even if not mentioned by name.".to_string();
let default_passive = "Sol has NOT spoken in this conversation yet. Only score high if the message \
is clearly relevant to Sol's expertise (archive search, finding past conversations, \
information retrieval) or touches a topic Sol has genuine insight on.".to_string();
let participation_note = if sol_in_context {
self.config.behavior.evaluation_prompt_active.as_deref()
.unwrap_or(&default_active)
} else {
self.config.behavior.evaluation_prompt_passive.as_deref()
.unwrap_or(&default_passive)
};
info!(
sol_in_context,
context_window = window,
"Building evaluation prompt"
);
let prompt = format!(
"You are evaluating whether a virtual librarian named Sol should spontaneously join \
a conversation. Sol has deep knowledge of the group's message archive and helps \
people find information.\n\n\
"You are evaluating whether Sol should respond to a message in a group chat. \
Sol is a librarian with access to the team's message archive.\n\n\
Recent conversation:\n{context}\n\n\
Latest message: {body}\n\n\
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\"}}\n\
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant."
{participation_note}\n\n\
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\", \"emoji\": \"a single emoji reaction or empty string\"}}\n\
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant.\n\
emoji: if Sol wouldn't write a full response but might react to the message, suggest a single emoji. \
pick something that feels natural and specific to the message — not generic thumbs up. leave empty if no reaction fits."
);
let messages = vec![ChatMessage::new_user_message(&prompt)];
@@ -159,21 +211,48 @@ impl Evaluator {
match result {
Ok(response) => {
let text = &response.choices[0].message.content;
info!(
raw_response = text.as_str(),
model = self.config.mistral.evaluation_model.as_str(),
"LLM evaluation raw response"
);
match serde_json::from_str::<serde_json::Value>(text) {
Ok(val) => {
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
let hook = val["hook"].as_str().unwrap_or("").to_string();
let emoji = val["emoji"].as_str().unwrap_or("").to_string();
let threshold = self.config.behavior.spontaneous_threshold;
let reaction_threshold = self.config.behavior.reaction_threshold;
let reaction_enabled = self.config.behavior.reaction_enabled;
debug!(relevance, hook = hook.as_str(), "Evaluation result");
info!(
relevance,
threshold,
reaction_threshold,
hook = hook.as_str(),
emoji = emoji.as_str(),
"LLM evaluation parsed"
);
if relevance >= self.config.behavior.spontaneous_threshold {
if relevance >= threshold {
Engagement::MaybeRespond { relevance, hook }
} else if reaction_enabled
&& relevance >= reaction_threshold
&& !emoji.is_empty()
{
info!(
relevance,
emoji = emoji.as_str(),
"Reaction range — will react with emoji"
);
Engagement::React { emoji, relevance }
} else {
Engagement::Ignore
}
}
Err(e) => {
warn!("Failed to parse evaluation response: {e}");
warn!(raw = text.as_str(), "Failed to parse evaluation response: {e}");
Engagement::Ignore
}
}

View File

@@ -15,14 +15,19 @@ impl Personality {
&self,
room_name: &str,
members: &[String],
memory_notes: Option<&str>,
) -> String {
let date = Utc::now().format("%Y-%m-%d").to_string();
let now = Utc::now();
let date = now.format("%Y-%m-%d").to_string();
let epoch_ms = now.timestamp_millis().to_string();
let members_str = members.join(", ");
self.template
.replace("{date}", &date)
.replace("{epoch_ms}", &epoch_ms)
.replace("{room_name}", room_name)
.replace("{members}", &members_str)
.replace("{memory_notes}", memory_notes.unwrap_or(""))
}
}
@@ -33,7 +38,7 @@ mod tests {
#[test]
fn test_date_substitution() {
let p = Personality::new("Today is {date}.".to_string());
let result = p.build_system_prompt("general", &[]);
let result = p.build_system_prompt("general", &[], None);
let today = Utc::now().format("%Y-%m-%d").to_string();
assert_eq!(result, format!("Today is {today}."));
}
@@ -41,7 +46,7 @@ mod tests {
#[test]
fn test_room_name_substitution() {
let p = Personality::new("You are in {room_name}.".to_string());
let result = p.build_system_prompt("design-chat", &[]);
let result = p.build_system_prompt("design-chat", &[], None);
assert!(result.contains("design-chat"));
}
@@ -49,14 +54,14 @@ mod tests {
fn test_members_substitution() {
let p = Personality::new("Members: {members}".to_string());
let members = vec!["Alice".to_string(), "Bob".to_string(), "Carol".to_string()];
let result = p.build_system_prompt("room", &members);
let result = p.build_system_prompt("room", &members, None);
assert_eq!(result, "Members: Alice, Bob, Carol");
}
#[test]
fn test_empty_members() {
let p = Personality::new("Members: {members}".to_string());
let result = p.build_system_prompt("room", &[]);
let result = p.build_system_prompt("room", &[], None);
assert_eq!(result, "Members: ");
}
@@ -65,7 +70,7 @@ mod tests {
let template = "Date: {date}, Room: {room_name}, People: {members}".to_string();
let p = Personality::new(template);
let members = vec!["Sienna".to_string(), "Lonni".to_string()];
let result = p.build_system_prompt("studio", &members);
let result = p.build_system_prompt("studio", &members, None);
let today = Utc::now().format("%Y-%m-%d").to_string();
assert!(result.starts_with(&format!("Date: {today}")));
@@ -76,14 +81,32 @@ mod tests {
#[test]
fn test_no_placeholders_passthrough() {
let p = Personality::new("Static prompt with no variables.".to_string());
let result = p.build_system_prompt("room", &["Alice".to_string()]);
let result = p.build_system_prompt("room", &["Alice".to_string()], None);
assert_eq!(result, "Static prompt with no variables.");
}
#[test]
fn test_multiple_same_placeholder() {
let p = Personality::new("{room_name} is great. I love {room_name}.".to_string());
let result = p.build_system_prompt("lounge", &[]);
let result = p.build_system_prompt("lounge", &[], None);
assert_eq!(result, "lounge is great. I love lounge.");
}
#[test]
fn test_memory_notes_substitution() {
let p = Personality::new("Context:\n{memory_notes}\nEnd.".to_string());
let notes = "## notes about sienna\n- [preference] likes terse answers";
let result = p.build_system_prompt("room", &[], Some(notes));
assert!(result.contains("## notes about sienna"));
assert!(result.contains("- [preference] likes terse answers"));
assert!(result.starts_with("Context:\n"));
assert!(result.ends_with("\nEnd."));
}
#[test]
fn test_memory_notes_none_clears_placeholder() {
let p = Personality::new("Before\n{memory_notes}\nAfter".to_string());
let result = p.build_system_prompt("room", &[], None);
assert_eq!(result, "Before\n\nAfter");
}
}

View File

@@ -10,9 +10,14 @@ use rand::Rng;
use tokio::time::{sleep, Duration};
use tracing::{debug, error, info, warn};
use matrix_sdk::room::Room;
use opensearch::OpenSearch;
use crate::brain::conversation::ContextMessage;
use crate::brain::personality::Personality;
use crate::config::Config;
use crate::context::ResponseContext;
use crate::memory;
use crate::tools::ToolRegistry;
/// Run a Mistral chat completion on a blocking thread.
@@ -38,6 +43,7 @@ pub struct Responder {
config: Arc<Config>,
personality: Arc<Personality>,
tools: Arc<ToolRegistry>,
opensearch: OpenSearch,
}
impl Responder {
@@ -45,11 +51,13 @@ impl Responder {
config: Arc<Config>,
personality: Arc<Personality>,
tools: Arc<ToolRegistry>,
opensearch: OpenSearch,
) -> Self {
Self {
config,
personality,
tools,
opensearch,
}
}
@@ -62,31 +70,52 @@ impl Responder {
members: &[String],
is_spontaneous: bool,
mistral: &Arc<mistralai_client::v1::client::Client>,
room: &Room,
response_ctx: &ResponseContext,
) -> Option<String> {
// Apply response delay
let delay = if is_spontaneous {
rand::thread_rng().gen_range(
self.config.behavior.spontaneous_delay_min_ms
..=self.config.behavior.spontaneous_delay_max_ms,
)
} else {
rand::thread_rng().gen_range(
self.config.behavior.response_delay_min_ms
..=self.config.behavior.response_delay_max_ms,
)
};
sleep(Duration::from_millis(delay)).await;
// Apply response delay (skip if instant_responses is enabled)
// Delay happens BEFORE typing indicator — Sol "notices" the message first
if !self.config.behavior.instant_responses {
let delay = if is_spontaneous {
rand::thread_rng().gen_range(
self.config.behavior.spontaneous_delay_min_ms
..=self.config.behavior.spontaneous_delay_max_ms,
)
} else {
rand::thread_rng().gen_range(
self.config.behavior.response_delay_min_ms
..=self.config.behavior.response_delay_max_ms,
)
};
debug!(delay_ms = delay, is_spontaneous, "Applying response delay");
sleep(Duration::from_millis(delay)).await;
}
let system_prompt = self.personality.build_system_prompt(room_name, members);
// Start typing AFTER the delay — Sol has decided to respond
let _ = room.typing_notice(true).await;
// Pre-response memory query
let memory_notes = self
.load_memory_notes(response_ctx, trigger_body)
.await;
let system_prompt = self.personality.build_system_prompt(
room_name,
members,
memory_notes.as_deref(),
);
let mut messages = vec![ChatMessage::new_system_message(&system_prompt)];
// Add context messages
// Add context messages with timestamps so the model has time awareness
for msg in context {
let ts = chrono::DateTime::from_timestamp_millis(msg.timestamp)
.map(|d| d.format("%H:%M").to_string())
.unwrap_or_default();
if msg.sender == self.config.matrix.user_id {
messages.push(ChatMessage::new_assistant_message(&msg.content, None));
} else {
let user_msg = format!("{}: {}", msg.sender, msg.content);
let user_msg = format!("[{}] {}: {}", ts, msg.sender, msg.content);
messages.push(ChatMessage::new_user_message(&user_msg));
}
}
@@ -117,6 +146,7 @@ impl Responder {
let response = match chat_blocking(mistral, model.clone(), messages.clone(), params).await {
Ok(r) => r,
Err(e) => {
let _ = room.typing_notice(false).await;
error!("Mistral chat failed: {e}");
return None;
}
@@ -137,12 +167,13 @@ impl Responder {
info!(
tool = tc.function.name.as_str(),
id = call_id,
args = tc.function.arguments.as_str(),
"Executing tool call"
);
let result = self
.tools
.execute(&tc.function.name, &tc.function.arguments)
.execute(&tc.function.name, &tc.function.arguments, response_ctx)
.await;
let result_str = match result {
@@ -165,15 +196,155 @@ impl Responder {
}
}
// Final text response
let text = choice.message.content.trim().to_string();
// Final text response — strip own name prefix if present
let mut text = choice.message.content.trim().to_string();
// Strip "sol:" or "sol 💕:" or similar prefixes the model sometimes adds
let lower = text.to_lowercase();
for prefix in &["sol:", "sol 💕:", "sol💕:"] {
if lower.starts_with(prefix) {
text = text[prefix.len()..].trim().to_string();
break;
}
}
if text.is_empty() {
info!("Generated empty response, skipping send");
let _ = room.typing_notice(false).await;
return None;
}
let preview: String = text.chars().take(120).collect();
let _ = room.typing_notice(false).await;
info!(
response_len = text.len(),
response_preview = preview.as_str(),
is_spontaneous,
tool_iterations = iteration,
"Generated response"
);
return Some(text);
}
let _ = room.typing_notice(false).await;
warn!("Exceeded max tool iterations");
None
}
async fn load_memory_notes(
&self,
ctx: &ResponseContext,
trigger_body: &str,
) -> Option<String> {
let index = &self.config.opensearch.memory_index;
let user_id = &ctx.user_id;
// Search for topically relevant memories
let mut memories = memory::store::query(
&self.opensearch,
index,
user_id,
trigger_body,
5,
)
.await
.unwrap_or_default();
// Backfill with recent memories if we have fewer than 3
if memories.len() < 3 {
let remaining = 5 - memories.len();
if let Ok(recent) = memory::store::get_recent(
&self.opensearch,
index,
user_id,
remaining,
)
.await
{
let existing_ids: std::collections::HashSet<String> =
memories.iter().map(|m| m.id.clone()).collect();
for doc in recent {
if !existing_ids.contains(&doc.id) && memories.len() < 5 {
memories.push(doc);
}
}
}
}
if memories.is_empty() {
return None;
}
let display = ctx
.display_name
.as_deref()
.unwrap_or(&ctx.matrix_user_id);
Some(format_memory_notes(display, &memories))
}
}
/// Format memory documents into a notes block for the system prompt.
pub(crate) fn format_memory_notes(
display_name: &str,
memories: &[memory::schema::MemoryDocument],
) -> String {
let mut lines = vec![format!(
"## notes about {display_name}\n\n\
these are your private notes about the person you're talking to.\n\
use them to inform your responses but don't mention that you have notes.\n"
)];
for mem in memories {
lines.push(format!("- [{}] {}", mem.category, mem.content));
}
lines.join("\n")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::schema::MemoryDocument;
fn make_mem(id: &str, content: &str, category: &str) -> MemoryDocument {
MemoryDocument {
id: id.into(),
user_id: "sienna@sunbeam.pt".into(),
content: content.into(),
category: category.into(),
created_at: 1710000000000,
updated_at: 1710000000000,
source: "auto".into(),
}
}
#[test]
fn test_format_memory_notes_basic() {
let memories = vec![
make_mem("a", "prefers terse answers", "preference"),
make_mem("b", "working on drive UI", "fact"),
];
let result = format_memory_notes("sienna", &memories);
assert!(result.contains("## notes about sienna"));
assert!(result.contains("don't mention that you have notes"));
assert!(result.contains("- [preference] prefers terse answers"));
assert!(result.contains("- [fact] working on drive UI"));
}
#[test]
fn test_format_memory_notes_single() {
let memories = vec![make_mem("x", "birthday is march 12", "context")];
let result = format_memory_notes("lonni", &memories);
assert!(result.contains("## notes about lonni"));
assert!(result.contains("- [context] birthday is march 12"));
}
#[test]
fn test_format_memory_notes_uses_display_name() {
let memories = vec![make_mem("a", "test", "general")];
let result = format_memory_notes("Amber", &memories);
assert!(result.contains("## notes about Amber"));
}
}

View File

@@ -25,6 +25,8 @@ pub struct OpenSearchConfig {
pub flush_interval_ms: u64,
#[serde(default = "default_embedding_pipeline")]
pub embedding_pipeline: String,
#[serde(default = "default_memory_index")]
pub memory_index: String,
}
#[derive(Debug, Clone, Deserialize)]
@@ -59,6 +61,30 @@ pub struct BehaviorConfig {
pub backfill_on_join: bool,
#[serde(default = "default_backfill_limit")]
pub backfill_limit: usize,
#[serde(default)]
pub instant_responses: bool,
#[serde(default = "default_cooldown_after_response_ms")]
pub cooldown_after_response_ms: u64,
#[serde(default = "default_evaluation_context_window")]
pub evaluation_context_window: usize,
#[serde(default = "default_detect_sol_in_conversation")]
pub detect_sol_in_conversation: bool,
#[serde(default)]
pub evaluation_prompt_active: Option<String>,
#[serde(default)]
pub evaluation_prompt_passive: Option<String>,
#[serde(default = "default_reaction_threshold")]
pub reaction_threshold: f32,
#[serde(default = "default_reaction_enabled")]
pub reaction_enabled: bool,
#[serde(default = "default_script_timeout_secs")]
pub script_timeout_secs: u64,
#[serde(default = "default_script_max_heap_mb")]
pub script_max_heap_mb: usize,
#[serde(default)]
pub script_fetch_allowlist: Vec<String>,
#[serde(default = "default_memory_extraction_enabled")]
pub memory_extraction_enabled: bool,
}
fn default_batch_size() -> usize { 50 }
@@ -68,15 +94,24 @@ fn default_model() -> String { "mistral-medium-latest".into() }
fn default_evaluation_model() -> String { "ministral-3b-latest".into() }
fn default_research_model() -> String { "mistral-large-latest".into() }
fn default_max_tool_iterations() -> usize { 5 }
fn default_response_delay_min_ms() -> u64 { 2000 }
fn default_response_delay_max_ms() -> u64 { 8000 }
fn default_response_delay_min_ms() -> u64 { 100 }
fn default_response_delay_max_ms() -> u64 { 2300 }
fn default_spontaneous_delay_min_ms() -> u64 { 15000 }
fn default_spontaneous_delay_max_ms() -> u64 { 60000 }
fn default_spontaneous_threshold() -> f32 { 0.7 }
fn default_spontaneous_threshold() -> f32 { 0.85 }
fn default_cooldown_after_response_ms() -> u64 { 15000 }
fn default_evaluation_context_window() -> usize { 25 }
fn default_detect_sol_in_conversation() -> bool { true }
fn default_reaction_threshold() -> f32 { 0.6 }
fn default_reaction_enabled() -> bool { true }
fn default_room_context_window() -> usize { 30 }
fn default_dm_context_window() -> usize { 100 }
fn default_backfill_on_join() -> bool { true }
fn default_backfill_limit() -> usize { 10000 }
fn default_script_timeout_secs() -> u64 { 5 }
fn default_script_max_heap_mb() -> usize { 64 }
fn default_memory_index() -> String { "sol_user_memory".into() }
fn default_memory_extraction_enabled() -> bool { true }
impl Config {
pub fn load(path: &str) -> anyhow::Result<Self> {
@@ -155,19 +190,23 @@ backfill_limit = 5000
assert_eq!(config.opensearch.batch_size, 50);
assert_eq!(config.opensearch.flush_interval_ms, 2000);
assert_eq!(config.opensearch.embedding_pipeline, "tuwunel_embedding_pipeline");
assert_eq!(config.opensearch.memory_index, "sol_user_memory");
assert_eq!(config.mistral.default_model, "mistral-medium-latest");
assert_eq!(config.mistral.evaluation_model, "ministral-3b-latest");
assert_eq!(config.mistral.research_model, "mistral-large-latest");
assert_eq!(config.mistral.max_tool_iterations, 5);
assert_eq!(config.behavior.response_delay_min_ms, 2000);
assert_eq!(config.behavior.response_delay_max_ms, 8000);
assert_eq!(config.behavior.response_delay_min_ms, 100);
assert_eq!(config.behavior.response_delay_max_ms, 2300);
assert_eq!(config.behavior.spontaneous_delay_min_ms, 15000);
assert_eq!(config.behavior.spontaneous_delay_max_ms, 60000);
assert!((config.behavior.spontaneous_threshold - 0.7).abs() < f32::EPSILON);
assert!((config.behavior.spontaneous_threshold - 0.85).abs() < f32::EPSILON);
assert!(!config.behavior.instant_responses);
assert_eq!(config.behavior.cooldown_after_response_ms, 15000);
assert_eq!(config.behavior.room_context_window, 30);
assert_eq!(config.behavior.dm_context_window, 100);
assert!(config.behavior.backfill_on_join);
assert_eq!(config.behavior.backfill_limit, 10000);
assert!(config.behavior.memory_extraction_enabled);
}
#[test]

56
src/context.rs Normal file
View File

@@ -0,0 +1,56 @@
/// Per-message response context, threading the sender's identity from Matrix
/// through the tool loop and memory system.
#[derive(Debug, Clone)]
pub struct ResponseContext {
/// Full Matrix user ID, e.g. `@sienna:sunbeam.pt`
pub matrix_user_id: String,
/// Derived portable ID, e.g. `sienna@sunbeam.pt`
pub user_id: String,
/// Display name if available
pub display_name: Option<String>,
/// Whether this message was sent in a DM
pub is_dm: bool,
/// Whether this message is a reply to Sol
pub is_reply: bool,
/// The room this message was sent in
pub room_id: String,
}
/// Derive a portable user ID from a Matrix user ID.
///
/// `@sienna:sunbeam.pt` → `sienna@sunbeam.pt`
pub fn derive_user_id(matrix_user_id: &str) -> String {
let stripped = matrix_user_id.strip_prefix('@').unwrap_or(matrix_user_id);
stripped.replacen(':', "@", 1)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_derive_user_id_standard() {
assert_eq!(derive_user_id("@sienna:sunbeam.pt"), "sienna@sunbeam.pt");
}
#[test]
fn test_derive_user_id_no_at_prefix() {
assert_eq!(derive_user_id("sienna:sunbeam.pt"), "sienna@sunbeam.pt");
}
#[test]
fn test_derive_user_id_complex() {
assert_eq!(
derive_user_id("@user.name:matrix.org"),
"user.name@matrix.org"
);
}
#[test]
fn test_derive_user_id_only_first_colon() {
assert_eq!(
derive_user_id("@user:server:8448"),
"user@server:8448"
);
}
}

View File

@@ -1,7 +1,9 @@
mod archive;
mod brain;
mod config;
mod context;
mod matrix_utils;
mod memory;
mod sync;
mod tools;
@@ -18,7 +20,8 @@ use url::Url;
use archive::indexer::Indexer;
use archive::schema::create_index_if_not_exists;
use brain::conversation::ConversationManager;
use brain::conversation::{ContextMessage, ConversationManager};
use memory::schema::create_index_if_not_exists as create_memory_index;
use brain::evaluator::Evaluator;
use brain::personality::Personality;
use brain::responder::Responder;
@@ -93,8 +96,9 @@ async fn main() -> anyhow::Result<()> {
.build()?;
let os_client = OpenSearch::new(os_transport);
// Ensure index exists
// Ensure indices exist
create_index_if_not_exists(&os_client, &config.opensearch.index).await?;
create_memory_index(&os_client, &config.opensearch.memory_index).await?;
// Initialize Mistral client
let mistral_client = mistralai_client::v1::client::Client::new(
@@ -107,22 +111,32 @@ async fn main() -> anyhow::Result<()> {
// Build components
let personality = Arc::new(Personality::new(system_prompt));
let conversations = Arc::new(Mutex::new(ConversationManager::new(
config.behavior.room_context_window,
config.behavior.dm_context_window,
)));
// Backfill conversation context from archive before starting
if config.behavior.backfill_on_join {
info!("Backfilling conversation context from archive...");
if let Err(e) = backfill_conversations(&os_client, &config, &conversations).await {
error!("Backfill failed (non-fatal): {e}");
}
}
let tool_registry = Arc::new(ToolRegistry::new(
os_client.clone(),
matrix_client.clone(),
config.clone(),
));
let indexer = Arc::new(Indexer::new(os_client, config.clone()));
let indexer = Arc::new(Indexer::new(os_client.clone(), config.clone()));
let evaluator = Arc::new(Evaluator::new(config.clone()));
let responder = Arc::new(Responder::new(
config.clone(),
personality,
tool_registry,
os_client.clone(),
));
let conversations = Arc::new(Mutex::new(ConversationManager::new(
config.behavior.room_context_window,
config.behavior.dm_context_window,
)));
// Start background flush task
let _flush_handle = indexer.start_flush_task();
@@ -135,8 +149,17 @@ async fn main() -> anyhow::Result<()> {
responder,
conversations,
mistral,
opensearch: os_client,
last_response: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
responding_in: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
});
// Backfill reactions from Matrix room timelines
info!("Backfilling reactions from room timelines...");
if let Err(e) = backfill_reactions(&matrix_client, &state.indexer).await {
error!("Reaction backfill failed (non-fatal): {e}");
}
// Start sync loop in background
let sync_client = matrix_client.clone();
let sync_state = state.clone();
@@ -158,3 +181,160 @@ async fn main() -> anyhow::Result<()> {
info!("Sol has shut down");
Ok(())
}
/// Backfill conversation context from the OpenSearch archive.
///
/// Queries the most recent messages per room and seeds the ConversationManager
/// so Sol has context surviving restarts.
async fn backfill_conversations(
os_client: &OpenSearch,
config: &Config,
conversations: &Arc<Mutex<ConversationManager>>,
) -> anyhow::Result<()> {
use serde_json::json;
let window = config.behavior.room_context_window.max(config.behavior.dm_context_window);
let index = &config.opensearch.index;
// Get all distinct rooms
let agg_body = json!({
"size": 0,
"aggs": {
"rooms": {
"terms": { "field": "room_id", "size": 500 }
}
}
});
let response = os_client
.search(opensearch::SearchParts::Index(&[index]))
.body(agg_body)
.send()
.await?;
let body: serde_json::Value = response.json().await?;
let buckets = body["aggregations"]["rooms"]["buckets"]
.as_array()
.cloned()
.unwrap_or_default();
let mut total = 0;
for bucket in &buckets {
let room_id = bucket["key"].as_str().unwrap_or("");
if room_id.is_empty() {
continue;
}
// Fetch recent messages for this room
let query = json!({
"size": window,
"sort": [{ "timestamp": "asc" }],
"query": {
"bool": {
"filter": [
{ "term": { "room_id": room_id } },
{ "term": { "redacted": false } }
]
}
},
"_source": ["sender_name", "sender", "content", "timestamp"]
});
let resp = os_client
.search(opensearch::SearchParts::Index(&[index]))
.body(query)
.send()
.await?;
let data: serde_json::Value = resp.json().await?;
let hits = data["hits"]["hits"].as_array().cloned().unwrap_or_default();
if hits.is_empty() {
continue;
}
let mut convs = conversations.lock().await;
for hit in &hits {
let src = &hit["_source"];
let sender = src["sender_name"]
.as_str()
.or_else(|| src["sender"].as_str())
.unwrap_or("unknown");
let content = src["content"].as_str().unwrap_or("");
let timestamp = src["timestamp"].as_i64().unwrap_or(0);
convs.add_message(
room_id,
false, // we don't know if it's a DM from the archive, use group window
ContextMessage {
sender: sender.to_string(),
content: content.to_string(),
timestamp,
},
);
total += 1;
}
}
info!(rooms = buckets.len(), messages = total, "Backfill complete");
Ok(())
}
/// Backfill reactions from Matrix room timelines into the archive.
///
/// For each joined room, fetches recent timeline events and indexes any
/// m.reaction events that aren't already in the archive.
async fn backfill_reactions(
client: &Client,
indexer: &Arc<Indexer>,
) -> anyhow::Result<()> {
use matrix_sdk::room::MessagesOptions;
use ruma::events::AnySyncTimelineEvent;
use ruma::uint;
let rooms = client.joined_rooms();
let mut total = 0;
for room in &rooms {
let room_id = room.room_id().to_string();
// Fetch recent messages (backwards from now)
let mut options = MessagesOptions::backward();
options.limit = uint!(500);
let messages = match room.messages(options).await {
Ok(m) => m,
Err(e) => {
error!(room = room_id.as_str(), "Failed to fetch timeline for reaction backfill: {e}");
continue;
}
};
for event in &messages.chunk {
let Ok(deserialized) = event.raw().deserialize() else {
continue;
};
if let AnySyncTimelineEvent::MessageLike(
ruma::events::AnySyncMessageLikeEvent::Reaction(reaction_event),
) = deserialized
{
let original = match reaction_event {
ruma::events::SyncMessageLikeEvent::Original(ref o) => o,
_ => continue,
};
let target_event_id = original.content.relates_to.event_id.to_string();
let sender = original.sender.to_string();
let emoji = &original.content.relates_to.key;
let timestamp: i64 = original.origin_server_ts.0.into();
indexer.add_reaction(&target_event_id, &sender, emoji, timestamp).await;
total += 1;
}
}
}
info!(reactions = total, rooms = rooms.len(), "Reaction backfill complete");
Ok(())
}

View File

@@ -3,7 +3,8 @@ use matrix_sdk::RoomMemberships;
use ruma::events::room::message::{
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
};
use ruma::events::relation::InReplyTo;
use ruma::events::relation::{Annotation, InReplyTo};
use ruma::events::reaction::ReactionEventContent;
use ruma::OwnedEventId;
/// Extract the plain-text body from a message event.
@@ -45,15 +46,27 @@ pub fn extract_thread_id(event: &OriginalSyncRoomMessageEvent) -> Option<OwnedEv
None
}
/// Build a reply message content with m.in_reply_to relation.
/// Build a reply message content with m.in_reply_to relation and markdown rendering.
pub fn make_reply_content(body: &str, reply_to_event_id: OwnedEventId) -> RoomMessageEventContent {
let mut content = RoomMessageEventContent::text_plain(body);
let mut content = RoomMessageEventContent::text_markdown(body);
content.relates_to = Some(Relation::Reply {
in_reply_to: InReplyTo::new(reply_to_event_id),
});
content
}
/// Send an emoji reaction to a message.
pub async fn send_reaction(
room: &Room,
event_id: OwnedEventId,
emoji: &str,
) -> anyhow::Result<()> {
let annotation = Annotation::new(event_id, emoji.to_string());
let content = ReactionEventContent::new(annotation);
room.send(content).await?;
Ok(())
}
/// Get the display name for a room.
pub fn room_display_name(room: &Room) -> String {
room.cached_display_name()

158
src/memory/extractor.rs Normal file
View File

@@ -0,0 +1,158 @@
use std::sync::Arc;
use mistralai_client::v1::{
chat::{ChatMessage, ChatParams, ResponseFormat},
constants::Model,
};
use opensearch::OpenSearch;
use serde::Deserialize;
use tracing::{debug, warn};
use crate::config::Config;
use crate::context::ResponseContext;
use crate::brain::responder::chat_blocking;
use super::store;
#[derive(Debug, Deserialize)]
pub(crate) struct ExtractionResponse {
pub memories: Vec<ExtractedMemory>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ExtractedMemory {
pub content: String,
pub category: String,
}
/// Validate and normalize a category string.
pub(crate) fn normalize_category(raw: &str) -> &str {
match raw {
"preference" | "fact" | "context" => raw,
_ => "general",
}
}
pub async fn extract_and_store(
mistral: &Arc<mistralai_client::v1::client::Client>,
opensearch: &OpenSearch,
config: &Config,
ctx: &ResponseContext,
user_message: &str,
sol_response: &str,
) -> anyhow::Result<()> {
let display = ctx
.display_name
.as_deref()
.unwrap_or(&ctx.matrix_user_id);
let prompt = format!(
"Analyze this conversation exchange and extract any facts worth remembering about {display}.\n\
Focus on: preferences, personal details, ongoing projects, opinions, recurring topics.\n\n\
They said: {user_message}\n\
Response: {sol_response}\n\n\
Respond ONLY with JSON: {{\"memories\": [{{\"content\": \"...\", \"category\": \"preference|fact|context\"}}]}}\n\
If nothing worth remembering, respond with {{\"memories\": []}}.\n\
Be selective — only genuinely useful information."
);
let messages = vec![ChatMessage::new_user_message(&prompt)];
let model = Model::new(&config.mistral.evaluation_model);
let params = ChatParams {
response_format: Some(ResponseFormat::json_object()),
..Default::default()
};
let response = chat_blocking(mistral, model, messages, params).await?;
let text = response.choices[0].message.content.trim();
let extraction: ExtractionResponse = match serde_json::from_str(text) {
Ok(e) => e,
Err(e) => {
debug!(raw = text, "Failed to parse extraction response: {e}");
return Ok(());
}
};
if extraction.memories.is_empty() {
debug!("No memories extracted");
return Ok(());
}
let index = &config.opensearch.memory_index;
for mem in &extraction.memories {
let category = normalize_category(&mem.category);
if let Err(e) = store::set(
opensearch,
index,
&ctx.user_id,
&mem.content,
category,
"auto",
)
.await
{
warn!("Failed to store extracted memory: {e}");
}
}
debug!(
count = extraction.memories.len(),
user = ctx.user_id.as_str(),
"Extracted and stored memories"
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_extraction_response_with_memories() {
let json = r#"{"memories": [
{"content": "prefers terse answers", "category": "preference"},
{"content": "working on drive UI", "category": "fact"}
]}"#;
let resp: ExtractionResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.memories.len(), 2);
assert_eq!(resp.memories[0].content, "prefers terse answers");
assert_eq!(resp.memories[0].category, "preference");
assert_eq!(resp.memories[1].category, "fact");
}
#[test]
fn test_parse_extraction_response_empty() {
let json = r#"{"memories": []}"#;
let resp: ExtractionResponse = serde_json::from_str(json).unwrap();
assert!(resp.memories.is_empty());
}
#[test]
fn test_parse_extraction_response_invalid_json() {
let json = "not json at all";
assert!(serde_json::from_str::<ExtractionResponse>(json).is_err());
}
#[test]
fn test_parse_extraction_response_missing_field() {
let json = r#"{"memories": [{"content": "hi"}]}"#;
assert!(serde_json::from_str::<ExtractionResponse>(json).is_err());
}
#[test]
fn test_normalize_category_valid() {
assert_eq!(normalize_category("preference"), "preference");
assert_eq!(normalize_category("fact"), "fact");
assert_eq!(normalize_category("context"), "context");
}
#[test]
fn test_normalize_category_unknown_falls_back() {
assert_eq!(normalize_category("opinion"), "general");
assert_eq!(normalize_category(""), "general");
assert_eq!(normalize_category("PREFERENCE"), "general");
}
}

3
src/memory/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod extractor;
pub mod schema;
pub mod store;

118
src/memory/schema.rs Normal file
View File

@@ -0,0 +1,118 @@
use opensearch::OpenSearch;
use serde::{Deserialize, Serialize};
use tracing::info;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryDocument {
pub id: String,
pub user_id: String,
pub content: String,
pub category: String,
pub created_at: i64,
pub updated_at: i64,
pub source: String,
}
const INDEX_MAPPING: &str = r#"{
"settings": {
"number_of_shards": 1,
"number_of_replicas": 0
},
"mappings": {
"properties": {
"id": { "type": "keyword" },
"user_id": { "type": "keyword" },
"content": { "type": "text", "analyzer": "standard" },
"category": { "type": "keyword" },
"created_at": { "type": "date", "format": "epoch_millis" },
"updated_at": { "type": "date", "format": "epoch_millis" },
"source": { "type": "keyword" }
}
}
}"#;
pub async fn create_index_if_not_exists(client: &OpenSearch, index: &str) -> anyhow::Result<()> {
let exists = client
.indices()
.exists(opensearch::indices::IndicesExistsParts::Index(&[index]))
.send()
.await?;
if exists.status_code().is_success() {
info!(index, "Memory index already exists");
return Ok(());
}
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING)?;
let response = client
.indices()
.create(opensearch::indices::IndicesCreateParts::Index(index))
.body(mapping)
.send()
.await?;
if !response.status_code().is_success() {
let body = response.text().await?;
anyhow::bail!("Failed to create memory index {index}: {body}");
}
info!(index, "Created memory index");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_document_serialize() {
let doc = MemoryDocument {
id: "abc-123".into(),
user_id: "sienna@sunbeam.pt".into(),
content: "prefers terse answers".into(),
category: "preference".into(),
created_at: 1710000000000,
updated_at: 1710000000000,
source: "auto".into(),
};
let json = serde_json::to_value(&doc).unwrap();
assert_eq!(json["user_id"], "sienna@sunbeam.pt");
assert_eq!(json["category"], "preference");
assert_eq!(json["source"], "auto");
}
#[test]
fn test_memory_document_roundtrip() {
let doc = MemoryDocument {
id: "xyz".into(),
user_id: "lonni@sunbeam.pt".into(),
content: "working on UI redesign".into(),
category: "fact".into(),
created_at: 1710000000000,
updated_at: 1710000000000,
source: "script".into(),
};
let json_str = serde_json::to_string(&doc).unwrap();
let roundtrip: MemoryDocument = serde_json::from_str(&json_str).unwrap();
assert_eq!(roundtrip.id, doc.id);
assert_eq!(roundtrip.user_id, doc.user_id);
assert_eq!(roundtrip.content, doc.content);
}
#[test]
fn test_index_mapping_valid_json() {
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING).unwrap();
assert_eq!(
mapping["mappings"]["properties"]["user_id"]["type"]
.as_str()
.unwrap(),
"keyword"
);
assert_eq!(
mapping["mappings"]["properties"]["content"]["type"]
.as_str()
.unwrap(),
"text"
);
}
}

187
src/memory/store.rs Normal file
View File

@@ -0,0 +1,187 @@
use chrono::Utc;
use opensearch::OpenSearch;
use serde_json::json;
use uuid::Uuid;
use super::schema::MemoryDocument;
/// Search memories by content relevance, filtered to a specific user.
pub async fn query(
client: &OpenSearch,
index: &str,
user_id: &str,
query_text: &str,
limit: usize,
) -> anyhow::Result<Vec<MemoryDocument>> {
let body = json!({
"size": limit,
"query": {
"bool": {
"filter": [
{ "term": { "user_id": user_id } }
],
"must": [
{ "match": { "content": query_text } }
]
}
},
"sort": [{ "_score": "desc" }]
});
let response = client
.search(opensearch::SearchParts::Index(&[index]))
.body(body)
.send()
.await?;
let data: serde_json::Value = response.json().await?;
parse_hits(&data)
}
/// Get the most recent memories for a user, sorted by updated_at desc.
pub async fn get_recent(
client: &OpenSearch,
index: &str,
user_id: &str,
limit: usize,
) -> anyhow::Result<Vec<MemoryDocument>> {
let body = json!({
"size": limit,
"query": {
"bool": {
"filter": [
{ "term": { "user_id": user_id } }
]
}
},
"sort": [{ "updated_at": "desc" }]
});
let response = client
.search(opensearch::SearchParts::Index(&[index]))
.body(body)
.send()
.await?;
let data: serde_json::Value = response.json().await?;
parse_hits(&data)
}
/// Store a new memory document for a user.
pub async fn set(
client: &OpenSearch,
index: &str,
user_id: &str,
content: &str,
category: &str,
source: &str,
) -> anyhow::Result<()> {
let now = Utc::now().timestamp_millis();
let id = Uuid::new_v4().to_string();
let doc = MemoryDocument {
id: id.clone(),
user_id: user_id.to_string(),
content: content.to_string(),
category: category.to_string(),
created_at: now,
updated_at: now,
source: source.to_string(),
};
let response = client
.index(opensearch::IndexParts::IndexId(index, &id))
.body(serde_json::to_value(&doc)?)
.send()
.await?;
if !response.status_code().is_success() {
let body = response.text().await?;
anyhow::bail!("Failed to store memory: {body}");
}
Ok(())
}
pub(crate) fn parse_hits(data: &serde_json::Value) -> anyhow::Result<Vec<MemoryDocument>> {
let hits = data["hits"]["hits"]
.as_array()
.cloned()
.unwrap_or_default();
let mut docs = Vec::with_capacity(hits.len());
for hit in &hits {
if let Ok(doc) = serde_json::from_value::<MemoryDocument>(hit["_source"].clone()) {
docs.push(doc);
}
}
Ok(docs)
}
#[cfg(test)]
mod tests {
use super::*;
fn fake_os_response(sources: Vec<serde_json::Value>) -> serde_json::Value {
let hits: Vec<serde_json::Value> = sources
.into_iter()
.map(|s| json!({ "_source": s }))
.collect();
json!({ "hits": { "hits": hits } })
}
#[test]
fn test_parse_hits_multiple() {
let data = fake_os_response(vec![
json!({
"id": "a", "user_id": "sienna@sunbeam.pt",
"content": "prefers terse answers", "category": "preference",
"created_at": 1710000000000_i64, "updated_at": 1710000000000_i64,
"source": "auto"
}),
json!({
"id": "b", "user_id": "sienna@sunbeam.pt",
"content": "working on drive UI", "category": "fact",
"created_at": 1710000000000_i64, "updated_at": 1710000000000_i64,
"source": "script"
}),
]);
let docs = parse_hits(&data).unwrap();
assert_eq!(docs.len(), 2);
assert_eq!(docs[0].id, "a");
assert_eq!(docs[0].content, "prefers terse answers");
assert_eq!(docs[1].id, "b");
assert_eq!(docs[1].category, "fact");
}
#[test]
fn test_parse_hits_empty() {
let data = json!({ "hits": { "hits": [] } });
let docs = parse_hits(&data).unwrap();
assert!(docs.is_empty());
}
#[test]
fn test_parse_hits_missing_structure() {
let data = json!({});
let docs = parse_hits(&data).unwrap();
assert!(docs.is_empty());
}
#[test]
fn test_parse_hits_skips_malformed() {
let data = fake_os_response(vec![
json!({
"id": "good", "user_id": "x@y",
"content": "ok", "category": "fact",
"created_at": 1, "updated_at": 1, "source": "auto"
}),
json!({ "bad": "no fields" }),
]);
let docs = parse_hits(&data).unwrap();
assert_eq!(docs.len(), 1);
assert_eq!(docs[0].id, "good");
}
}

View File

@@ -1,13 +1,18 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Instant;
use matrix_sdk::config::SyncSettings;
use matrix_sdk::room::Room;
use matrix_sdk::Client;
use ruma::events::reaction::OriginalSyncReactionEvent;
use ruma::events::room::member::StrippedRoomMemberEvent;
use ruma::events::room::message::OriginalSyncRoomMessageEvent;
use ruma::events::room::redaction::OriginalSyncRoomRedactionEvent;
use tokio::sync::Mutex;
use tracing::{error, info, warn};
use tracing::{debug, error, info, warn};
use opensearch::OpenSearch;
use crate::archive::indexer::Indexer;
use crate::archive::schema::ArchiveDocument;
@@ -15,7 +20,9 @@ use crate::brain::conversation::{ContextMessage, ConversationManager};
use crate::brain::evaluator::{Engagement, Evaluator};
use crate::brain::responder::Responder;
use crate::config::Config;
use crate::context::{self, ResponseContext};
use crate::matrix_utils;
use crate::memory;
pub struct AppState {
pub config: Arc<Config>,
@@ -24,6 +31,11 @@ pub struct AppState {
pub responder: Arc<Responder>,
pub conversations: Arc<Mutex<ConversationManager>>,
pub mistral: Arc<mistralai_client::v1::client::Client>,
pub opensearch: OpenSearch,
/// Tracks when Sol last responded in each room (for cooldown)
pub last_response: Arc<Mutex<HashMap<String, Instant>>>,
/// Tracks rooms where a response is currently being generated (in-flight guard)
pub responding_in: Arc<Mutex<std::collections::HashSet<String>>>,
}
pub async fn start_sync(client: Client, state: Arc<AppState>) -> anyhow::Result<()> {
@@ -50,6 +62,16 @@ pub async fn start_sync(client: Client, state: Arc<AppState>) -> anyhow::Result<
},
);
let s = state.clone();
client.add_event_handler(
move |event: OriginalSyncReactionEvent, _room: Room| {
let state = s.clone();
async move {
handle_reaction(event, &state).await;
}
},
);
client.add_event_handler(
move |event: StrippedRoomMemberEvent, room: Room| async move {
handle_invite(event, room).await;
@@ -95,6 +117,7 @@ async fn handle_message(
.and_then(|m| m.display_name().map(|s| s.to_string()));
let reply_to = matrix_utils::extract_reply_to(&event).map(|id| id.to_string());
let is_reply = reply_to.is_some();
let thread_id = matrix_utils::extract_thread_id(&event).map(|id| id.to_string());
// Archive the message
@@ -112,11 +135,22 @@ async fn handle_message(
event_type: "m.room.message".into(),
edited: false,
redacted: false,
reactions: Vec::new(),
};
state.indexer.add(doc).await;
// Update conversation context
let is_dm = room.is_direct().await.unwrap_or(false);
let response_ctx = ResponseContext {
matrix_user_id: sender.clone(),
user_id: context::derive_user_id(&sender),
display_name: sender_name.clone(),
is_dm,
is_reply,
room_id: room_id.clone(),
};
{
let mut convs = state.conversations.lock().await;
convs.add_message(
@@ -147,13 +181,20 @@ async fn handle_message(
let (should_respond, is_spontaneous) = match engagement {
Engagement::MustRespond { reason } => {
info!(?reason, "Must respond");
info!(room = room_id.as_str(), ?reason, "Must respond");
(true, false)
}
Engagement::MaybeRespond { relevance, hook } => {
info!(relevance, hook = hook.as_str(), "Maybe respond (spontaneous)");
info!(room = room_id.as_str(), relevance, hook = hook.as_str(), "Maybe respond (spontaneous)");
(true, true)
}
Engagement::React { emoji, relevance } => {
info!(room = room_id.as_str(), relevance, emoji = emoji.as_str(), "Reacting with emoji");
if let Err(e) = matrix_utils::send_reaction(&room, event.event_id.clone().into(), &emoji).await {
error!("Failed to send reaction: {e}");
}
(false, false)
}
Engagement::Ignore => (false, false),
};
@@ -161,8 +202,38 @@ async fn handle_message(
return Ok(());
}
// Show typing indicator
let _ = room.typing_notice(true).await;
// In-flight guard: skip if we're already generating a response for this room
{
let responding = state.responding_in.lock().await;
if responding.contains(&room_id) {
debug!(room = room_id.as_str(), "Skipping — response already in flight for this room");
return Ok(());
}
}
// Cooldown check: skip spontaneous if we responded recently
if is_spontaneous {
let last = state.last_response.lock().await;
if let Some(ts) = last.get(&room_id) {
let elapsed = ts.elapsed().as_millis() as u64;
let cooldown = state.config.behavior.cooldown_after_response_ms;
if elapsed < cooldown {
debug!(
room = room_id.as_str(),
elapsed_ms = elapsed,
cooldown_ms = cooldown,
"Skipping spontaneous — within cooldown period"
);
return Ok(());
}
}
}
// Mark room as in-flight
{
let mut responding = state.responding_in.lock().await;
responding.insert(room_id.clone());
}
let context = {
let convs = state.conversations.lock().await;
@@ -181,22 +252,74 @@ async fn handle_message(
&members,
is_spontaneous,
&state.mistral,
&room,
&response_ctx,
)
.await;
// Stop typing indicator
let _ = room.typing_notice(false).await;
if let Some(text) = response {
let content = matrix_utils::make_reply_content(&text, event.event_id.to_owned());
// Reply with reference only when directly addressed. Spontaneous
// and DM messages are sent as plain content — feels more natural.
let content = if !is_spontaneous && !is_dm {
matrix_utils::make_reply_content(&text, event.event_id.to_owned())
} else {
ruma::events::room::message::RoomMessageEventContent::text_markdown(&text)
};
if let Err(e) = room.send(content).await {
error!("Failed to send response: {e}");
} else {
info!(room = room_id.as_str(), len = text.len(), is_dm, "Response sent");
}
// Post-response memory extraction (fire-and-forget)
if state.config.behavior.memory_extraction_enabled {
let ctx = response_ctx.clone();
let mistral = state.mistral.clone();
let os = state.opensearch.clone();
let config = state.config.clone();
let user_msg = body.clone();
let sol_response = text.clone();
tokio::spawn(async move {
if let Err(e) = memory::extractor::extract_and_store(
&mistral, &os, &config, &ctx, &user_msg, &sol_response,
)
.await
{
warn!("Memory extraction failed (non-fatal): {e}");
}
});
}
// Update last response timestamp
let mut last = state.last_response.lock().await;
last.insert(room_id.clone(), Instant::now());
}
// Clear in-flight flag
{
let mut responding = state.responding_in.lock().await;
responding.remove(&room_id);
}
Ok(())
}
async fn handle_reaction(event: OriginalSyncReactionEvent, state: &AppState) {
let target_event_id = event.content.relates_to.event_id.to_string();
let sender = event.sender.to_string();
let emoji = &event.content.relates_to.key;
let timestamp: i64 = event.origin_server_ts.0.into();
info!(
target = target_event_id.as_str(),
sender = sender.as_str(),
emoji = emoji.as_str(),
"Indexing reaction"
);
state.indexer.add_reaction(&target_event_id, &sender, emoji, timestamp).await;
}
async fn handle_redaction(event: OriginalSyncRoomRedactionEvent, state: &AppState) {
if let Some(redacted_id) = &event.redacts {
state.indexer.update_redaction(&redacted_id.to_string()).await;

View File

@@ -1,5 +1,6 @@
pub mod room_history;
pub mod room_info;
pub mod script;
pub mod search;
use std::sync::Arc;
@@ -10,6 +11,7 @@ use opensearch::OpenSearch;
use serde_json::json;
use crate::config::Config;
use crate::context::ResponseContext;
pub struct ToolRegistry {
opensearch: OpenSearch,
@@ -122,10 +124,44 @@ impl ToolRegistry {
"required": ["room_id"]
}),
),
Tool::new(
"run_script".into(),
"Execute a TypeScript/JavaScript snippet in a sandboxed runtime. \
Use this for math, date calculations, data transformations, or any \
computation that needs precision. The script has access to:\n\
- sol.search(query, opts?) — search the message archive. opts: \
{ room?, sender?, after?, before?, limit?, semantic? }\n\
- sol.rooms() — list joined rooms (returns array of {name, id, members})\n\
- sol.members(roomName) — get room members (returns array of {name, id})\n\
- sol.fetch(url) — HTTP GET (allowlisted domains only)\n\
- sol.memory.get(query?) — retrieve internal notes relevant to the query\n\
- sol.memory.set(content, category?) — save an internal note for later reference\n\
- sol.fs.read(path), sol.fs.write(path, content), sol.fs.list(path?) — \
sandboxed temp filesystem for intermediate files\n\
- console.log() to produce output\n\
All sol.* methods are async — use await. The last expression value is \
also captured. Output is truncated to 4096 chars."
.into(),
json!({
"type": "object",
"properties": {
"code": {
"type": "string",
"description": "TypeScript or JavaScript code to execute"
}
},
"required": ["code"]
}),
),
]
}
pub async fn execute(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
pub async fn execute(
&self,
name: &str,
arguments: &str,
response_ctx: &ResponseContext,
) -> anyhow::Result<String> {
match name {
"search_archive" => {
search::search_archive(
@@ -145,6 +181,16 @@ impl ToolRegistry {
}
"list_rooms" => room_info::list_rooms(&self.matrix).await,
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await,
"run_script" => {
script::run_script(
&self.opensearch,
&self.matrix,
&self.config,
arguments,
response_ctx,
)
.await
}
_ => anyhow::bail!("Unknown tool: {name}"),
}
}

706
src/tools/script.rs Normal file
View File

@@ -0,0 +1,706 @@
use std::cell::RefCell;
use std::path::{Path, PathBuf};
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::time::Duration;
use deno_core::{extension, op2, JsRuntime, OpState, RuntimeOptions};
use deno_error::JsErrorBox;
use matrix_sdk::Client as MatrixClient;
use opensearch::OpenSearch;
use serde::Deserialize;
use tempfile::TempDir;
use tracing::info;
use crate::config::Config;
use crate::context::ResponseContext;
// ---------------------------------------------------------------------------
// State types stored in OpState
// ---------------------------------------------------------------------------
struct ScriptState {
opensearch: OpenSearch,
matrix: MatrixClient,
config: Arc<Config>,
tmpdir: PathBuf,
user_id: String,
}
struct ScriptOutput(String);
#[derive(Debug, Deserialize)]
struct RunScriptArgs {
code: String,
}
// ---------------------------------------------------------------------------
// Sandbox path resolution
// ---------------------------------------------------------------------------
fn resolve_sandbox_path(
tmpdir: &Path,
requested: &str,
) -> Result<PathBuf, JsErrorBox> {
let base = tmpdir
.canonicalize()
.map_err(|e| JsErrorBox::generic(format!("sandbox error: {e}")))?;
let joined = base.join(requested);
if let Some(parent) = joined.parent() {
std::fs::create_dir_all(parent).ok();
}
let resolved = if joined.exists() {
joined
.canonicalize()
.map_err(|e| JsErrorBox::generic(format!("path error: {e}")))?
} else {
let parent = joined
.parent()
.ok_or_else(|| JsErrorBox::generic("invalid path"))?
.canonicalize()
.map_err(|e| JsErrorBox::generic(format!("path error: {e}")))?;
parent.join(joined.file_name().unwrap_or_default())
};
if !resolved.starts_with(&base) {
return Err(JsErrorBox::generic("path escapes sandbox"));
}
Ok(resolved)
}
// ---------------------------------------------------------------------------
// Ops — async (search, rooms, members, fetch)
// ---------------------------------------------------------------------------
#[op2]
#[string]
async fn op_sol_search(
state: Rc<RefCell<OpState>>,
#[string] query: String,
#[string] opts_json: String,
) -> Result<String, JsErrorBox> {
let (os, index) = {
let st = state.borrow();
let ss = st.borrow::<ScriptState>();
(ss.opensearch.clone(), ss.config.opensearch.index.clone())
};
let mut args: serde_json::Value =
serde_json::from_str(&opts_json).unwrap_or(serde_json::json!({}));
args["query"] = serde_json::Value::String(query);
super::search::search_archive(&os, &index, &args.to_string())
.await
.map_err(|e| JsErrorBox::generic(e.to_string()))
}
#[op2]
#[string]
async fn op_sol_rooms(
state: Rc<RefCell<OpState>>,
) -> Result<String, JsErrorBox> {
let matrix = {
let st = state.borrow();
st.borrow::<ScriptState>().matrix.clone()
};
let rooms = matrix.joined_rooms();
let names: Vec<serde_json::Value> = rooms
.iter()
.map(|r| {
let name = match r.cached_display_name() {
Some(n) => n.to_string(),
None => r.room_id().to_string(),
};
serde_json::json!({
"name": name,
"id": r.room_id().to_string(),
"members": r.joined_members_count(),
})
})
.collect();
serde_json::to_string(&names).map_err(|e| JsErrorBox::generic(e.to_string()))
}
#[op2]
#[string]
async fn op_sol_members(
state: Rc<RefCell<OpState>>,
#[string] room_name: String,
) -> Result<String, JsErrorBox> {
let matrix = {
let st = state.borrow();
st.borrow::<ScriptState>().matrix.clone()
};
let rooms = matrix.joined_rooms();
let room = rooms.iter().find(|r| {
r.cached_display_name()
.map(|n| n.to_string() == room_name)
.unwrap_or(false)
});
let Some(room) = room else {
return Err(JsErrorBox::generic(format!(
"room not found: {room_name}"
)));
};
let members = room
.members(matrix_sdk::RoomMemberships::JOIN)
.await
.map_err(|e| JsErrorBox::generic(e.to_string()))?;
let names: Vec<serde_json::Value> = members
.iter()
.map(|m| {
serde_json::json!({
"name": m.display_name().unwrap_or_else(|| m.user_id().as_str()),
"id": m.user_id().to_string(),
})
})
.collect();
serde_json::to_string(&names).map_err(|e| JsErrorBox::generic(e.to_string()))
}
#[op2]
#[string]
async fn op_sol_fetch(
state: Rc<RefCell<OpState>>,
#[string] url_str: String,
) -> Result<String, JsErrorBox> {
let allowlist = {
let st = state.borrow();
st.borrow::<ScriptState>()
.config
.behavior
.script_fetch_allowlist
.clone()
};
let parsed = url::Url::parse(&url_str)
.map_err(|e| JsErrorBox::generic(format!("invalid URL: {e}")))?;
let domain = parsed
.host_str()
.ok_or_else(|| JsErrorBox::generic("URL has no host"))?;
if !allowlist
.iter()
.any(|d| domain == d || domain.ends_with(&format!(".{d}")))
{
return Err(JsErrorBox::generic(format!(
"domain not in allowlist: {domain}"
)));
}
let resp = reqwest::get(&url_str)
.await
.map_err(|e| JsErrorBox::generic(e.to_string()))?;
let text = resp
.text()
.await
.map_err(|e| JsErrorBox::generic(e.to_string()))?;
let max_len = 32768;
if text.len() > max_len {
Ok(format!("{}...(truncated)", &text[..max_len]))
} else {
Ok(text)
}
}
// ---------------------------------------------------------------------------
// Ops — sync (filesystem sandbox + output collection)
// ---------------------------------------------------------------------------
#[op2]
#[string]
fn op_sol_read_file(
state: &mut OpState,
#[string] path: String,
) -> Result<String, JsErrorBox> {
let tmpdir = state.borrow::<ScriptState>().tmpdir.clone();
let resolved = resolve_sandbox_path(&tmpdir, &path)?;
std::fs::read_to_string(&resolved)
.map_err(|e| JsErrorBox::generic(format!("read error: {e}")))
}
#[op2(fast)]
fn op_sol_write_file(
state: &mut OpState,
#[string] path: String,
#[string] content: String,
) -> Result<(), JsErrorBox> {
let tmpdir = state.borrow::<ScriptState>().tmpdir.clone();
let resolved = resolve_sandbox_path(&tmpdir, &path)?;
std::fs::write(&resolved, content)
.map_err(|e| JsErrorBox::generic(format!("write error: {e}")))
}
#[op2]
#[string]
fn op_sol_list_dir(
state: &mut OpState,
#[string] path: String,
) -> Result<String, JsErrorBox> {
let tmpdir = state.borrow::<ScriptState>().tmpdir.clone();
let resolved = resolve_sandbox_path(&tmpdir, &path)?;
let entries: Vec<String> = std::fs::read_dir(&resolved)
.map_err(|e| JsErrorBox::generic(format!("list error: {e}")))?
.filter_map(|e| e.ok())
.map(|e| e.file_name().to_string_lossy().into_owned())
.collect();
serde_json::to_string(&entries).map_err(|e| JsErrorBox::generic(e.to_string()))
}
#[op2(fast)]
fn op_sol_set_output(
state: &mut OpState,
#[string] output: String,
) -> Result<(), JsErrorBox> {
*state.borrow_mut::<ScriptOutput>() = ScriptOutput(output);
Ok(())
}
// ---------------------------------------------------------------------------
// Ops — async (memory)
// ---------------------------------------------------------------------------
#[op2]
#[string]
async fn op_sol_memory_get(
state: Rc<RefCell<OpState>>,
#[string] query: String,
) -> Result<String, JsErrorBox> {
let (os, index, user_id) = {
let st = state.borrow();
let ss = st.borrow::<ScriptState>();
(
ss.opensearch.clone(),
ss.config.opensearch.memory_index.clone(),
ss.user_id.clone(),
)
};
let results = if query.is_empty() {
crate::memory::store::get_recent(&os, &index, &user_id, 10)
.await
.map_err(|e| JsErrorBox::generic(e.to_string()))?
} else {
crate::memory::store::query(&os, &index, &user_id, &query, 10)
.await
.map_err(|e| JsErrorBox::generic(e.to_string()))?
};
let items: Vec<serde_json::Value> = results
.iter()
.map(|m| {
serde_json::json!({
"content": m.content,
"category": m.category,
})
})
.collect();
serde_json::to_string(&items).map_err(|e| JsErrorBox::generic(e.to_string()))
}
#[op2]
#[string]
async fn op_sol_memory_set(
state: Rc<RefCell<OpState>>,
#[string] content: String,
#[string] category: String,
) -> Result<String, JsErrorBox> {
let (os, index, user_id) = {
let st = state.borrow();
let ss = st.borrow::<ScriptState>();
(
ss.opensearch.clone(),
ss.config.opensearch.memory_index.clone(),
ss.user_id.clone(),
)
};
crate::memory::store::set(&os, &index, &user_id, &content, &category, "script")
.await
.map_err(|e| JsErrorBox::generic(e.to_string()))?;
Ok("ok".into())
}
// ---------------------------------------------------------------------------
// Extension
// ---------------------------------------------------------------------------
extension!(
sol_script,
ops = [
op_sol_search,
op_sol_rooms,
op_sol_members,
op_sol_fetch,
op_sol_read_file,
op_sol_write_file,
op_sol_list_dir,
op_sol_set_output,
op_sol_memory_get,
op_sol_memory_set,
],
state = |state| {
state.put(ScriptOutput(String::new()));
},
);
// ---------------------------------------------------------------------------
// Bootstrap JS — injected before user code
// ---------------------------------------------------------------------------
const BOOTSTRAP_JS: &str = r#"
const __output = [];
globalThis.console = {
log: (...args) => __output.push(args.map(a => typeof a === 'object' ? JSON.stringify(a) : String(a)).join(' ')),
error: (...args) => __output.push('ERROR: ' + args.map(a => typeof a === 'object' ? JSON.stringify(a) : String(a)).join(' ')),
warn: (...args) => __output.push('WARN: ' + args.map(a => typeof a === 'object' ? JSON.stringify(a) : String(a)).join(' ')),
info: (...args) => __output.push(args.map(a => typeof a === 'object' ? JSON.stringify(a) : String(a)).join(' ')),
};
globalThis.sol = {
search: (query, opts) => Deno.core.ops.op_sol_search(query, opts ? JSON.stringify(opts) : "{}"),
rooms: async () => JSON.parse(await Deno.core.ops.op_sol_rooms()),
members: async (room) => JSON.parse(await Deno.core.ops.op_sol_members(room)),
fetch: (url) => Deno.core.ops.op_sol_fetch(url),
memory: {
get: async (query) => JSON.parse(await Deno.core.ops.op_sol_memory_get(query || "")),
set: async (content, category) => {
await Deno.core.ops.op_sol_memory_set(content, category || "general");
},
},
fs: {
read: (path) => Deno.core.ops.op_sol_read_file(path),
write: (path, content) => Deno.core.ops.op_sol_write_file(path, content),
list: (path) => JSON.parse(Deno.core.ops.op_sol_list_dir(path || ".")),
},
};
"#;
// ---------------------------------------------------------------------------
// TypeScript transpilation
// ---------------------------------------------------------------------------
fn transpile_ts(code: &str) -> anyhow::Result<String> {
use deno_ast::{MediaType, ParseParams, TranspileModuleOptions};
let specifier = deno_ast::ModuleSpecifier::parse("file:///script.ts")?;
let parsed = deno_ast::parse_module(ParseParams {
specifier,
text: code.into(),
media_type: MediaType::TypeScript,
capture_tokens: false,
maybe_syntax: None,
scope_analysis: false,
})?;
let transpiled = parsed.transpile(
&deno_ast::TranspileOptions::default(),
&TranspileModuleOptions::default(),
&deno_ast::EmitOptions::default(),
)?;
Ok(transpiled.into_source().text.to_string())
}
// ---------------------------------------------------------------------------
// Public entry point
// ---------------------------------------------------------------------------
pub async fn run_script(
opensearch: &OpenSearch,
matrix: &MatrixClient,
config: &Config,
args_json: &str,
response_ctx: &ResponseContext,
) -> anyhow::Result<String> {
let args: RunScriptArgs = serde_json::from_str(args_json)?;
let code = args.code.clone();
info!(code_len = code.len(), "Executing script");
// Transpile TS to JS
let js_code = transpile_ts(&code)?;
// Clone state for move into spawn_blocking
let os = opensearch.clone();
let mx = matrix.clone();
let cfg = Arc::new(config.clone());
let timeout_secs = cfg.behavior.script_timeout_secs;
let max_heap_mb = cfg.behavior.script_max_heap_mb;
let user_id = response_ctx.user_id.clone();
// Wrap user code: async IIFE that captures output, then stores via op
let wrapped = format!(
r#"
(async () => {{
try {{
const __result = await (async () => {{
{js_code}
}})();
if (__result !== undefined) {{
__output.push(typeof __result === 'object' ? JSON.stringify(__result, null, 2) : String(__result));
}}
}} catch(e) {{
__output.push('Error: ' + (e.stack || e.message || String(e)));
}}
Deno.core.ops.op_sol_set_output(JSON.stringify(__output));
}})();"#
);
let timeout = Duration::from_secs(timeout_secs);
let result = tokio::task::spawn_blocking(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
rt.block_on(async move {
let tmpdir = TempDir::new()?;
let tmpdir_path = tmpdir.path().to_path_buf();
let mut runtime = JsRuntime::new(RuntimeOptions {
extensions: vec![sol_script::init()],
create_params: Some(
deno_core::v8::CreateParams::default()
.heap_limits(0, max_heap_mb * 1024 * 1024),
),
..Default::default()
});
// Inject shared state
{
let op_state = runtime.op_state();
let mut state = op_state.borrow_mut();
state.put(ScriptState {
opensearch: os,
matrix: mx,
config: cfg,
tmpdir: tmpdir_path,
user_id,
});
}
// V8 isolate termination for timeout
let done = Arc::new(AtomicBool::new(false));
let done_clone = done.clone();
let isolate_handle = runtime.v8_isolate().thread_safe_handle();
std::thread::spawn(move || {
let deadline = std::time::Instant::now() + timeout;
while std::time::Instant::now() < deadline {
if done_clone.load(Ordering::Relaxed) {
return;
}
std::thread::sleep(Duration::from_millis(50));
}
isolate_handle.terminate_execution();
});
// Execute bootstrap
runtime.execute_script("<bootstrap>", BOOTSTRAP_JS)?;
// Execute user code
let exec_result = runtime.execute_script("<script>", wrapped);
if let Err(e) = exec_result {
done.store(true, Ordering::Relaxed);
let msg = e.to_string();
if msg.contains("terminated") {
return Ok("Error: script execution timed out".into());
}
return Ok(format!("Error: {msg}"));
}
// Drive async ops to completion
if let Err(e) = runtime.run_event_loop(Default::default()).await {
done.store(true, Ordering::Relaxed);
let msg = e.to_string();
if msg.contains("terminated") {
return Ok("Error: script execution timed out".into());
}
return Ok(format!("Error: {msg}"));
}
done.store(true, Ordering::Relaxed);
// Read captured output from OpState
let output_json = {
let op_state = runtime.op_state();
let state = op_state.borrow();
state.borrow::<ScriptOutput>().0.clone()
};
let output_arr: Vec<String> =
serde_json::from_str(&output_json).unwrap_or_else(|_| {
if output_json.is_empty() {
vec![]
} else {
vec![output_json]
}
});
let result = output_arr.join("\n");
let max_output = 4096;
let result = if result.len() > max_output {
format!("{}...(truncated)", &result[..max_output])
} else {
result
};
drop(tmpdir);
Ok::<String, anyhow::Error>(result)
})
})
.await?;
result
}
// ---------------------------------------------------------------------------
// Tests
// ---------------------------------------------------------------------------
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_run_script_args() {
let args: RunScriptArgs =
serde_json::from_str(r#"{"code": "console.log(42)"}"#).unwrap();
assert_eq!(args.code, "console.log(42)");
}
#[test]
fn test_transpile_ts_basic() {
let js = transpile_ts("const x: number = 42; x;").unwrap();
assert!(js.contains("const x = 42"));
assert!(!js.contains(": number"));
}
#[test]
fn test_transpile_ts_arrow() {
let js =
transpile_ts("const add = (a: number, b: number): number => a + b;").unwrap();
assert!(js.contains("const add"));
assert!(!js.contains(": number"));
}
#[test]
fn test_transpile_ts_interface() {
let js = transpile_ts(
"interface Foo { bar: string; } const x: Foo = { bar: 'hello' };",
)
.unwrap();
assert!(!js.contains("interface"));
assert!(js.contains("const x"));
}
#[test]
fn test_transpile_ts_invalid() {
let result = transpile_ts("const x: number = {");
assert!(result.is_err());
}
#[test]
fn test_sandbox_path_within() {
let dir = TempDir::new().unwrap();
let base = dir.path();
std::fs::write(base.join("test.txt"), "hello").unwrap();
let resolved = resolve_sandbox_path(base, "test.txt").unwrap();
assert!(resolved.starts_with(base.canonicalize().unwrap()));
}
#[test]
fn test_sandbox_path_escape_rejected() {
let dir = TempDir::new().unwrap();
let base = dir.path();
// Path traversal is blocked — either by canonicalization failure
// (parent doesn't exist) or by the starts_with sandbox check
let result = resolve_sandbox_path(base, "../../../etc/passwd");
assert!(result.is_err());
}
#[test]
fn test_sandbox_path_symlink_escape_rejected() {
let dir = TempDir::new().unwrap();
let base = dir.path();
// Create a symlink pointing outside the sandbox
let link_path = base.join("escape");
std::os::unix::fs::symlink("/tmp", &link_path).unwrap();
// Following the symlink resolves outside sandbox
let result = resolve_sandbox_path(base, "escape/somefile");
assert!(result.is_err());
}
#[test]
fn test_sandbox_path_new_file() {
let dir = TempDir::new().unwrap();
let base = dir.path();
let resolved = resolve_sandbox_path(base, "newfile.txt").unwrap();
assert!(resolved.starts_with(base.canonicalize().unwrap()));
assert!(resolved.ends_with("newfile.txt"));
}
#[test]
fn test_sandbox_path_nested_dir() {
let dir = TempDir::new().unwrap();
let base = dir.path();
let resolved = resolve_sandbox_path(base, "subdir/file.txt").unwrap();
assert!(resolved.starts_with(base.canonicalize().unwrap()));
}
#[tokio::test]
async fn test_basic_script_execution() {
let mut runtime = JsRuntime::new(RuntimeOptions {
extensions: vec![sol_script::init()],
..Default::default()
});
runtime
.execute_script("<bootstrap>", BOOTSTRAP_JS)
.unwrap();
runtime
.execute_script(
"<test>",
r#"
console.log(2 ** 64);
console.log(Math.PI);
Deno.core.ops.op_sol_set_output(JSON.stringify(__output));
"#,
)
.unwrap();
let op_state = runtime.op_state();
let state = op_state.borrow();
let output = &state.borrow::<ScriptOutput>().0;
let arr: Vec<String> = serde_json::from_str(output).unwrap();
assert_eq!(arr.len(), 2);
assert!(arr[0].contains("18446744073709552000"));
assert!(arr[1].contains("3.14159"));
}
}

View File

@@ -1,7 +1,7 @@
use opensearch::OpenSearch;
use serde::Deserialize;
use serde_json::json;
use tracing::debug;
use tracing::{debug, info};
#[derive(Debug, Deserialize)]
pub struct SearchArgs {
@@ -24,19 +24,24 @@ fn default_limit() -> usize { 10 }
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
let must = vec![json!({
"match": { "content": args.query }
})];
// Handle empty/wildcard queries as match_all
let must = if args.query.is_empty() || args.query == "*" {
vec![json!({ "match_all": {} })]
} else {
vec![json!({
"match": { "content": args.query }
})]
};
let mut filter = vec![json!({
"term": { "redacted": false }
})];
if let Some(ref room) = args.room {
filter.push(json!({ "term": { "room_name": room } }));
filter.push(json!({ "term": { "room_name.keyword": room } }));
}
if let Some(ref sender) = args.sender {
filter.push(json!({ "term": { "sender_name": sender } }));
filter.push(json!({ "term": { "sender_name.keyword": sender } }));
}
let mut range = serde_json::Map::new();
@@ -73,10 +78,19 @@ pub async fn search_archive(
args_json: &str,
) -> anyhow::Result<String> {
let args: SearchArgs = serde_json::from_str(args_json)?;
debug!(query = args.query.as_str(), "Searching archive");
let query_body = build_search_query(&args);
info!(
query = args.query.as_str(),
room = args.room.as_deref().unwrap_or("*"),
sender = args.sender.as_deref().unwrap_or("*"),
after = args.after.as_deref().unwrap_or("*"),
before = args.before.as_deref().unwrap_or("*"),
limit = args.limit,
query_json = %query_body,
"Executing search"
);
let response = client
.search(opensearch::SearchParts::Index(&[index]))
.body(query_body)
@@ -84,6 +98,8 @@ pub async fn search_archive(
.await?;
let body: serde_json::Value = response.json().await?;
let hit_count = body["hits"]["total"]["value"].as_i64().unwrap_or(0);
info!(hit_count, "Search results");
let hits = &body["hits"]["hits"];
let Some(hits_arr) = hits.as_array() else {
@@ -173,7 +189,7 @@ mod tests {
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2);
assert_eq!(filters[1]["term"]["room_name"], "design");
assert_eq!(filters[1]["term"]["room_name.keyword"], "design");
}
#[test]
@@ -183,7 +199,7 @@ mod tests {
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2);
assert_eq!(filters[1]["term"]["sender_name"], "Bob");
assert_eq!(filters[1]["term"]["sender_name.keyword"], "Bob");
}
#[test]
@@ -193,8 +209,8 @@ mod tests {
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 3);
assert_eq!(filters[1]["term"]["room_name"], "dev");
assert_eq!(filters[2]["term"]["sender_name"], "Carol");
assert_eq!(filters[1]["term"]["room_name.keyword"], "dev");
assert_eq!(filters[2]["term"]["sender_name.keyword"], "Carol");
}
#[test]