feat: per-user auto-memory with ResponseContext
Three memory channels: hidden tool (sol.memory.set/get in scripts), pre-response injection (relevant memories loaded into system prompt), and post-response extraction (ministral-3b extracts facts after each response). User isolation enforced at Rust level — user_id derived from Matrix sender, never from script arguments. New modules: context (ResponseContext), memory (schema, store, extractor). ResponseContext threaded through responder → tools → script runtime. OpenSearch index sol_user_memory created on startup alongside archive.
This commit is contained in:
@@ -59,6 +59,34 @@ impl Indexer {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn add_reaction(&self, target_event_id: &str, sender: &str, emoji: &str, timestamp: i64) {
|
||||
// Use a script to append to the reactions array (upsert-safe)
|
||||
let body = json!({
|
||||
"script": {
|
||||
"source": "if (ctx._source.reactions == null) { ctx._source.reactions = []; } ctx._source.reactions.add(params.reaction)",
|
||||
"params": {
|
||||
"reaction": {
|
||||
"sender": sender,
|
||||
"emoji": emoji,
|
||||
"timestamp": timestamp
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
if let Err(e) = self
|
||||
.client
|
||||
.update(opensearch::UpdateParts::IndexId(
|
||||
&self.config.opensearch.index,
|
||||
target_event_id,
|
||||
))
|
||||
.body(body)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
warn!(target_event_id, sender, emoji, "Failed to add reaction: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn update_redaction(&self, event_id: &str) {
|
||||
let body = json!({
|
||||
"doc": {
|
||||
|
||||
@@ -24,6 +24,15 @@ pub struct ArchiveDocument {
|
||||
pub edited: bool,
|
||||
#[serde(default)]
|
||||
pub redacted: bool,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub reactions: Vec<Reaction>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Reaction {
|
||||
pub sender: String,
|
||||
pub emoji: String,
|
||||
pub timestamp: i64,
|
||||
}
|
||||
|
||||
const INDEX_MAPPING: &str = r#"{
|
||||
@@ -45,7 +54,15 @@ const INDEX_MAPPING: &str = r#"{
|
||||
"media_urls": { "type": "keyword" },
|
||||
"event_type": { "type": "keyword" },
|
||||
"edited": { "type": "boolean" },
|
||||
"redacted": { "type": "boolean" }
|
||||
"redacted": { "type": "boolean" },
|
||||
"reactions": {
|
||||
"type": "nested",
|
||||
"properties": {
|
||||
"sender": { "type": "keyword" },
|
||||
"emoji": { "type": "keyword" },
|
||||
"timestamp": { "type": "date", "format": "epoch_millis" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
@@ -63,6 +80,25 @@ pub async fn create_index_if_not_exists(client: &OpenSearch, index: &str) -> any
|
||||
|
||||
if exists.status_code().is_success() {
|
||||
info!(index, "OpenSearch index already exists");
|
||||
// Ensure reactions field exists (added after initial schema)
|
||||
let reactions_mapping = serde_json::json!({
|
||||
"properties": {
|
||||
"reactions": {
|
||||
"type": "nested",
|
||||
"properties": {
|
||||
"sender": { "type": "keyword" },
|
||||
"emoji": { "type": "keyword" },
|
||||
"timestamp": { "type": "date", "format": "epoch_millis" }
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let _ = client
|
||||
.indices()
|
||||
.put_mapping(opensearch::indices::IndicesPutMappingParts::Index(&[index]))
|
||||
.body(reactions_mapping)
|
||||
.send()
|
||||
.await;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
@@ -102,6 +138,7 @@ mod tests {
|
||||
event_type: "m.room.message".to_string(),
|
||||
edited: false,
|
||||
redacted: false,
|
||||
reactions: vec![],
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ use mistralai_client::v1::{
|
||||
constants::Model,
|
||||
};
|
||||
use regex::Regex;
|
||||
use tracing::{debug, warn};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::config::Config;
|
||||
|
||||
@@ -13,6 +13,7 @@ use crate::config::Config;
|
||||
pub enum Engagement {
|
||||
MustRespond { reason: MustRespondReason },
|
||||
MaybeRespond { relevance: f32, hook: String },
|
||||
React { emoji: String, relevance: f32 },
|
||||
Ignore,
|
||||
}
|
||||
|
||||
@@ -33,7 +34,9 @@ impl Evaluator {
|
||||
// todo(sienna): regex must be configrable
|
||||
pub fn new(config: Arc<Config>) -> Self {
|
||||
let user_id = &config.matrix.user_id;
|
||||
let mention_pattern = regex::escape(user_id);
|
||||
// Match both plain @sol:sunbeam.pt and Matrix link format [sol](https://matrix.to/#/@sol:sunbeam.pt)
|
||||
let escaped = regex::escape(user_id);
|
||||
let mention_pattern = format!(r"{}|matrix\.to/#/{}", escaped, escaped);
|
||||
let mention_regex = Regex::new(&mention_pattern).expect("Failed to compile mention regex");
|
||||
let name_regex =
|
||||
Regex::new(r"(?i)(?:^|\bhey\s+)\bsol\b").expect("Failed to compile name regex");
|
||||
@@ -53,13 +56,17 @@ impl Evaluator {
|
||||
recent_messages: &[String],
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
) -> Engagement {
|
||||
let body_preview: String = body.chars().take(80).collect();
|
||||
|
||||
// Don't respond to ourselves
|
||||
if sender == self.config.matrix.user_id {
|
||||
debug!(sender, body = body_preview.as_str(), "Ignoring own message");
|
||||
return Engagement::Ignore;
|
||||
}
|
||||
|
||||
// Direct mention: @sol:sunbeam.pt
|
||||
if self.mention_regex.is_match(body) {
|
||||
info!(sender, body = body_preview.as_str(), rule = "direct_mention", "Engagement: MustRespond");
|
||||
return Engagement::MustRespond {
|
||||
reason: MustRespondReason::DirectMention,
|
||||
};
|
||||
@@ -67,6 +74,7 @@ impl Evaluator {
|
||||
|
||||
// DM
|
||||
if is_dm {
|
||||
info!(sender, body = body_preview.as_str(), rule = "dm", "Engagement: MustRespond");
|
||||
return Engagement::MustRespond {
|
||||
reason: MustRespondReason::DirectMessage,
|
||||
};
|
||||
@@ -74,11 +82,22 @@ impl Evaluator {
|
||||
|
||||
// Name invocation: "sol ..." or "hey sol ..."
|
||||
if self.name_regex.is_match(body) {
|
||||
info!(sender, body = body_preview.as_str(), rule = "name_invocation", "Engagement: MustRespond");
|
||||
return Engagement::MustRespond {
|
||||
reason: MustRespondReason::NameInvocation,
|
||||
};
|
||||
}
|
||||
|
||||
info!(
|
||||
sender, body = body_preview.as_str(),
|
||||
threshold = self.config.behavior.spontaneous_threshold,
|
||||
model = self.config.mistral.evaluation_model.as_str(),
|
||||
context_len = recent_messages.len(),
|
||||
eval_window = self.config.behavior.evaluation_context_window,
|
||||
detect_sol = self.config.behavior.detect_sol_in_conversation,
|
||||
"No rule match — running LLM relevance evaluation"
|
||||
);
|
||||
|
||||
// Cheap evaluation call for spontaneous responses
|
||||
self.evaluate_relevance(body, recent_messages, mistral)
|
||||
.await
|
||||
@@ -119,23 +138,56 @@ impl Evaluator {
|
||||
recent_messages: &[String],
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
) -> Engagement {
|
||||
let window = self.config.behavior.evaluation_context_window;
|
||||
let context = recent_messages
|
||||
.iter()
|
||||
.rev()
|
||||
.take(5) //todo(sienna): must be configurable
|
||||
.take(window)
|
||||
.rev()
|
||||
.cloned()
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
// Check if Sol recently participated in this conversation
|
||||
let sol_in_context = self.config.behavior.detect_sol_in_conversation
|
||||
&& recent_messages.iter().any(|m| {
|
||||
let lower = m.to_lowercase();
|
||||
lower.starts_with("sol:") || lower.starts_with("sol ") || lower.contains("@sol:")
|
||||
});
|
||||
|
||||
let default_active = "Sol is ALREADY part of this conversation (see messages above from Sol). \
|
||||
Messages that follow up on Sol's response, ask Sol a question, or continue \
|
||||
a thread Sol is in should score HIGH (0.8+). Sol should respond to follow-ups \
|
||||
directed at them even if not mentioned by name.".to_string();
|
||||
|
||||
let default_passive = "Sol has NOT spoken in this conversation yet. Only score high if the message \
|
||||
is clearly relevant to Sol's expertise (archive search, finding past conversations, \
|
||||
information retrieval) or touches a topic Sol has genuine insight on.".to_string();
|
||||
|
||||
let participation_note = if sol_in_context {
|
||||
self.config.behavior.evaluation_prompt_active.as_deref()
|
||||
.unwrap_or(&default_active)
|
||||
} else {
|
||||
self.config.behavior.evaluation_prompt_passive.as_deref()
|
||||
.unwrap_or(&default_passive)
|
||||
};
|
||||
|
||||
info!(
|
||||
sol_in_context,
|
||||
context_window = window,
|
||||
"Building evaluation prompt"
|
||||
);
|
||||
|
||||
let prompt = format!(
|
||||
"You are evaluating whether a virtual librarian named Sol should spontaneously join \
|
||||
a conversation. Sol has deep knowledge of the group's message archive and helps \
|
||||
people find information.\n\n\
|
||||
"You are evaluating whether Sol should respond to a message in a group chat. \
|
||||
Sol is a librarian with access to the team's message archive.\n\n\
|
||||
Recent conversation:\n{context}\n\n\
|
||||
Latest message: {body}\n\n\
|
||||
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\"}}\n\
|
||||
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant."
|
||||
{participation_note}\n\n\
|
||||
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\", \"emoji\": \"a single emoji reaction or empty string\"}}\n\
|
||||
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant.\n\
|
||||
emoji: if Sol wouldn't write a full response but might react to the message, suggest a single emoji. \
|
||||
pick something that feels natural and specific to the message — not generic thumbs up. leave empty if no reaction fits."
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::new_user_message(&prompt)];
|
||||
@@ -159,21 +211,48 @@ impl Evaluator {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
let text = &response.choices[0].message.content;
|
||||
info!(
|
||||
raw_response = text.as_str(),
|
||||
model = self.config.mistral.evaluation_model.as_str(),
|
||||
"LLM evaluation raw response"
|
||||
);
|
||||
|
||||
match serde_json::from_str::<serde_json::Value>(text) {
|
||||
Ok(val) => {
|
||||
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
|
||||
let hook = val["hook"].as_str().unwrap_or("").to_string();
|
||||
let emoji = val["emoji"].as_str().unwrap_or("").to_string();
|
||||
let threshold = self.config.behavior.spontaneous_threshold;
|
||||
let reaction_threshold = self.config.behavior.reaction_threshold;
|
||||
let reaction_enabled = self.config.behavior.reaction_enabled;
|
||||
|
||||
debug!(relevance, hook = hook.as_str(), "Evaluation result");
|
||||
info!(
|
||||
relevance,
|
||||
threshold,
|
||||
reaction_threshold,
|
||||
hook = hook.as_str(),
|
||||
emoji = emoji.as_str(),
|
||||
"LLM evaluation parsed"
|
||||
);
|
||||
|
||||
if relevance >= self.config.behavior.spontaneous_threshold {
|
||||
if relevance >= threshold {
|
||||
Engagement::MaybeRespond { relevance, hook }
|
||||
} else if reaction_enabled
|
||||
&& relevance >= reaction_threshold
|
||||
&& !emoji.is_empty()
|
||||
{
|
||||
info!(
|
||||
relevance,
|
||||
emoji = emoji.as_str(),
|
||||
"Reaction range — will react with emoji"
|
||||
);
|
||||
Engagement::React { emoji, relevance }
|
||||
} else {
|
||||
Engagement::Ignore
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to parse evaluation response: {e}");
|
||||
warn!(raw = text.as_str(), "Failed to parse evaluation response: {e}");
|
||||
Engagement::Ignore
|
||||
}
|
||||
}
|
||||
|
||||
@@ -15,14 +15,19 @@ impl Personality {
|
||||
&self,
|
||||
room_name: &str,
|
||||
members: &[String],
|
||||
memory_notes: Option<&str>,
|
||||
) -> String {
|
||||
let date = Utc::now().format("%Y-%m-%d").to_string();
|
||||
let now = Utc::now();
|
||||
let date = now.format("%Y-%m-%d").to_string();
|
||||
let epoch_ms = now.timestamp_millis().to_string();
|
||||
let members_str = members.join(", ");
|
||||
|
||||
self.template
|
||||
.replace("{date}", &date)
|
||||
.replace("{epoch_ms}", &epoch_ms)
|
||||
.replace("{room_name}", room_name)
|
||||
.replace("{members}", &members_str)
|
||||
.replace("{memory_notes}", memory_notes.unwrap_or(""))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,7 +38,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_date_substitution() {
|
||||
let p = Personality::new("Today is {date}.".to_string());
|
||||
let result = p.build_system_prompt("general", &[]);
|
||||
let result = p.build_system_prompt("general", &[], None);
|
||||
let today = Utc::now().format("%Y-%m-%d").to_string();
|
||||
assert_eq!(result, format!("Today is {today}."));
|
||||
}
|
||||
@@ -41,7 +46,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_room_name_substitution() {
|
||||
let p = Personality::new("You are in {room_name}.".to_string());
|
||||
let result = p.build_system_prompt("design-chat", &[]);
|
||||
let result = p.build_system_prompt("design-chat", &[], None);
|
||||
assert!(result.contains("design-chat"));
|
||||
}
|
||||
|
||||
@@ -49,14 +54,14 @@ mod tests {
|
||||
fn test_members_substitution() {
|
||||
let p = Personality::new("Members: {members}".to_string());
|
||||
let members = vec!["Alice".to_string(), "Bob".to_string(), "Carol".to_string()];
|
||||
let result = p.build_system_prompt("room", &members);
|
||||
let result = p.build_system_prompt("room", &members, None);
|
||||
assert_eq!(result, "Members: Alice, Bob, Carol");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_members() {
|
||||
let p = Personality::new("Members: {members}".to_string());
|
||||
let result = p.build_system_prompt("room", &[]);
|
||||
let result = p.build_system_prompt("room", &[], None);
|
||||
assert_eq!(result, "Members: ");
|
||||
}
|
||||
|
||||
@@ -65,7 +70,7 @@ mod tests {
|
||||
let template = "Date: {date}, Room: {room_name}, People: {members}".to_string();
|
||||
let p = Personality::new(template);
|
||||
let members = vec!["Sienna".to_string(), "Lonni".to_string()];
|
||||
let result = p.build_system_prompt("studio", &members);
|
||||
let result = p.build_system_prompt("studio", &members, None);
|
||||
|
||||
let today = Utc::now().format("%Y-%m-%d").to_string();
|
||||
assert!(result.starts_with(&format!("Date: {today}")));
|
||||
@@ -76,14 +81,32 @@ mod tests {
|
||||
#[test]
|
||||
fn test_no_placeholders_passthrough() {
|
||||
let p = Personality::new("Static prompt with no variables.".to_string());
|
||||
let result = p.build_system_prompt("room", &["Alice".to_string()]);
|
||||
let result = p.build_system_prompt("room", &["Alice".to_string()], None);
|
||||
assert_eq!(result, "Static prompt with no variables.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_same_placeholder() {
|
||||
let p = Personality::new("{room_name} is great. I love {room_name}.".to_string());
|
||||
let result = p.build_system_prompt("lounge", &[]);
|
||||
let result = p.build_system_prompt("lounge", &[], None);
|
||||
assert_eq!(result, "lounge is great. I love lounge.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_notes_substitution() {
|
||||
let p = Personality::new("Context:\n{memory_notes}\nEnd.".to_string());
|
||||
let notes = "## notes about sienna\n- [preference] likes terse answers";
|
||||
let result = p.build_system_prompt("room", &[], Some(notes));
|
||||
assert!(result.contains("## notes about sienna"));
|
||||
assert!(result.contains("- [preference] likes terse answers"));
|
||||
assert!(result.starts_with("Context:\n"));
|
||||
assert!(result.ends_with("\nEnd."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_notes_none_clears_placeholder() {
|
||||
let p = Personality::new("Before\n{memory_notes}\nAfter".to_string());
|
||||
let result = p.build_system_prompt("room", &[], None);
|
||||
assert_eq!(result, "Before\n\nAfter");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,9 +10,14 @@ use rand::Rng;
|
||||
use tokio::time::{sleep, Duration};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use matrix_sdk::room::Room;
|
||||
use opensearch::OpenSearch;
|
||||
|
||||
use crate::brain::conversation::ContextMessage;
|
||||
use crate::brain::personality::Personality;
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::memory;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// Run a Mistral chat completion on a blocking thread.
|
||||
@@ -38,6 +43,7 @@ pub struct Responder {
|
||||
config: Arc<Config>,
|
||||
personality: Arc<Personality>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
opensearch: OpenSearch,
|
||||
}
|
||||
|
||||
impl Responder {
|
||||
@@ -45,11 +51,13 @@ impl Responder {
|
||||
config: Arc<Config>,
|
||||
personality: Arc<Personality>,
|
||||
tools: Arc<ToolRegistry>,
|
||||
opensearch: OpenSearch,
|
||||
) -> Self {
|
||||
Self {
|
||||
config,
|
||||
personality,
|
||||
tools,
|
||||
opensearch,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -62,31 +70,52 @@ impl Responder {
|
||||
members: &[String],
|
||||
is_spontaneous: bool,
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
room: &Room,
|
||||
response_ctx: &ResponseContext,
|
||||
) -> Option<String> {
|
||||
// Apply response delay
|
||||
let delay = if is_spontaneous {
|
||||
rand::thread_rng().gen_range(
|
||||
self.config.behavior.spontaneous_delay_min_ms
|
||||
..=self.config.behavior.spontaneous_delay_max_ms,
|
||||
)
|
||||
} else {
|
||||
rand::thread_rng().gen_range(
|
||||
self.config.behavior.response_delay_min_ms
|
||||
..=self.config.behavior.response_delay_max_ms,
|
||||
)
|
||||
};
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
// Apply response delay (skip if instant_responses is enabled)
|
||||
// Delay happens BEFORE typing indicator — Sol "notices" the message first
|
||||
if !self.config.behavior.instant_responses {
|
||||
let delay = if is_spontaneous {
|
||||
rand::thread_rng().gen_range(
|
||||
self.config.behavior.spontaneous_delay_min_ms
|
||||
..=self.config.behavior.spontaneous_delay_max_ms,
|
||||
)
|
||||
} else {
|
||||
rand::thread_rng().gen_range(
|
||||
self.config.behavior.response_delay_min_ms
|
||||
..=self.config.behavior.response_delay_max_ms,
|
||||
)
|
||||
};
|
||||
debug!(delay_ms = delay, is_spontaneous, "Applying response delay");
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
|
||||
let system_prompt = self.personality.build_system_prompt(room_name, members);
|
||||
// Start typing AFTER the delay — Sol has decided to respond
|
||||
let _ = room.typing_notice(true).await;
|
||||
|
||||
// Pre-response memory query
|
||||
let memory_notes = self
|
||||
.load_memory_notes(response_ctx, trigger_body)
|
||||
.await;
|
||||
|
||||
let system_prompt = self.personality.build_system_prompt(
|
||||
room_name,
|
||||
members,
|
||||
memory_notes.as_deref(),
|
||||
);
|
||||
|
||||
let mut messages = vec![ChatMessage::new_system_message(&system_prompt)];
|
||||
|
||||
// Add context messages
|
||||
// Add context messages with timestamps so the model has time awareness
|
||||
for msg in context {
|
||||
let ts = chrono::DateTime::from_timestamp_millis(msg.timestamp)
|
||||
.map(|d| d.format("%H:%M").to_string())
|
||||
.unwrap_or_default();
|
||||
if msg.sender == self.config.matrix.user_id {
|
||||
messages.push(ChatMessage::new_assistant_message(&msg.content, None));
|
||||
} else {
|
||||
let user_msg = format!("{}: {}", msg.sender, msg.content);
|
||||
let user_msg = format!("[{}] {}: {}", ts, msg.sender, msg.content);
|
||||
messages.push(ChatMessage::new_user_message(&user_msg));
|
||||
}
|
||||
}
|
||||
@@ -117,6 +146,7 @@ impl Responder {
|
||||
let response = match chat_blocking(mistral, model.clone(), messages.clone(), params).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let _ = room.typing_notice(false).await;
|
||||
error!("Mistral chat failed: {e}");
|
||||
return None;
|
||||
}
|
||||
@@ -137,12 +167,13 @@ impl Responder {
|
||||
info!(
|
||||
tool = tc.function.name.as_str(),
|
||||
id = call_id,
|
||||
args = tc.function.arguments.as_str(),
|
||||
"Executing tool call"
|
||||
);
|
||||
|
||||
let result = self
|
||||
.tools
|
||||
.execute(&tc.function.name, &tc.function.arguments)
|
||||
.execute(&tc.function.name, &tc.function.arguments, response_ctx)
|
||||
.await;
|
||||
|
||||
let result_str = match result {
|
||||
@@ -165,15 +196,155 @@ impl Responder {
|
||||
}
|
||||
}
|
||||
|
||||
// Final text response
|
||||
let text = choice.message.content.trim().to_string();
|
||||
// Final text response — strip own name prefix if present
|
||||
let mut text = choice.message.content.trim().to_string();
|
||||
|
||||
// Strip "sol:" or "sol 💕:" or similar prefixes the model sometimes adds
|
||||
let lower = text.to_lowercase();
|
||||
for prefix in &["sol:", "sol 💕:", "sol💕:"] {
|
||||
if lower.starts_with(prefix) {
|
||||
text = text[prefix.len()..].trim().to_string();
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if text.is_empty() {
|
||||
info!("Generated empty response, skipping send");
|
||||
let _ = room.typing_notice(false).await;
|
||||
return None;
|
||||
}
|
||||
|
||||
let preview: String = text.chars().take(120).collect();
|
||||
let _ = room.typing_notice(false).await;
|
||||
info!(
|
||||
response_len = text.len(),
|
||||
response_preview = preview.as_str(),
|
||||
is_spontaneous,
|
||||
tool_iterations = iteration,
|
||||
"Generated response"
|
||||
);
|
||||
return Some(text);
|
||||
}
|
||||
|
||||
let _ = room.typing_notice(false).await;
|
||||
warn!("Exceeded max tool iterations");
|
||||
None
|
||||
}
|
||||
|
||||
async fn load_memory_notes(
|
||||
&self,
|
||||
ctx: &ResponseContext,
|
||||
trigger_body: &str,
|
||||
) -> Option<String> {
|
||||
let index = &self.config.opensearch.memory_index;
|
||||
let user_id = &ctx.user_id;
|
||||
|
||||
// Search for topically relevant memories
|
||||
let mut memories = memory::store::query(
|
||||
&self.opensearch,
|
||||
index,
|
||||
user_id,
|
||||
trigger_body,
|
||||
5,
|
||||
)
|
||||
.await
|
||||
.unwrap_or_default();
|
||||
|
||||
// Backfill with recent memories if we have fewer than 3
|
||||
if memories.len() < 3 {
|
||||
let remaining = 5 - memories.len();
|
||||
if let Ok(recent) = memory::store::get_recent(
|
||||
&self.opensearch,
|
||||
index,
|
||||
user_id,
|
||||
remaining,
|
||||
)
|
||||
.await
|
||||
{
|
||||
let existing_ids: std::collections::HashSet<String> =
|
||||
memories.iter().map(|m| m.id.clone()).collect();
|
||||
for doc in recent {
|
||||
if !existing_ids.contains(&doc.id) && memories.len() < 5 {
|
||||
memories.push(doc);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if memories.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let display = ctx
|
||||
.display_name
|
||||
.as_deref()
|
||||
.unwrap_or(&ctx.matrix_user_id);
|
||||
|
||||
Some(format_memory_notes(display, &memories))
|
||||
}
|
||||
}
|
||||
|
||||
/// Format memory documents into a notes block for the system prompt.
|
||||
pub(crate) fn format_memory_notes(
|
||||
display_name: &str,
|
||||
memories: &[memory::schema::MemoryDocument],
|
||||
) -> String {
|
||||
let mut lines = vec![format!(
|
||||
"## notes about {display_name}\n\n\
|
||||
these are your private notes about the person you're talking to.\n\
|
||||
use them to inform your responses but don't mention that you have notes.\n"
|
||||
)];
|
||||
|
||||
for mem in memories {
|
||||
lines.push(format!("- [{}] {}", mem.category, mem.content));
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::schema::MemoryDocument;
|
||||
|
||||
fn make_mem(id: &str, content: &str, category: &str) -> MemoryDocument {
|
||||
MemoryDocument {
|
||||
id: id.into(),
|
||||
user_id: "sienna@sunbeam.pt".into(),
|
||||
content: content.into(),
|
||||
category: category.into(),
|
||||
created_at: 1710000000000,
|
||||
updated_at: 1710000000000,
|
||||
source: "auto".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_memory_notes_basic() {
|
||||
let memories = vec![
|
||||
make_mem("a", "prefers terse answers", "preference"),
|
||||
make_mem("b", "working on drive UI", "fact"),
|
||||
];
|
||||
|
||||
let result = format_memory_notes("sienna", &memories);
|
||||
assert!(result.contains("## notes about sienna"));
|
||||
assert!(result.contains("don't mention that you have notes"));
|
||||
assert!(result.contains("- [preference] prefers terse answers"));
|
||||
assert!(result.contains("- [fact] working on drive UI"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_memory_notes_single() {
|
||||
let memories = vec![make_mem("x", "birthday is march 12", "context")];
|
||||
let result = format_memory_notes("lonni", &memories);
|
||||
assert!(result.contains("## notes about lonni"));
|
||||
assert!(result.contains("- [context] birthday is march 12"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_memory_notes_uses_display_name() {
|
||||
let memories = vec![make_mem("a", "test", "general")];
|
||||
let result = format_memory_notes("Amber", &memories);
|
||||
assert!(result.contains("## notes about Amber"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,6 +25,8 @@ pub struct OpenSearchConfig {
|
||||
pub flush_interval_ms: u64,
|
||||
#[serde(default = "default_embedding_pipeline")]
|
||||
pub embedding_pipeline: String,
|
||||
#[serde(default = "default_memory_index")]
|
||||
pub memory_index: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -59,6 +61,30 @@ pub struct BehaviorConfig {
|
||||
pub backfill_on_join: bool,
|
||||
#[serde(default = "default_backfill_limit")]
|
||||
pub backfill_limit: usize,
|
||||
#[serde(default)]
|
||||
pub instant_responses: bool,
|
||||
#[serde(default = "default_cooldown_after_response_ms")]
|
||||
pub cooldown_after_response_ms: u64,
|
||||
#[serde(default = "default_evaluation_context_window")]
|
||||
pub evaluation_context_window: usize,
|
||||
#[serde(default = "default_detect_sol_in_conversation")]
|
||||
pub detect_sol_in_conversation: bool,
|
||||
#[serde(default)]
|
||||
pub evaluation_prompt_active: Option<String>,
|
||||
#[serde(default)]
|
||||
pub evaluation_prompt_passive: Option<String>,
|
||||
#[serde(default = "default_reaction_threshold")]
|
||||
pub reaction_threshold: f32,
|
||||
#[serde(default = "default_reaction_enabled")]
|
||||
pub reaction_enabled: bool,
|
||||
#[serde(default = "default_script_timeout_secs")]
|
||||
pub script_timeout_secs: u64,
|
||||
#[serde(default = "default_script_max_heap_mb")]
|
||||
pub script_max_heap_mb: usize,
|
||||
#[serde(default)]
|
||||
pub script_fetch_allowlist: Vec<String>,
|
||||
#[serde(default = "default_memory_extraction_enabled")]
|
||||
pub memory_extraction_enabled: bool,
|
||||
}
|
||||
|
||||
fn default_batch_size() -> usize { 50 }
|
||||
@@ -68,15 +94,24 @@ fn default_model() -> String { "mistral-medium-latest".into() }
|
||||
fn default_evaluation_model() -> String { "ministral-3b-latest".into() }
|
||||
fn default_research_model() -> String { "mistral-large-latest".into() }
|
||||
fn default_max_tool_iterations() -> usize { 5 }
|
||||
fn default_response_delay_min_ms() -> u64 { 2000 }
|
||||
fn default_response_delay_max_ms() -> u64 { 8000 }
|
||||
fn default_response_delay_min_ms() -> u64 { 100 }
|
||||
fn default_response_delay_max_ms() -> u64 { 2300 }
|
||||
fn default_spontaneous_delay_min_ms() -> u64 { 15000 }
|
||||
fn default_spontaneous_delay_max_ms() -> u64 { 60000 }
|
||||
fn default_spontaneous_threshold() -> f32 { 0.7 }
|
||||
fn default_spontaneous_threshold() -> f32 { 0.85 }
|
||||
fn default_cooldown_after_response_ms() -> u64 { 15000 }
|
||||
fn default_evaluation_context_window() -> usize { 25 }
|
||||
fn default_detect_sol_in_conversation() -> bool { true }
|
||||
fn default_reaction_threshold() -> f32 { 0.6 }
|
||||
fn default_reaction_enabled() -> bool { true }
|
||||
fn default_room_context_window() -> usize { 30 }
|
||||
fn default_dm_context_window() -> usize { 100 }
|
||||
fn default_backfill_on_join() -> bool { true }
|
||||
fn default_backfill_limit() -> usize { 10000 }
|
||||
fn default_script_timeout_secs() -> u64 { 5 }
|
||||
fn default_script_max_heap_mb() -> usize { 64 }
|
||||
fn default_memory_index() -> String { "sol_user_memory".into() }
|
||||
fn default_memory_extraction_enabled() -> bool { true }
|
||||
|
||||
impl Config {
|
||||
pub fn load(path: &str) -> anyhow::Result<Self> {
|
||||
@@ -155,19 +190,23 @@ backfill_limit = 5000
|
||||
assert_eq!(config.opensearch.batch_size, 50);
|
||||
assert_eq!(config.opensearch.flush_interval_ms, 2000);
|
||||
assert_eq!(config.opensearch.embedding_pipeline, "tuwunel_embedding_pipeline");
|
||||
assert_eq!(config.opensearch.memory_index, "sol_user_memory");
|
||||
assert_eq!(config.mistral.default_model, "mistral-medium-latest");
|
||||
assert_eq!(config.mistral.evaluation_model, "ministral-3b-latest");
|
||||
assert_eq!(config.mistral.research_model, "mistral-large-latest");
|
||||
assert_eq!(config.mistral.max_tool_iterations, 5);
|
||||
assert_eq!(config.behavior.response_delay_min_ms, 2000);
|
||||
assert_eq!(config.behavior.response_delay_max_ms, 8000);
|
||||
assert_eq!(config.behavior.response_delay_min_ms, 100);
|
||||
assert_eq!(config.behavior.response_delay_max_ms, 2300);
|
||||
assert_eq!(config.behavior.spontaneous_delay_min_ms, 15000);
|
||||
assert_eq!(config.behavior.spontaneous_delay_max_ms, 60000);
|
||||
assert!((config.behavior.spontaneous_threshold - 0.7).abs() < f32::EPSILON);
|
||||
assert!((config.behavior.spontaneous_threshold - 0.85).abs() < f32::EPSILON);
|
||||
assert!(!config.behavior.instant_responses);
|
||||
assert_eq!(config.behavior.cooldown_after_response_ms, 15000);
|
||||
assert_eq!(config.behavior.room_context_window, 30);
|
||||
assert_eq!(config.behavior.dm_context_window, 100);
|
||||
assert!(config.behavior.backfill_on_join);
|
||||
assert_eq!(config.behavior.backfill_limit, 10000);
|
||||
assert!(config.behavior.memory_extraction_enabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
56
src/context.rs
Normal file
56
src/context.rs
Normal file
@@ -0,0 +1,56 @@
|
||||
/// Per-message response context, threading the sender's identity from Matrix
|
||||
/// through the tool loop and memory system.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResponseContext {
|
||||
/// Full Matrix user ID, e.g. `@sienna:sunbeam.pt`
|
||||
pub matrix_user_id: String,
|
||||
/// Derived portable ID, e.g. `sienna@sunbeam.pt`
|
||||
pub user_id: String,
|
||||
/// Display name if available
|
||||
pub display_name: Option<String>,
|
||||
/// Whether this message was sent in a DM
|
||||
pub is_dm: bool,
|
||||
/// Whether this message is a reply to Sol
|
||||
pub is_reply: bool,
|
||||
/// The room this message was sent in
|
||||
pub room_id: String,
|
||||
}
|
||||
|
||||
/// Derive a portable user ID from a Matrix user ID.
|
||||
///
|
||||
/// `@sienna:sunbeam.pt` → `sienna@sunbeam.pt`
|
||||
pub fn derive_user_id(matrix_user_id: &str) -> String {
|
||||
let stripped = matrix_user_id.strip_prefix('@').unwrap_or(matrix_user_id);
|
||||
stripped.replacen(':', "@", 1)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_derive_user_id_standard() {
|
||||
assert_eq!(derive_user_id("@sienna:sunbeam.pt"), "sienna@sunbeam.pt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_user_id_no_at_prefix() {
|
||||
assert_eq!(derive_user_id("sienna:sunbeam.pt"), "sienna@sunbeam.pt");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_user_id_complex() {
|
||||
assert_eq!(
|
||||
derive_user_id("@user.name:matrix.org"),
|
||||
"user.name@matrix.org"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_user_id_only_first_colon() {
|
||||
assert_eq!(
|
||||
derive_user_id("@user:server:8448"),
|
||||
"user@server:8448"
|
||||
);
|
||||
}
|
||||
}
|
||||
194
src/main.rs
194
src/main.rs
@@ -1,7 +1,9 @@
|
||||
mod archive;
|
||||
mod brain;
|
||||
mod config;
|
||||
mod context;
|
||||
mod matrix_utils;
|
||||
mod memory;
|
||||
mod sync;
|
||||
mod tools;
|
||||
|
||||
@@ -18,7 +20,8 @@ use url::Url;
|
||||
|
||||
use archive::indexer::Indexer;
|
||||
use archive::schema::create_index_if_not_exists;
|
||||
use brain::conversation::ConversationManager;
|
||||
use brain::conversation::{ContextMessage, ConversationManager};
|
||||
use memory::schema::create_index_if_not_exists as create_memory_index;
|
||||
use brain::evaluator::Evaluator;
|
||||
use brain::personality::Personality;
|
||||
use brain::responder::Responder;
|
||||
@@ -93,8 +96,9 @@ async fn main() -> anyhow::Result<()> {
|
||||
.build()?;
|
||||
let os_client = OpenSearch::new(os_transport);
|
||||
|
||||
// Ensure index exists
|
||||
// Ensure indices exist
|
||||
create_index_if_not_exists(&os_client, &config.opensearch.index).await?;
|
||||
create_memory_index(&os_client, &config.opensearch.memory_index).await?;
|
||||
|
||||
// Initialize Mistral client
|
||||
let mistral_client = mistralai_client::v1::client::Client::new(
|
||||
@@ -107,22 +111,32 @@ async fn main() -> anyhow::Result<()> {
|
||||
|
||||
// Build components
|
||||
let personality = Arc::new(Personality::new(system_prompt));
|
||||
let conversations = Arc::new(Mutex::new(ConversationManager::new(
|
||||
config.behavior.room_context_window,
|
||||
config.behavior.dm_context_window,
|
||||
)));
|
||||
|
||||
// Backfill conversation context from archive before starting
|
||||
if config.behavior.backfill_on_join {
|
||||
info!("Backfilling conversation context from archive...");
|
||||
if let Err(e) = backfill_conversations(&os_client, &config, &conversations).await {
|
||||
error!("Backfill failed (non-fatal): {e}");
|
||||
}
|
||||
}
|
||||
|
||||
let tool_registry = Arc::new(ToolRegistry::new(
|
||||
os_client.clone(),
|
||||
matrix_client.clone(),
|
||||
config.clone(),
|
||||
));
|
||||
let indexer = Arc::new(Indexer::new(os_client, config.clone()));
|
||||
let indexer = Arc::new(Indexer::new(os_client.clone(), config.clone()));
|
||||
let evaluator = Arc::new(Evaluator::new(config.clone()));
|
||||
let responder = Arc::new(Responder::new(
|
||||
config.clone(),
|
||||
personality,
|
||||
tool_registry,
|
||||
os_client.clone(),
|
||||
));
|
||||
let conversations = Arc::new(Mutex::new(ConversationManager::new(
|
||||
config.behavior.room_context_window,
|
||||
config.behavior.dm_context_window,
|
||||
)));
|
||||
|
||||
// Start background flush task
|
||||
let _flush_handle = indexer.start_flush_task();
|
||||
@@ -135,8 +149,17 @@ async fn main() -> anyhow::Result<()> {
|
||||
responder,
|
||||
conversations,
|
||||
mistral,
|
||||
opensearch: os_client,
|
||||
last_response: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
responding_in: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
});
|
||||
|
||||
// Backfill reactions from Matrix room timelines
|
||||
info!("Backfilling reactions from room timelines...");
|
||||
if let Err(e) = backfill_reactions(&matrix_client, &state.indexer).await {
|
||||
error!("Reaction backfill failed (non-fatal): {e}");
|
||||
}
|
||||
|
||||
// Start sync loop in background
|
||||
let sync_client = matrix_client.clone();
|
||||
let sync_state = state.clone();
|
||||
@@ -158,3 +181,160 @@ async fn main() -> anyhow::Result<()> {
|
||||
info!("Sol has shut down");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Backfill conversation context from the OpenSearch archive.
|
||||
///
|
||||
/// Queries the most recent messages per room and seeds the ConversationManager
|
||||
/// so Sol has context surviving restarts.
|
||||
async fn backfill_conversations(
|
||||
os_client: &OpenSearch,
|
||||
config: &Config,
|
||||
conversations: &Arc<Mutex<ConversationManager>>,
|
||||
) -> anyhow::Result<()> {
|
||||
use serde_json::json;
|
||||
|
||||
let window = config.behavior.room_context_window.max(config.behavior.dm_context_window);
|
||||
let index = &config.opensearch.index;
|
||||
|
||||
// Get all distinct rooms
|
||||
let agg_body = json!({
|
||||
"size": 0,
|
||||
"aggs": {
|
||||
"rooms": {
|
||||
"terms": { "field": "room_id", "size": 500 }
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response = os_client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(agg_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let buckets = body["aggregations"]["rooms"]["buckets"]
|
||||
.as_array()
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut total = 0;
|
||||
for bucket in &buckets {
|
||||
let room_id = bucket["key"].as_str().unwrap_or("");
|
||||
if room_id.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fetch recent messages for this room
|
||||
let query = json!({
|
||||
"size": window,
|
||||
"sort": [{ "timestamp": "asc" }],
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": [
|
||||
{ "term": { "room_id": room_id } },
|
||||
{ "term": { "redacted": false } }
|
||||
]
|
||||
}
|
||||
},
|
||||
"_source": ["sender_name", "sender", "content", "timestamp"]
|
||||
});
|
||||
|
||||
let resp = os_client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(query)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let data: serde_json::Value = resp.json().await?;
|
||||
let hits = data["hits"]["hits"].as_array().cloned().unwrap_or_default();
|
||||
|
||||
if hits.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut convs = conversations.lock().await;
|
||||
for hit in &hits {
|
||||
let src = &hit["_source"];
|
||||
let sender = src["sender_name"]
|
||||
.as_str()
|
||||
.or_else(|| src["sender"].as_str())
|
||||
.unwrap_or("unknown");
|
||||
let content = src["content"].as_str().unwrap_or("");
|
||||
let timestamp = src["timestamp"].as_i64().unwrap_or(0);
|
||||
|
||||
convs.add_message(
|
||||
room_id,
|
||||
false, // we don't know if it's a DM from the archive, use group window
|
||||
ContextMessage {
|
||||
sender: sender.to_string(),
|
||||
content: content.to_string(),
|
||||
timestamp,
|
||||
},
|
||||
);
|
||||
total += 1;
|
||||
}
|
||||
}
|
||||
|
||||
info!(rooms = buckets.len(), messages = total, "Backfill complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Backfill reactions from Matrix room timelines into the archive.
|
||||
///
|
||||
/// For each joined room, fetches recent timeline events and indexes any
|
||||
/// m.reaction events that aren't already in the archive.
|
||||
async fn backfill_reactions(
|
||||
client: &Client,
|
||||
indexer: &Arc<Indexer>,
|
||||
) -> anyhow::Result<()> {
|
||||
use matrix_sdk::room::MessagesOptions;
|
||||
use ruma::events::AnySyncTimelineEvent;
|
||||
use ruma::uint;
|
||||
|
||||
let rooms = client.joined_rooms();
|
||||
let mut total = 0;
|
||||
|
||||
for room in &rooms {
|
||||
let room_id = room.room_id().to_string();
|
||||
|
||||
// Fetch recent messages (backwards from now)
|
||||
let mut options = MessagesOptions::backward();
|
||||
options.limit = uint!(500);
|
||||
|
||||
let messages = match room.messages(options).await {
|
||||
Ok(m) => m,
|
||||
Err(e) => {
|
||||
error!(room = room_id.as_str(), "Failed to fetch timeline for reaction backfill: {e}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
for event in &messages.chunk {
|
||||
let Ok(deserialized) = event.raw().deserialize() else {
|
||||
continue;
|
||||
};
|
||||
|
||||
if let AnySyncTimelineEvent::MessageLike(
|
||||
ruma::events::AnySyncMessageLikeEvent::Reaction(reaction_event),
|
||||
) = deserialized
|
||||
{
|
||||
let original = match reaction_event {
|
||||
ruma::events::SyncMessageLikeEvent::Original(ref o) => o,
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let target_event_id = original.content.relates_to.event_id.to_string();
|
||||
let sender = original.sender.to_string();
|
||||
let emoji = &original.content.relates_to.key;
|
||||
let timestamp: i64 = original.origin_server_ts.0.into();
|
||||
|
||||
indexer.add_reaction(&target_event_id, &sender, emoji, timestamp).await;
|
||||
total += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
info!(reactions = total, rooms = rooms.len(), "Reaction backfill complete");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ use matrix_sdk::RoomMemberships;
|
||||
use ruma::events::room::message::{
|
||||
MessageType, OriginalSyncRoomMessageEvent, Relation, RoomMessageEventContent,
|
||||
};
|
||||
use ruma::events::relation::InReplyTo;
|
||||
use ruma::events::relation::{Annotation, InReplyTo};
|
||||
use ruma::events::reaction::ReactionEventContent;
|
||||
use ruma::OwnedEventId;
|
||||
|
||||
/// Extract the plain-text body from a message event.
|
||||
@@ -45,15 +46,27 @@ pub fn extract_thread_id(event: &OriginalSyncRoomMessageEvent) -> Option<OwnedEv
|
||||
None
|
||||
}
|
||||
|
||||
/// Build a reply message content with m.in_reply_to relation.
|
||||
/// Build a reply message content with m.in_reply_to relation and markdown rendering.
|
||||
pub fn make_reply_content(body: &str, reply_to_event_id: OwnedEventId) -> RoomMessageEventContent {
|
||||
let mut content = RoomMessageEventContent::text_plain(body);
|
||||
let mut content = RoomMessageEventContent::text_markdown(body);
|
||||
content.relates_to = Some(Relation::Reply {
|
||||
in_reply_to: InReplyTo::new(reply_to_event_id),
|
||||
});
|
||||
content
|
||||
}
|
||||
|
||||
/// Send an emoji reaction to a message.
|
||||
pub async fn send_reaction(
|
||||
room: &Room,
|
||||
event_id: OwnedEventId,
|
||||
emoji: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let annotation = Annotation::new(event_id, emoji.to_string());
|
||||
let content = ReactionEventContent::new(annotation);
|
||||
room.send(content).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get the display name for a room.
|
||||
pub fn room_display_name(room: &Room) -> String {
|
||||
room.cached_display_name()
|
||||
|
||||
158
src/memory/extractor.rs
Normal file
158
src/memory/extractor.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatParams, ResponseFormat},
|
||||
constants::Model,
|
||||
};
|
||||
use opensearch::OpenSearch;
|
||||
use serde::Deserialize;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::brain::responder::chat_blocking;
|
||||
|
||||
use super::store;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ExtractionResponse {
|
||||
pub memories: Vec<ExtractedMemory>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ExtractedMemory {
|
||||
pub content: String,
|
||||
pub category: String,
|
||||
}
|
||||
|
||||
/// Validate and normalize a category string.
|
||||
pub(crate) fn normalize_category(raw: &str) -> &str {
|
||||
match raw {
|
||||
"preference" | "fact" | "context" => raw,
|
||||
_ => "general",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn extract_and_store(
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
opensearch: &OpenSearch,
|
||||
config: &Config,
|
||||
ctx: &ResponseContext,
|
||||
user_message: &str,
|
||||
sol_response: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let display = ctx
|
||||
.display_name
|
||||
.as_deref()
|
||||
.unwrap_or(&ctx.matrix_user_id);
|
||||
|
||||
let prompt = format!(
|
||||
"Analyze this conversation exchange and extract any facts worth remembering about {display}.\n\
|
||||
Focus on: preferences, personal details, ongoing projects, opinions, recurring topics.\n\n\
|
||||
They said: {user_message}\n\
|
||||
Response: {sol_response}\n\n\
|
||||
Respond ONLY with JSON: {{\"memories\": [{{\"content\": \"...\", \"category\": \"preference|fact|context\"}}]}}\n\
|
||||
If nothing worth remembering, respond with {{\"memories\": []}}.\n\
|
||||
Be selective — only genuinely useful information."
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::new_user_message(&prompt)];
|
||||
let model = Model::new(&config.mistral.evaluation_model);
|
||||
let params = ChatParams {
|
||||
response_format: Some(ResponseFormat::json_object()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = chat_blocking(mistral, model, messages, params).await?;
|
||||
let text = response.choices[0].message.content.trim();
|
||||
|
||||
let extraction: ExtractionResponse = match serde_json::from_str(text) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
debug!(raw = text, "Failed to parse extraction response: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
if extraction.memories.is_empty() {
|
||||
debug!("No memories extracted");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let index = &config.opensearch.memory_index;
|
||||
for mem in &extraction.memories {
|
||||
let category = normalize_category(&mem.category);
|
||||
|
||||
if let Err(e) = store::set(
|
||||
opensearch,
|
||||
index,
|
||||
&ctx.user_id,
|
||||
&mem.content,
|
||||
category,
|
||||
"auto",
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to store extracted memory: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
count = extraction.memories.len(),
|
||||
user = ctx.user_id.as_str(),
|
||||
"Extracted and stored memories"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_with_memories() {
|
||||
let json = r#"{"memories": [
|
||||
{"content": "prefers terse answers", "category": "preference"},
|
||||
{"content": "working on drive UI", "category": "fact"}
|
||||
]}"#;
|
||||
let resp: ExtractionResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.memories.len(), 2);
|
||||
assert_eq!(resp.memories[0].content, "prefers terse answers");
|
||||
assert_eq!(resp.memories[0].category, "preference");
|
||||
assert_eq!(resp.memories[1].category, "fact");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_empty() {
|
||||
let json = r#"{"memories": []}"#;
|
||||
let resp: ExtractionResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.memories.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_invalid_json() {
|
||||
let json = "not json at all";
|
||||
assert!(serde_json::from_str::<ExtractionResponse>(json).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_missing_field() {
|
||||
let json = r#"{"memories": [{"content": "hi"}]}"#;
|
||||
assert!(serde_json::from_str::<ExtractionResponse>(json).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_category_valid() {
|
||||
assert_eq!(normalize_category("preference"), "preference");
|
||||
assert_eq!(normalize_category("fact"), "fact");
|
||||
assert_eq!(normalize_category("context"), "context");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_category_unknown_falls_back() {
|
||||
assert_eq!(normalize_category("opinion"), "general");
|
||||
assert_eq!(normalize_category(""), "general");
|
||||
assert_eq!(normalize_category("PREFERENCE"), "general");
|
||||
}
|
||||
}
|
||||
3
src/memory/mod.rs
Normal file
3
src/memory/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod extractor;
|
||||
pub mod schema;
|
||||
pub mod store;
|
||||
118
src/memory/schema.rs
Normal file
118
src/memory/schema.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use opensearch::OpenSearch;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryDocument {
|
||||
pub id: String,
|
||||
pub user_id: String,
|
||||
pub content: String,
|
||||
pub category: String,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
const INDEX_MAPPING: &str = r#"{
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"id": { "type": "keyword" },
|
||||
"user_id": { "type": "keyword" },
|
||||
"content": { "type": "text", "analyzer": "standard" },
|
||||
"category": { "type": "keyword" },
|
||||
"created_at": { "type": "date", "format": "epoch_millis" },
|
||||
"updated_at": { "type": "date", "format": "epoch_millis" },
|
||||
"source": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
pub async fn create_index_if_not_exists(client: &OpenSearch, index: &str) -> anyhow::Result<()> {
|
||||
let exists = client
|
||||
.indices()
|
||||
.exists(opensearch::indices::IndicesExistsParts::Index(&[index]))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if exists.status_code().is_success() {
|
||||
info!(index, "Memory index already exists");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING)?;
|
||||
let response = client
|
||||
.indices()
|
||||
.create(opensearch::indices::IndicesCreateParts::Index(index))
|
||||
.body(mapping)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status_code().is_success() {
|
||||
let body = response.text().await?;
|
||||
anyhow::bail!("Failed to create memory index {index}: {body}");
|
||||
}
|
||||
|
||||
info!(index, "Created memory index");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_memory_document_serialize() {
|
||||
let doc = MemoryDocument {
|
||||
id: "abc-123".into(),
|
||||
user_id: "sienna@sunbeam.pt".into(),
|
||||
content: "prefers terse answers".into(),
|
||||
category: "preference".into(),
|
||||
created_at: 1710000000000,
|
||||
updated_at: 1710000000000,
|
||||
source: "auto".into(),
|
||||
};
|
||||
let json = serde_json::to_value(&doc).unwrap();
|
||||
assert_eq!(json["user_id"], "sienna@sunbeam.pt");
|
||||
assert_eq!(json["category"], "preference");
|
||||
assert_eq!(json["source"], "auto");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_document_roundtrip() {
|
||||
let doc = MemoryDocument {
|
||||
id: "xyz".into(),
|
||||
user_id: "lonni@sunbeam.pt".into(),
|
||||
content: "working on UI redesign".into(),
|
||||
category: "fact".into(),
|
||||
created_at: 1710000000000,
|
||||
updated_at: 1710000000000,
|
||||
source: "script".into(),
|
||||
};
|
||||
let json_str = serde_json::to_string(&doc).unwrap();
|
||||
let roundtrip: MemoryDocument = serde_json::from_str(&json_str).unwrap();
|
||||
assert_eq!(roundtrip.id, doc.id);
|
||||
assert_eq!(roundtrip.user_id, doc.user_id);
|
||||
assert_eq!(roundtrip.content, doc.content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_mapping_valid_json() {
|
||||
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING).unwrap();
|
||||
assert_eq!(
|
||||
mapping["mappings"]["properties"]["user_id"]["type"]
|
||||
.as_str()
|
||||
.unwrap(),
|
||||
"keyword"
|
||||
);
|
||||
assert_eq!(
|
||||
mapping["mappings"]["properties"]["content"]["type"]
|
||||
.as_str()
|
||||
.unwrap(),
|
||||
"text"
|
||||
);
|
||||
}
|
||||
}
|
||||
187
src/memory/store.rs
Normal file
187
src/memory/store.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use chrono::Utc;
|
||||
use opensearch::OpenSearch;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::schema::MemoryDocument;
|
||||
|
||||
/// Search memories by content relevance, filtered to a specific user.
|
||||
pub async fn query(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
user_id: &str,
|
||||
query_text: &str,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<Vec<MemoryDocument>> {
|
||||
let body = json!({
|
||||
"size": limit,
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": [
|
||||
{ "term": { "user_id": user_id } }
|
||||
],
|
||||
"must": [
|
||||
{ "match": { "content": query_text } }
|
||||
]
|
||||
}
|
||||
},
|
||||
"sort": [{ "_score": "desc" }]
|
||||
});
|
||||
|
||||
let response = client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let data: serde_json::Value = response.json().await?;
|
||||
parse_hits(&data)
|
||||
}
|
||||
|
||||
/// Get the most recent memories for a user, sorted by updated_at desc.
|
||||
pub async fn get_recent(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<Vec<MemoryDocument>> {
|
||||
let body = json!({
|
||||
"size": limit,
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": [
|
||||
{ "term": { "user_id": user_id } }
|
||||
]
|
||||
}
|
||||
},
|
||||
"sort": [{ "updated_at": "desc" }]
|
||||
});
|
||||
|
||||
let response = client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let data: serde_json::Value = response.json().await?;
|
||||
parse_hits(&data)
|
||||
}
|
||||
|
||||
/// Store a new memory document for a user.
|
||||
pub async fn set(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
user_id: &str,
|
||||
content: &str,
|
||||
category: &str,
|
||||
source: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let now = Utc::now().timestamp_millis();
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
let doc = MemoryDocument {
|
||||
id: id.clone(),
|
||||
user_id: user_id.to_string(),
|
||||
content: content.to_string(),
|
||||
category: category.to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
source: source.to_string(),
|
||||
};
|
||||
|
||||
let response = client
|
||||
.index(opensearch::IndexParts::IndexId(index, &id))
|
||||
.body(serde_json::to_value(&doc)?)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status_code().is_success() {
|
||||
let body = response.text().await?;
|
||||
anyhow::bail!("Failed to store memory: {body}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn parse_hits(data: &serde_json::Value) -> anyhow::Result<Vec<MemoryDocument>> {
|
||||
let hits = data["hits"]["hits"]
|
||||
.as_array()
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut docs = Vec::with_capacity(hits.len());
|
||||
for hit in &hits {
|
||||
if let Ok(doc) = serde_json::from_value::<MemoryDocument>(hit["_source"].clone()) {
|
||||
docs.push(doc);
|
||||
}
|
||||
}
|
||||
Ok(docs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn fake_os_response(sources: Vec<serde_json::Value>) -> serde_json::Value {
|
||||
let hits: Vec<serde_json::Value> = sources
|
||||
.into_iter()
|
||||
.map(|s| json!({ "_source": s }))
|
||||
.collect();
|
||||
json!({ "hits": { "hits": hits } })
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hits_multiple() {
|
||||
let data = fake_os_response(vec![
|
||||
json!({
|
||||
"id": "a", "user_id": "sienna@sunbeam.pt",
|
||||
"content": "prefers terse answers", "category": "preference",
|
||||
"created_at": 1710000000000_i64, "updated_at": 1710000000000_i64,
|
||||
"source": "auto"
|
||||
}),
|
||||
json!({
|
||||
"id": "b", "user_id": "sienna@sunbeam.pt",
|
||||
"content": "working on drive UI", "category": "fact",
|
||||
"created_at": 1710000000000_i64, "updated_at": 1710000000000_i64,
|
||||
"source": "script"
|
||||
}),
|
||||
]);
|
||||
|
||||
let docs = parse_hits(&data).unwrap();
|
||||
assert_eq!(docs.len(), 2);
|
||||
assert_eq!(docs[0].id, "a");
|
||||
assert_eq!(docs[0].content, "prefers terse answers");
|
||||
assert_eq!(docs[1].id, "b");
|
||||
assert_eq!(docs[1].category, "fact");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hits_empty() {
|
||||
let data = json!({ "hits": { "hits": [] } });
|
||||
let docs = parse_hits(&data).unwrap();
|
||||
assert!(docs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hits_missing_structure() {
|
||||
let data = json!({});
|
||||
let docs = parse_hits(&data).unwrap();
|
||||
assert!(docs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hits_skips_malformed() {
|
||||
let data = fake_os_response(vec![
|
||||
json!({
|
||||
"id": "good", "user_id": "x@y",
|
||||
"content": "ok", "category": "fact",
|
||||
"created_at": 1, "updated_at": 1, "source": "auto"
|
||||
}),
|
||||
json!({ "bad": "no fields" }),
|
||||
]);
|
||||
|
||||
let docs = parse_hits(&data).unwrap();
|
||||
assert_eq!(docs.len(), 1);
|
||||
assert_eq!(docs[0].id, "good");
|
||||
}
|
||||
}
|
||||
141
src/sync.rs
141
src/sync.rs
@@ -1,13 +1,18 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use matrix_sdk::config::SyncSettings;
|
||||
use matrix_sdk::room::Room;
|
||||
use matrix_sdk::Client;
|
||||
use ruma::events::reaction::OriginalSyncReactionEvent;
|
||||
use ruma::events::room::member::StrippedRoomMemberEvent;
|
||||
use ruma::events::room::message::OriginalSyncRoomMessageEvent;
|
||||
use ruma::events::room::redaction::OriginalSyncRoomRedactionEvent;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use opensearch::OpenSearch;
|
||||
|
||||
use crate::archive::indexer::Indexer;
|
||||
use crate::archive::schema::ArchiveDocument;
|
||||
@@ -15,7 +20,9 @@ use crate::brain::conversation::{ContextMessage, ConversationManager};
|
||||
use crate::brain::evaluator::{Engagement, Evaluator};
|
||||
use crate::brain::responder::Responder;
|
||||
use crate::config::Config;
|
||||
use crate::context::{self, ResponseContext};
|
||||
use crate::matrix_utils;
|
||||
use crate::memory;
|
||||
|
||||
pub struct AppState {
|
||||
pub config: Arc<Config>,
|
||||
@@ -24,6 +31,11 @@ pub struct AppState {
|
||||
pub responder: Arc<Responder>,
|
||||
pub conversations: Arc<Mutex<ConversationManager>>,
|
||||
pub mistral: Arc<mistralai_client::v1::client::Client>,
|
||||
pub opensearch: OpenSearch,
|
||||
/// Tracks when Sol last responded in each room (for cooldown)
|
||||
pub last_response: Arc<Mutex<HashMap<String, Instant>>>,
|
||||
/// Tracks rooms where a response is currently being generated (in-flight guard)
|
||||
pub responding_in: Arc<Mutex<std::collections::HashSet<String>>>,
|
||||
}
|
||||
|
||||
pub async fn start_sync(client: Client, state: Arc<AppState>) -> anyhow::Result<()> {
|
||||
@@ -50,6 +62,16 @@ pub async fn start_sync(client: Client, state: Arc<AppState>) -> anyhow::Result<
|
||||
},
|
||||
);
|
||||
|
||||
let s = state.clone();
|
||||
client.add_event_handler(
|
||||
move |event: OriginalSyncReactionEvent, _room: Room| {
|
||||
let state = s.clone();
|
||||
async move {
|
||||
handle_reaction(event, &state).await;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
client.add_event_handler(
|
||||
move |event: StrippedRoomMemberEvent, room: Room| async move {
|
||||
handle_invite(event, room).await;
|
||||
@@ -95,6 +117,7 @@ async fn handle_message(
|
||||
.and_then(|m| m.display_name().map(|s| s.to_string()));
|
||||
|
||||
let reply_to = matrix_utils::extract_reply_to(&event).map(|id| id.to_string());
|
||||
let is_reply = reply_to.is_some();
|
||||
let thread_id = matrix_utils::extract_thread_id(&event).map(|id| id.to_string());
|
||||
|
||||
// Archive the message
|
||||
@@ -112,11 +135,22 @@ async fn handle_message(
|
||||
event_type: "m.room.message".into(),
|
||||
edited: false,
|
||||
redacted: false,
|
||||
reactions: Vec::new(),
|
||||
};
|
||||
state.indexer.add(doc).await;
|
||||
|
||||
// Update conversation context
|
||||
let is_dm = room.is_direct().await.unwrap_or(false);
|
||||
|
||||
let response_ctx = ResponseContext {
|
||||
matrix_user_id: sender.clone(),
|
||||
user_id: context::derive_user_id(&sender),
|
||||
display_name: sender_name.clone(),
|
||||
is_dm,
|
||||
is_reply,
|
||||
room_id: room_id.clone(),
|
||||
};
|
||||
|
||||
{
|
||||
let mut convs = state.conversations.lock().await;
|
||||
convs.add_message(
|
||||
@@ -147,13 +181,20 @@ async fn handle_message(
|
||||
|
||||
let (should_respond, is_spontaneous) = match engagement {
|
||||
Engagement::MustRespond { reason } => {
|
||||
info!(?reason, "Must respond");
|
||||
info!(room = room_id.as_str(), ?reason, "Must respond");
|
||||
(true, false)
|
||||
}
|
||||
Engagement::MaybeRespond { relevance, hook } => {
|
||||
info!(relevance, hook = hook.as_str(), "Maybe respond (spontaneous)");
|
||||
info!(room = room_id.as_str(), relevance, hook = hook.as_str(), "Maybe respond (spontaneous)");
|
||||
(true, true)
|
||||
}
|
||||
Engagement::React { emoji, relevance } => {
|
||||
info!(room = room_id.as_str(), relevance, emoji = emoji.as_str(), "Reacting with emoji");
|
||||
if let Err(e) = matrix_utils::send_reaction(&room, event.event_id.clone().into(), &emoji).await {
|
||||
error!("Failed to send reaction: {e}");
|
||||
}
|
||||
(false, false)
|
||||
}
|
||||
Engagement::Ignore => (false, false),
|
||||
};
|
||||
|
||||
@@ -161,8 +202,38 @@ async fn handle_message(
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Show typing indicator
|
||||
let _ = room.typing_notice(true).await;
|
||||
// In-flight guard: skip if we're already generating a response for this room
|
||||
{
|
||||
let responding = state.responding_in.lock().await;
|
||||
if responding.contains(&room_id) {
|
||||
debug!(room = room_id.as_str(), "Skipping — response already in flight for this room");
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
// Cooldown check: skip spontaneous if we responded recently
|
||||
if is_spontaneous {
|
||||
let last = state.last_response.lock().await;
|
||||
if let Some(ts) = last.get(&room_id) {
|
||||
let elapsed = ts.elapsed().as_millis() as u64;
|
||||
let cooldown = state.config.behavior.cooldown_after_response_ms;
|
||||
if elapsed < cooldown {
|
||||
debug!(
|
||||
room = room_id.as_str(),
|
||||
elapsed_ms = elapsed,
|
||||
cooldown_ms = cooldown,
|
||||
"Skipping spontaneous — within cooldown period"
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Mark room as in-flight
|
||||
{
|
||||
let mut responding = state.responding_in.lock().await;
|
||||
responding.insert(room_id.clone());
|
||||
}
|
||||
|
||||
let context = {
|
||||
let convs = state.conversations.lock().await;
|
||||
@@ -181,22 +252,74 @@ async fn handle_message(
|
||||
&members,
|
||||
is_spontaneous,
|
||||
&state.mistral,
|
||||
&room,
|
||||
&response_ctx,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Stop typing indicator
|
||||
let _ = room.typing_notice(false).await;
|
||||
|
||||
if let Some(text) = response {
|
||||
let content = matrix_utils::make_reply_content(&text, event.event_id.to_owned());
|
||||
// Reply with reference only when directly addressed. Spontaneous
|
||||
// and DM messages are sent as plain content — feels more natural.
|
||||
let content = if !is_spontaneous && !is_dm {
|
||||
matrix_utils::make_reply_content(&text, event.event_id.to_owned())
|
||||
} else {
|
||||
ruma::events::room::message::RoomMessageEventContent::text_markdown(&text)
|
||||
};
|
||||
if let Err(e) = room.send(content).await {
|
||||
error!("Failed to send response: {e}");
|
||||
} else {
|
||||
info!(room = room_id.as_str(), len = text.len(), is_dm, "Response sent");
|
||||
}
|
||||
// Post-response memory extraction (fire-and-forget)
|
||||
if state.config.behavior.memory_extraction_enabled {
|
||||
let ctx = response_ctx.clone();
|
||||
let mistral = state.mistral.clone();
|
||||
let os = state.opensearch.clone();
|
||||
let config = state.config.clone();
|
||||
let user_msg = body.clone();
|
||||
let sol_response = text.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = memory::extractor::extract_and_store(
|
||||
&mistral, &os, &config, &ctx, &user_msg, &sol_response,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("Memory extraction failed (non-fatal): {e}");
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Update last response timestamp
|
||||
let mut last = state.last_response.lock().await;
|
||||
last.insert(room_id.clone(), Instant::now());
|
||||
}
|
||||
|
||||
// Clear in-flight flag
|
||||
{
|
||||
let mut responding = state.responding_in.lock().await;
|
||||
responding.remove(&room_id);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_reaction(event: OriginalSyncReactionEvent, state: &AppState) {
|
||||
let target_event_id = event.content.relates_to.event_id.to_string();
|
||||
let sender = event.sender.to_string();
|
||||
let emoji = &event.content.relates_to.key;
|
||||
let timestamp: i64 = event.origin_server_ts.0.into();
|
||||
|
||||
info!(
|
||||
target = target_event_id.as_str(),
|
||||
sender = sender.as_str(),
|
||||
emoji = emoji.as_str(),
|
||||
"Indexing reaction"
|
||||
);
|
||||
|
||||
state.indexer.add_reaction(&target_event_id, &sender, emoji, timestamp).await;
|
||||
}
|
||||
|
||||
async fn handle_redaction(event: OriginalSyncRoomRedactionEvent, state: &AppState) {
|
||||
if let Some(redacted_id) = &event.redacts {
|
||||
state.indexer.update_redaction(&redacted_id.to_string()).await;
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod room_history;
|
||||
pub mod room_info;
|
||||
pub mod script;
|
||||
pub mod search;
|
||||
|
||||
use std::sync::Arc;
|
||||
@@ -10,6 +11,7 @@ use opensearch::OpenSearch;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
|
||||
pub struct ToolRegistry {
|
||||
opensearch: OpenSearch,
|
||||
@@ -122,10 +124,44 @@ impl ToolRegistry {
|
||||
"required": ["room_id"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"run_script".into(),
|
||||
"Execute a TypeScript/JavaScript snippet in a sandboxed runtime. \
|
||||
Use this for math, date calculations, data transformations, or any \
|
||||
computation that needs precision. The script has access to:\n\
|
||||
- sol.search(query, opts?) — search the message archive. opts: \
|
||||
{ room?, sender?, after?, before?, limit?, semantic? }\n\
|
||||
- sol.rooms() — list joined rooms (returns array of {name, id, members})\n\
|
||||
- sol.members(roomName) — get room members (returns array of {name, id})\n\
|
||||
- sol.fetch(url) — HTTP GET (allowlisted domains only)\n\
|
||||
- sol.memory.get(query?) — retrieve internal notes relevant to the query\n\
|
||||
- sol.memory.set(content, category?) — save an internal note for later reference\n\
|
||||
- sol.fs.read(path), sol.fs.write(path, content), sol.fs.list(path?) — \
|
||||
sandboxed temp filesystem for intermediate files\n\
|
||||
- console.log() to produce output\n\
|
||||
All sol.* methods are async — use await. The last expression value is \
|
||||
also captured. Output is truncated to 4096 chars."
|
||||
.into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"code": {
|
||||
"type": "string",
|
||||
"description": "TypeScript or JavaScript code to execute"
|
||||
}
|
||||
},
|
||||
"required": ["code"]
|
||||
}),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
pub async fn execute(&self, name: &str, arguments: &str) -> anyhow::Result<String> {
|
||||
pub async fn execute(
|
||||
&self,
|
||||
name: &str,
|
||||
arguments: &str,
|
||||
response_ctx: &ResponseContext,
|
||||
) -> anyhow::Result<String> {
|
||||
match name {
|
||||
"search_archive" => {
|
||||
search::search_archive(
|
||||
@@ -145,6 +181,16 @@ impl ToolRegistry {
|
||||
}
|
||||
"list_rooms" => room_info::list_rooms(&self.matrix).await,
|
||||
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await,
|
||||
"run_script" => {
|
||||
script::run_script(
|
||||
&self.opensearch,
|
||||
&self.matrix,
|
||||
&self.config,
|
||||
arguments,
|
||||
response_ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
_ => anyhow::bail!("Unknown tool: {name}"),
|
||||
}
|
||||
}
|
||||
|
||||
706
src/tools/script.rs
Normal file
706
src/tools/script.rs
Normal file
@@ -0,0 +1,706 @@
|
||||
use std::cell::RefCell;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::rc::Rc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use deno_core::{extension, op2, JsRuntime, OpState, RuntimeOptions};
|
||||
use deno_error::JsErrorBox;
|
||||
use matrix_sdk::Client as MatrixClient;
|
||||
use opensearch::OpenSearch;
|
||||
use serde::Deserialize;
|
||||
use tempfile::TempDir;
|
||||
use tracing::info;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// State types stored in OpState
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
struct ScriptState {
|
||||
opensearch: OpenSearch,
|
||||
matrix: MatrixClient,
|
||||
config: Arc<Config>,
|
||||
tmpdir: PathBuf,
|
||||
user_id: String,
|
||||
}
|
||||
|
||||
struct ScriptOutput(String);
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct RunScriptArgs {
|
||||
code: String,
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sandbox path resolution
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn resolve_sandbox_path(
|
||||
tmpdir: &Path,
|
||||
requested: &str,
|
||||
) -> Result<PathBuf, JsErrorBox> {
|
||||
let base = tmpdir
|
||||
.canonicalize()
|
||||
.map_err(|e| JsErrorBox::generic(format!("sandbox error: {e}")))?;
|
||||
let joined = base.join(requested);
|
||||
|
||||
if let Some(parent) = joined.parent() {
|
||||
std::fs::create_dir_all(parent).ok();
|
||||
}
|
||||
|
||||
let resolved = if joined.exists() {
|
||||
joined
|
||||
.canonicalize()
|
||||
.map_err(|e| JsErrorBox::generic(format!("path error: {e}")))?
|
||||
} else {
|
||||
let parent = joined
|
||||
.parent()
|
||||
.ok_or_else(|| JsErrorBox::generic("invalid path"))?
|
||||
.canonicalize()
|
||||
.map_err(|e| JsErrorBox::generic(format!("path error: {e}")))?;
|
||||
parent.join(joined.file_name().unwrap_or_default())
|
||||
};
|
||||
|
||||
if !resolved.starts_with(&base) {
|
||||
return Err(JsErrorBox::generic("path escapes sandbox"));
|
||||
}
|
||||
|
||||
Ok(resolved)
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ops — async (search, rooms, members, fetch)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[op2]
|
||||
#[string]
|
||||
async fn op_sol_search(
|
||||
state: Rc<RefCell<OpState>>,
|
||||
#[string] query: String,
|
||||
#[string] opts_json: String,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let (os, index) = {
|
||||
let st = state.borrow();
|
||||
let ss = st.borrow::<ScriptState>();
|
||||
(ss.opensearch.clone(), ss.config.opensearch.index.clone())
|
||||
};
|
||||
|
||||
let mut args: serde_json::Value =
|
||||
serde_json::from_str(&opts_json).unwrap_or(serde_json::json!({}));
|
||||
args["query"] = serde_json::Value::String(query);
|
||||
|
||||
super::search::search_archive(&os, &index, &args.to_string())
|
||||
.await
|
||||
.map_err(|e| JsErrorBox::generic(e.to_string()))
|
||||
}
|
||||
|
||||
#[op2]
|
||||
#[string]
|
||||
async fn op_sol_rooms(
|
||||
state: Rc<RefCell<OpState>>,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let matrix = {
|
||||
let st = state.borrow();
|
||||
st.borrow::<ScriptState>().matrix.clone()
|
||||
};
|
||||
|
||||
let rooms = matrix.joined_rooms();
|
||||
let names: Vec<serde_json::Value> = rooms
|
||||
.iter()
|
||||
.map(|r| {
|
||||
let name = match r.cached_display_name() {
|
||||
Some(n) => n.to_string(),
|
||||
None => r.room_id().to_string(),
|
||||
};
|
||||
serde_json::json!({
|
||||
"name": name,
|
||||
"id": r.room_id().to_string(),
|
||||
"members": r.joined_members_count(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::to_string(&names).map_err(|e| JsErrorBox::generic(e.to_string()))
|
||||
}
|
||||
|
||||
#[op2]
|
||||
#[string]
|
||||
async fn op_sol_members(
|
||||
state: Rc<RefCell<OpState>>,
|
||||
#[string] room_name: String,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let matrix = {
|
||||
let st = state.borrow();
|
||||
st.borrow::<ScriptState>().matrix.clone()
|
||||
};
|
||||
|
||||
let rooms = matrix.joined_rooms();
|
||||
let room = rooms.iter().find(|r| {
|
||||
r.cached_display_name()
|
||||
.map(|n| n.to_string() == room_name)
|
||||
.unwrap_or(false)
|
||||
});
|
||||
|
||||
let Some(room) = room else {
|
||||
return Err(JsErrorBox::generic(format!(
|
||||
"room not found: {room_name}"
|
||||
)));
|
||||
};
|
||||
|
||||
let members = room
|
||||
.members(matrix_sdk::RoomMemberships::JOIN)
|
||||
.await
|
||||
.map_err(|e| JsErrorBox::generic(e.to_string()))?;
|
||||
|
||||
let names: Vec<serde_json::Value> = members
|
||||
.iter()
|
||||
.map(|m| {
|
||||
serde_json::json!({
|
||||
"name": m.display_name().unwrap_or_else(|| m.user_id().as_str()),
|
||||
"id": m.user_id().to_string(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::to_string(&names).map_err(|e| JsErrorBox::generic(e.to_string()))
|
||||
}
|
||||
|
||||
#[op2]
|
||||
#[string]
|
||||
async fn op_sol_fetch(
|
||||
state: Rc<RefCell<OpState>>,
|
||||
#[string] url_str: String,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let allowlist = {
|
||||
let st = state.borrow();
|
||||
st.borrow::<ScriptState>()
|
||||
.config
|
||||
.behavior
|
||||
.script_fetch_allowlist
|
||||
.clone()
|
||||
};
|
||||
|
||||
let parsed = url::Url::parse(&url_str)
|
||||
.map_err(|e| JsErrorBox::generic(format!("invalid URL: {e}")))?;
|
||||
let domain = parsed
|
||||
.host_str()
|
||||
.ok_or_else(|| JsErrorBox::generic("URL has no host"))?;
|
||||
|
||||
if !allowlist
|
||||
.iter()
|
||||
.any(|d| domain == d || domain.ends_with(&format!(".{d}")))
|
||||
{
|
||||
return Err(JsErrorBox::generic(format!(
|
||||
"domain not in allowlist: {domain}"
|
||||
)));
|
||||
}
|
||||
|
||||
let resp = reqwest::get(&url_str)
|
||||
.await
|
||||
.map_err(|e| JsErrorBox::generic(e.to_string()))?;
|
||||
|
||||
let text = resp
|
||||
.text()
|
||||
.await
|
||||
.map_err(|e| JsErrorBox::generic(e.to_string()))?;
|
||||
|
||||
let max_len = 32768;
|
||||
if text.len() > max_len {
|
||||
Ok(format!("{}...(truncated)", &text[..max_len]))
|
||||
} else {
|
||||
Ok(text)
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ops — sync (filesystem sandbox + output collection)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[op2]
|
||||
#[string]
|
||||
fn op_sol_read_file(
|
||||
state: &mut OpState,
|
||||
#[string] path: String,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let tmpdir = state.borrow::<ScriptState>().tmpdir.clone();
|
||||
let resolved = resolve_sandbox_path(&tmpdir, &path)?;
|
||||
std::fs::read_to_string(&resolved)
|
||||
.map_err(|e| JsErrorBox::generic(format!("read error: {e}")))
|
||||
}
|
||||
|
||||
#[op2(fast)]
|
||||
fn op_sol_write_file(
|
||||
state: &mut OpState,
|
||||
#[string] path: String,
|
||||
#[string] content: String,
|
||||
) -> Result<(), JsErrorBox> {
|
||||
let tmpdir = state.borrow::<ScriptState>().tmpdir.clone();
|
||||
let resolved = resolve_sandbox_path(&tmpdir, &path)?;
|
||||
std::fs::write(&resolved, content)
|
||||
.map_err(|e| JsErrorBox::generic(format!("write error: {e}")))
|
||||
}
|
||||
|
||||
#[op2]
|
||||
#[string]
|
||||
fn op_sol_list_dir(
|
||||
state: &mut OpState,
|
||||
#[string] path: String,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let tmpdir = state.borrow::<ScriptState>().tmpdir.clone();
|
||||
let resolved = resolve_sandbox_path(&tmpdir, &path)?;
|
||||
|
||||
let entries: Vec<String> = std::fs::read_dir(&resolved)
|
||||
.map_err(|e| JsErrorBox::generic(format!("list error: {e}")))?
|
||||
.filter_map(|e| e.ok())
|
||||
.map(|e| e.file_name().to_string_lossy().into_owned())
|
||||
.collect();
|
||||
|
||||
serde_json::to_string(&entries).map_err(|e| JsErrorBox::generic(e.to_string()))
|
||||
}
|
||||
|
||||
#[op2(fast)]
|
||||
fn op_sol_set_output(
|
||||
state: &mut OpState,
|
||||
#[string] output: String,
|
||||
) -> Result<(), JsErrorBox> {
|
||||
*state.borrow_mut::<ScriptOutput>() = ScriptOutput(output);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Ops — async (memory)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[op2]
|
||||
#[string]
|
||||
async fn op_sol_memory_get(
|
||||
state: Rc<RefCell<OpState>>,
|
||||
#[string] query: String,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let (os, index, user_id) = {
|
||||
let st = state.borrow();
|
||||
let ss = st.borrow::<ScriptState>();
|
||||
(
|
||||
ss.opensearch.clone(),
|
||||
ss.config.opensearch.memory_index.clone(),
|
||||
ss.user_id.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
let results = if query.is_empty() {
|
||||
crate::memory::store::get_recent(&os, &index, &user_id, 10)
|
||||
.await
|
||||
.map_err(|e| JsErrorBox::generic(e.to_string()))?
|
||||
} else {
|
||||
crate::memory::store::query(&os, &index, &user_id, &query, 10)
|
||||
.await
|
||||
.map_err(|e| JsErrorBox::generic(e.to_string()))?
|
||||
};
|
||||
|
||||
let items: Vec<serde_json::Value> = results
|
||||
.iter()
|
||||
.map(|m| {
|
||||
serde_json::json!({
|
||||
"content": m.content,
|
||||
"category": m.category,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::to_string(&items).map_err(|e| JsErrorBox::generic(e.to_string()))
|
||||
}
|
||||
|
||||
#[op2]
|
||||
#[string]
|
||||
async fn op_sol_memory_set(
|
||||
state: Rc<RefCell<OpState>>,
|
||||
#[string] content: String,
|
||||
#[string] category: String,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let (os, index, user_id) = {
|
||||
let st = state.borrow();
|
||||
let ss = st.borrow::<ScriptState>();
|
||||
(
|
||||
ss.opensearch.clone(),
|
||||
ss.config.opensearch.memory_index.clone(),
|
||||
ss.user_id.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
crate::memory::store::set(&os, &index, &user_id, &content, &category, "script")
|
||||
.await
|
||||
.map_err(|e| JsErrorBox::generic(e.to_string()))?;
|
||||
|
||||
Ok("ok".into())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Extension
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
extension!(
|
||||
sol_script,
|
||||
ops = [
|
||||
op_sol_search,
|
||||
op_sol_rooms,
|
||||
op_sol_members,
|
||||
op_sol_fetch,
|
||||
op_sol_read_file,
|
||||
op_sol_write_file,
|
||||
op_sol_list_dir,
|
||||
op_sol_set_output,
|
||||
op_sol_memory_get,
|
||||
op_sol_memory_set,
|
||||
],
|
||||
state = |state| {
|
||||
state.put(ScriptOutput(String::new()));
|
||||
},
|
||||
);
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Bootstrap JS — injected before user code
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const BOOTSTRAP_JS: &str = r#"
|
||||
const __output = [];
|
||||
globalThis.console = {
|
||||
log: (...args) => __output.push(args.map(a => typeof a === 'object' ? JSON.stringify(a) : String(a)).join(' ')),
|
||||
error: (...args) => __output.push('ERROR: ' + args.map(a => typeof a === 'object' ? JSON.stringify(a) : String(a)).join(' ')),
|
||||
warn: (...args) => __output.push('WARN: ' + args.map(a => typeof a === 'object' ? JSON.stringify(a) : String(a)).join(' ')),
|
||||
info: (...args) => __output.push(args.map(a => typeof a === 'object' ? JSON.stringify(a) : String(a)).join(' ')),
|
||||
};
|
||||
globalThis.sol = {
|
||||
search: (query, opts) => Deno.core.ops.op_sol_search(query, opts ? JSON.stringify(opts) : "{}"),
|
||||
rooms: async () => JSON.parse(await Deno.core.ops.op_sol_rooms()),
|
||||
members: async (room) => JSON.parse(await Deno.core.ops.op_sol_members(room)),
|
||||
fetch: (url) => Deno.core.ops.op_sol_fetch(url),
|
||||
memory: {
|
||||
get: async (query) => JSON.parse(await Deno.core.ops.op_sol_memory_get(query || "")),
|
||||
set: async (content, category) => {
|
||||
await Deno.core.ops.op_sol_memory_set(content, category || "general");
|
||||
},
|
||||
},
|
||||
fs: {
|
||||
read: (path) => Deno.core.ops.op_sol_read_file(path),
|
||||
write: (path, content) => Deno.core.ops.op_sol_write_file(path, content),
|
||||
list: (path) => JSON.parse(Deno.core.ops.op_sol_list_dir(path || ".")),
|
||||
},
|
||||
};
|
||||
"#;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// TypeScript transpilation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
fn transpile_ts(code: &str) -> anyhow::Result<String> {
|
||||
use deno_ast::{MediaType, ParseParams, TranspileModuleOptions};
|
||||
|
||||
let specifier = deno_ast::ModuleSpecifier::parse("file:///script.ts")?;
|
||||
|
||||
let parsed = deno_ast::parse_module(ParseParams {
|
||||
specifier,
|
||||
text: code.into(),
|
||||
media_type: MediaType::TypeScript,
|
||||
capture_tokens: false,
|
||||
maybe_syntax: None,
|
||||
scope_analysis: false,
|
||||
})?;
|
||||
|
||||
let transpiled = parsed.transpile(
|
||||
&deno_ast::TranspileOptions::default(),
|
||||
&TranspileModuleOptions::default(),
|
||||
&deno_ast::EmitOptions::default(),
|
||||
)?;
|
||||
|
||||
Ok(transpiled.into_source().text.to_string())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Public entry point
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
pub async fn run_script(
|
||||
opensearch: &OpenSearch,
|
||||
matrix: &MatrixClient,
|
||||
config: &Config,
|
||||
args_json: &str,
|
||||
response_ctx: &ResponseContext,
|
||||
) -> anyhow::Result<String> {
|
||||
let args: RunScriptArgs = serde_json::from_str(args_json)?;
|
||||
let code = args.code.clone();
|
||||
|
||||
info!(code_len = code.len(), "Executing script");
|
||||
|
||||
// Transpile TS to JS
|
||||
let js_code = transpile_ts(&code)?;
|
||||
|
||||
// Clone state for move into spawn_blocking
|
||||
let os = opensearch.clone();
|
||||
let mx = matrix.clone();
|
||||
let cfg = Arc::new(config.clone());
|
||||
let timeout_secs = cfg.behavior.script_timeout_secs;
|
||||
let max_heap_mb = cfg.behavior.script_max_heap_mb;
|
||||
let user_id = response_ctx.user_id.clone();
|
||||
|
||||
// Wrap user code: async IIFE that captures output, then stores via op
|
||||
let wrapped = format!(
|
||||
r#"
|
||||
(async () => {{
|
||||
try {{
|
||||
const __result = await (async () => {{
|
||||
{js_code}
|
||||
}})();
|
||||
if (__result !== undefined) {{
|
||||
__output.push(typeof __result === 'object' ? JSON.stringify(__result, null, 2) : String(__result));
|
||||
}}
|
||||
}} catch(e) {{
|
||||
__output.push('Error: ' + (e.stack || e.message || String(e)));
|
||||
}}
|
||||
Deno.core.ops.op_sol_set_output(JSON.stringify(__output));
|
||||
}})();"#
|
||||
);
|
||||
|
||||
let timeout = Duration::from_secs(timeout_secs);
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
let rt = tokio::runtime::Builder::new_current_thread()
|
||||
.enable_all()
|
||||
.build()?;
|
||||
|
||||
rt.block_on(async move {
|
||||
let tmpdir = TempDir::new()?;
|
||||
let tmpdir_path = tmpdir.path().to_path_buf();
|
||||
|
||||
let mut runtime = JsRuntime::new(RuntimeOptions {
|
||||
extensions: vec![sol_script::init()],
|
||||
create_params: Some(
|
||||
deno_core::v8::CreateParams::default()
|
||||
.heap_limits(0, max_heap_mb * 1024 * 1024),
|
||||
),
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
// Inject shared state
|
||||
{
|
||||
let op_state = runtime.op_state();
|
||||
let mut state = op_state.borrow_mut();
|
||||
state.put(ScriptState {
|
||||
opensearch: os,
|
||||
matrix: mx,
|
||||
config: cfg,
|
||||
tmpdir: tmpdir_path,
|
||||
user_id,
|
||||
});
|
||||
}
|
||||
|
||||
// V8 isolate termination for timeout
|
||||
let done = Arc::new(AtomicBool::new(false));
|
||||
let done_clone = done.clone();
|
||||
let isolate_handle = runtime.v8_isolate().thread_safe_handle();
|
||||
std::thread::spawn(move || {
|
||||
let deadline = std::time::Instant::now() + timeout;
|
||||
while std::time::Instant::now() < deadline {
|
||||
if done_clone.load(Ordering::Relaxed) {
|
||||
return;
|
||||
}
|
||||
std::thread::sleep(Duration::from_millis(50));
|
||||
}
|
||||
isolate_handle.terminate_execution();
|
||||
});
|
||||
|
||||
// Execute bootstrap
|
||||
runtime.execute_script("<bootstrap>", BOOTSTRAP_JS)?;
|
||||
|
||||
// Execute user code
|
||||
let exec_result = runtime.execute_script("<script>", wrapped);
|
||||
|
||||
if let Err(e) = exec_result {
|
||||
done.store(true, Ordering::Relaxed);
|
||||
let msg = e.to_string();
|
||||
if msg.contains("terminated") {
|
||||
return Ok("Error: script execution timed out".into());
|
||||
}
|
||||
return Ok(format!("Error: {msg}"));
|
||||
}
|
||||
|
||||
// Drive async ops to completion
|
||||
if let Err(e) = runtime.run_event_loop(Default::default()).await {
|
||||
done.store(true, Ordering::Relaxed);
|
||||
let msg = e.to_string();
|
||||
if msg.contains("terminated") {
|
||||
return Ok("Error: script execution timed out".into());
|
||||
}
|
||||
return Ok(format!("Error: {msg}"));
|
||||
}
|
||||
|
||||
done.store(true, Ordering::Relaxed);
|
||||
|
||||
// Read captured output from OpState
|
||||
let output_json = {
|
||||
let op_state = runtime.op_state();
|
||||
let state = op_state.borrow();
|
||||
state.borrow::<ScriptOutput>().0.clone()
|
||||
};
|
||||
|
||||
let output_arr: Vec<String> =
|
||||
serde_json::from_str(&output_json).unwrap_or_else(|_| {
|
||||
if output_json.is_empty() {
|
||||
vec![]
|
||||
} else {
|
||||
vec![output_json]
|
||||
}
|
||||
});
|
||||
let result = output_arr.join("\n");
|
||||
|
||||
let max_output = 4096;
|
||||
let result = if result.len() > max_output {
|
||||
format!("{}...(truncated)", &result[..max_output])
|
||||
} else {
|
||||
result
|
||||
};
|
||||
|
||||
drop(tmpdir);
|
||||
|
||||
Ok::<String, anyhow::Error>(result)
|
||||
})
|
||||
})
|
||||
.await?;
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Tests
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_run_script_args() {
|
||||
let args: RunScriptArgs =
|
||||
serde_json::from_str(r#"{"code": "console.log(42)"}"#).unwrap();
|
||||
assert_eq!(args.code, "console.log(42)");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpile_ts_basic() {
|
||||
let js = transpile_ts("const x: number = 42; x;").unwrap();
|
||||
assert!(js.contains("const x = 42"));
|
||||
assert!(!js.contains(": number"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpile_ts_arrow() {
|
||||
let js =
|
||||
transpile_ts("const add = (a: number, b: number): number => a + b;").unwrap();
|
||||
assert!(js.contains("const add"));
|
||||
assert!(!js.contains(": number"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpile_ts_interface() {
|
||||
let js = transpile_ts(
|
||||
"interface Foo { bar: string; } const x: Foo = { bar: 'hello' };",
|
||||
)
|
||||
.unwrap();
|
||||
assert!(!js.contains("interface"));
|
||||
assert!(js.contains("const x"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_transpile_ts_invalid() {
|
||||
let result = transpile_ts("const x: number = {");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sandbox_path_within() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let base = dir.path();
|
||||
std::fs::write(base.join("test.txt"), "hello").unwrap();
|
||||
|
||||
let resolved = resolve_sandbox_path(base, "test.txt").unwrap();
|
||||
assert!(resolved.starts_with(base.canonicalize().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sandbox_path_escape_rejected() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let base = dir.path();
|
||||
|
||||
// Path traversal is blocked — either by canonicalization failure
|
||||
// (parent doesn't exist) or by the starts_with sandbox check
|
||||
let result = resolve_sandbox_path(base, "../../../etc/passwd");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sandbox_path_symlink_escape_rejected() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let base = dir.path();
|
||||
|
||||
// Create a symlink pointing outside the sandbox
|
||||
let link_path = base.join("escape");
|
||||
std::os::unix::fs::symlink("/tmp", &link_path).unwrap();
|
||||
|
||||
// Following the symlink resolves outside sandbox
|
||||
let result = resolve_sandbox_path(base, "escape/somefile");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sandbox_path_new_file() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let base = dir.path();
|
||||
|
||||
let resolved = resolve_sandbox_path(base, "newfile.txt").unwrap();
|
||||
assert!(resolved.starts_with(base.canonicalize().unwrap()));
|
||||
assert!(resolved.ends_with("newfile.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sandbox_path_nested_dir() {
|
||||
let dir = TempDir::new().unwrap();
|
||||
let base = dir.path();
|
||||
|
||||
let resolved = resolve_sandbox_path(base, "subdir/file.txt").unwrap();
|
||||
assert!(resolved.starts_with(base.canonicalize().unwrap()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_basic_script_execution() {
|
||||
let mut runtime = JsRuntime::new(RuntimeOptions {
|
||||
extensions: vec![sol_script::init()],
|
||||
..Default::default()
|
||||
});
|
||||
|
||||
runtime
|
||||
.execute_script("<bootstrap>", BOOTSTRAP_JS)
|
||||
.unwrap();
|
||||
|
||||
runtime
|
||||
.execute_script(
|
||||
"<test>",
|
||||
r#"
|
||||
console.log(2 ** 64);
|
||||
console.log(Math.PI);
|
||||
Deno.core.ops.op_sol_set_output(JSON.stringify(__output));
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let op_state = runtime.op_state();
|
||||
let state = op_state.borrow();
|
||||
let output = &state.borrow::<ScriptOutput>().0;
|
||||
let arr: Vec<String> = serde_json::from_str(output).unwrap();
|
||||
assert_eq!(arr.len(), 2);
|
||||
assert!(arr[0].contains("18446744073709552000"));
|
||||
assert!(arr[1].contains("3.14159"));
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,7 @@
|
||||
use opensearch::OpenSearch;
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
use tracing::debug;
|
||||
use tracing::{debug, info};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct SearchArgs {
|
||||
@@ -24,19 +24,24 @@ fn default_limit() -> usize { 10 }
|
||||
|
||||
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
|
||||
pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
|
||||
let must = vec![json!({
|
||||
"match": { "content": args.query }
|
||||
})];
|
||||
// Handle empty/wildcard queries as match_all
|
||||
let must = if args.query.is_empty() || args.query == "*" {
|
||||
vec![json!({ "match_all": {} })]
|
||||
} else {
|
||||
vec![json!({
|
||||
"match": { "content": args.query }
|
||||
})]
|
||||
};
|
||||
|
||||
let mut filter = vec![json!({
|
||||
"term": { "redacted": false }
|
||||
})];
|
||||
|
||||
if let Some(ref room) = args.room {
|
||||
filter.push(json!({ "term": { "room_name": room } }));
|
||||
filter.push(json!({ "term": { "room_name.keyword": room } }));
|
||||
}
|
||||
if let Some(ref sender) = args.sender {
|
||||
filter.push(json!({ "term": { "sender_name": sender } }));
|
||||
filter.push(json!({ "term": { "sender_name.keyword": sender } }));
|
||||
}
|
||||
|
||||
let mut range = serde_json::Map::new();
|
||||
@@ -73,10 +78,19 @@ pub async fn search_archive(
|
||||
args_json: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let args: SearchArgs = serde_json::from_str(args_json)?;
|
||||
debug!(query = args.query.as_str(), "Searching archive");
|
||||
|
||||
let query_body = build_search_query(&args);
|
||||
|
||||
info!(
|
||||
query = args.query.as_str(),
|
||||
room = args.room.as_deref().unwrap_or("*"),
|
||||
sender = args.sender.as_deref().unwrap_or("*"),
|
||||
after = args.after.as_deref().unwrap_or("*"),
|
||||
before = args.before.as_deref().unwrap_or("*"),
|
||||
limit = args.limit,
|
||||
query_json = %query_body,
|
||||
"Executing search"
|
||||
);
|
||||
|
||||
let response = client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(query_body)
|
||||
@@ -84,6 +98,8 @@ pub async fn search_archive(
|
||||
.await?;
|
||||
|
||||
let body: serde_json::Value = response.json().await?;
|
||||
let hit_count = body["hits"]["total"]["value"].as_i64().unwrap_or(0);
|
||||
info!(hit_count, "Search results");
|
||||
let hits = &body["hits"]["hits"];
|
||||
|
||||
let Some(hits_arr) = hits.as_array() else {
|
||||
@@ -173,7 +189,7 @@ mod tests {
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 2);
|
||||
assert_eq!(filters[1]["term"]["room_name"], "design");
|
||||
assert_eq!(filters[1]["term"]["room_name.keyword"], "design");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -183,7 +199,7 @@ mod tests {
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 2);
|
||||
assert_eq!(filters[1]["term"]["sender_name"], "Bob");
|
||||
assert_eq!(filters[1]["term"]["sender_name.keyword"], "Bob");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -193,8 +209,8 @@ mod tests {
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 3);
|
||||
assert_eq!(filters[1]["term"]["room_name"], "dev");
|
||||
assert_eq!(filters[2]["term"]["sender_name"], "Carol");
|
||||
assert_eq!(filters[1]["term"]["room_name.keyword"], "dev");
|
||||
assert_eq!(filters[2]["term"]["sender_name.keyword"], "Carol");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
Reference in New Issue
Block a user