feat: per-user auto-memory with ResponseContext
Three memory channels: hidden tool (sol.memory.set/get in scripts), pre-response injection (relevant memories loaded into system prompt), and post-response extraction (ministral-3b extracts facts after each response). User isolation enforced at Rust level — user_id derived from Matrix sender, never from script arguments. New modules: context (ResponseContext), memory (schema, store, extractor). ResponseContext threaded through responder → tools → script runtime. OpenSearch index sol_user_memory created on startup alongside archive.
This commit is contained in:
158
src/memory/extractor.rs
Normal file
158
src/memory/extractor.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatParams, ResponseFormat},
|
||||
constants::Model,
|
||||
};
|
||||
use opensearch::OpenSearch;
|
||||
use serde::Deserialize;
|
||||
use tracing::{debug, warn};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::brain::responder::chat_blocking;
|
||||
|
||||
use super::store;
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ExtractionResponse {
|
||||
pub memories: Vec<ExtractedMemory>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct ExtractedMemory {
|
||||
pub content: String,
|
||||
pub category: String,
|
||||
}
|
||||
|
||||
/// Validate and normalize a category string.
|
||||
pub(crate) fn normalize_category(raw: &str) -> &str {
|
||||
match raw {
|
||||
"preference" | "fact" | "context" => raw,
|
||||
_ => "general",
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn extract_and_store(
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
opensearch: &OpenSearch,
|
||||
config: &Config,
|
||||
ctx: &ResponseContext,
|
||||
user_message: &str,
|
||||
sol_response: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let display = ctx
|
||||
.display_name
|
||||
.as_deref()
|
||||
.unwrap_or(&ctx.matrix_user_id);
|
||||
|
||||
let prompt = format!(
|
||||
"Analyze this conversation exchange and extract any facts worth remembering about {display}.\n\
|
||||
Focus on: preferences, personal details, ongoing projects, opinions, recurring topics.\n\n\
|
||||
They said: {user_message}\n\
|
||||
Response: {sol_response}\n\n\
|
||||
Respond ONLY with JSON: {{\"memories\": [{{\"content\": \"...\", \"category\": \"preference|fact|context\"}}]}}\n\
|
||||
If nothing worth remembering, respond with {{\"memories\": []}}.\n\
|
||||
Be selective — only genuinely useful information."
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::new_user_message(&prompt)];
|
||||
let model = Model::new(&config.mistral.evaluation_model);
|
||||
let params = ChatParams {
|
||||
response_format: Some(ResponseFormat::json_object()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = chat_blocking(mistral, model, messages, params).await?;
|
||||
let text = response.choices[0].message.content.trim();
|
||||
|
||||
let extraction: ExtractionResponse = match serde_json::from_str(text) {
|
||||
Ok(e) => e,
|
||||
Err(e) => {
|
||||
debug!(raw = text, "Failed to parse extraction response: {e}");
|
||||
return Ok(());
|
||||
}
|
||||
};
|
||||
|
||||
if extraction.memories.is_empty() {
|
||||
debug!("No memories extracted");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let index = &config.opensearch.memory_index;
|
||||
for mem in &extraction.memories {
|
||||
let category = normalize_category(&mem.category);
|
||||
|
||||
if let Err(e) = store::set(
|
||||
opensearch,
|
||||
index,
|
||||
&ctx.user_id,
|
||||
&mem.content,
|
||||
category,
|
||||
"auto",
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("Failed to store extracted memory: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
count = extraction.memories.len(),
|
||||
user = ctx.user_id.as_str(),
|
||||
"Extracted and stored memories"
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_with_memories() {
|
||||
let json = r#"{"memories": [
|
||||
{"content": "prefers terse answers", "category": "preference"},
|
||||
{"content": "working on drive UI", "category": "fact"}
|
||||
]}"#;
|
||||
let resp: ExtractionResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.memories.len(), 2);
|
||||
assert_eq!(resp.memories[0].content, "prefers terse answers");
|
||||
assert_eq!(resp.memories[0].category, "preference");
|
||||
assert_eq!(resp.memories[1].category, "fact");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_empty() {
|
||||
let json = r#"{"memories": []}"#;
|
||||
let resp: ExtractionResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.memories.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_invalid_json() {
|
||||
let json = "not json at all";
|
||||
assert!(serde_json::from_str::<ExtractionResponse>(json).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_extraction_response_missing_field() {
|
||||
let json = r#"{"memories": [{"content": "hi"}]}"#;
|
||||
assert!(serde_json::from_str::<ExtractionResponse>(json).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_category_valid() {
|
||||
assert_eq!(normalize_category("preference"), "preference");
|
||||
assert_eq!(normalize_category("fact"), "fact");
|
||||
assert_eq!(normalize_category("context"), "context");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_category_unknown_falls_back() {
|
||||
assert_eq!(normalize_category("opinion"), "general");
|
||||
assert_eq!(normalize_category(""), "general");
|
||||
assert_eq!(normalize_category("PREFERENCE"), "general");
|
||||
}
|
||||
}
|
||||
3
src/memory/mod.rs
Normal file
3
src/memory/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod extractor;
|
||||
pub mod schema;
|
||||
pub mod store;
|
||||
118
src/memory/schema.rs
Normal file
118
src/memory/schema.rs
Normal file
@@ -0,0 +1,118 @@
|
||||
use opensearch::OpenSearch;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MemoryDocument {
|
||||
pub id: String,
|
||||
pub user_id: String,
|
||||
pub content: String,
|
||||
pub category: String,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
pub source: String,
|
||||
}
|
||||
|
||||
const INDEX_MAPPING: &str = r#"{
|
||||
"settings": {
|
||||
"number_of_shards": 1,
|
||||
"number_of_replicas": 0
|
||||
},
|
||||
"mappings": {
|
||||
"properties": {
|
||||
"id": { "type": "keyword" },
|
||||
"user_id": { "type": "keyword" },
|
||||
"content": { "type": "text", "analyzer": "standard" },
|
||||
"category": { "type": "keyword" },
|
||||
"created_at": { "type": "date", "format": "epoch_millis" },
|
||||
"updated_at": { "type": "date", "format": "epoch_millis" },
|
||||
"source": { "type": "keyword" }
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
pub async fn create_index_if_not_exists(client: &OpenSearch, index: &str) -> anyhow::Result<()> {
|
||||
let exists = client
|
||||
.indices()
|
||||
.exists(opensearch::indices::IndicesExistsParts::Index(&[index]))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if exists.status_code().is_success() {
|
||||
info!(index, "Memory index already exists");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING)?;
|
||||
let response = client
|
||||
.indices()
|
||||
.create(opensearch::indices::IndicesCreateParts::Index(index))
|
||||
.body(mapping)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status_code().is_success() {
|
||||
let body = response.text().await?;
|
||||
anyhow::bail!("Failed to create memory index {index}: {body}");
|
||||
}
|
||||
|
||||
info!(index, "Created memory index");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_memory_document_serialize() {
|
||||
let doc = MemoryDocument {
|
||||
id: "abc-123".into(),
|
||||
user_id: "sienna@sunbeam.pt".into(),
|
||||
content: "prefers terse answers".into(),
|
||||
category: "preference".into(),
|
||||
created_at: 1710000000000,
|
||||
updated_at: 1710000000000,
|
||||
source: "auto".into(),
|
||||
};
|
||||
let json = serde_json::to_value(&doc).unwrap();
|
||||
assert_eq!(json["user_id"], "sienna@sunbeam.pt");
|
||||
assert_eq!(json["category"], "preference");
|
||||
assert_eq!(json["source"], "auto");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_memory_document_roundtrip() {
|
||||
let doc = MemoryDocument {
|
||||
id: "xyz".into(),
|
||||
user_id: "lonni@sunbeam.pt".into(),
|
||||
content: "working on UI redesign".into(),
|
||||
category: "fact".into(),
|
||||
created_at: 1710000000000,
|
||||
updated_at: 1710000000000,
|
||||
source: "script".into(),
|
||||
};
|
||||
let json_str = serde_json::to_string(&doc).unwrap();
|
||||
let roundtrip: MemoryDocument = serde_json::from_str(&json_str).unwrap();
|
||||
assert_eq!(roundtrip.id, doc.id);
|
||||
assert_eq!(roundtrip.user_id, doc.user_id);
|
||||
assert_eq!(roundtrip.content, doc.content);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_mapping_valid_json() {
|
||||
let mapping: serde_json::Value = serde_json::from_str(INDEX_MAPPING).unwrap();
|
||||
assert_eq!(
|
||||
mapping["mappings"]["properties"]["user_id"]["type"]
|
||||
.as_str()
|
||||
.unwrap(),
|
||||
"keyword"
|
||||
);
|
||||
assert_eq!(
|
||||
mapping["mappings"]["properties"]["content"]["type"]
|
||||
.as_str()
|
||||
.unwrap(),
|
||||
"text"
|
||||
);
|
||||
}
|
||||
}
|
||||
187
src/memory/store.rs
Normal file
187
src/memory/store.rs
Normal file
@@ -0,0 +1,187 @@
|
||||
use chrono::Utc;
|
||||
use opensearch::OpenSearch;
|
||||
use serde_json::json;
|
||||
use uuid::Uuid;
|
||||
|
||||
use super::schema::MemoryDocument;
|
||||
|
||||
/// Search memories by content relevance, filtered to a specific user.
|
||||
pub async fn query(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
user_id: &str,
|
||||
query_text: &str,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<Vec<MemoryDocument>> {
|
||||
let body = json!({
|
||||
"size": limit,
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": [
|
||||
{ "term": { "user_id": user_id } }
|
||||
],
|
||||
"must": [
|
||||
{ "match": { "content": query_text } }
|
||||
]
|
||||
}
|
||||
},
|
||||
"sort": [{ "_score": "desc" }]
|
||||
});
|
||||
|
||||
let response = client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let data: serde_json::Value = response.json().await?;
|
||||
parse_hits(&data)
|
||||
}
|
||||
|
||||
/// Get the most recent memories for a user, sorted by updated_at desc.
|
||||
pub async fn get_recent(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
user_id: &str,
|
||||
limit: usize,
|
||||
) -> anyhow::Result<Vec<MemoryDocument>> {
|
||||
let body = json!({
|
||||
"size": limit,
|
||||
"query": {
|
||||
"bool": {
|
||||
"filter": [
|
||||
{ "term": { "user_id": user_id } }
|
||||
]
|
||||
}
|
||||
},
|
||||
"sort": [{ "updated_at": "desc" }]
|
||||
});
|
||||
|
||||
let response = client
|
||||
.search(opensearch::SearchParts::Index(&[index]))
|
||||
.body(body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let data: serde_json::Value = response.json().await?;
|
||||
parse_hits(&data)
|
||||
}
|
||||
|
||||
/// Store a new memory document for a user.
|
||||
pub async fn set(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
user_id: &str,
|
||||
content: &str,
|
||||
category: &str,
|
||||
source: &str,
|
||||
) -> anyhow::Result<()> {
|
||||
let now = Utc::now().timestamp_millis();
|
||||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
let doc = MemoryDocument {
|
||||
id: id.clone(),
|
||||
user_id: user_id.to_string(),
|
||||
content: content.to_string(),
|
||||
category: category.to_string(),
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
source: source.to_string(),
|
||||
};
|
||||
|
||||
let response = client
|
||||
.index(opensearch::IndexParts::IndexId(index, &id))
|
||||
.body(serde_json::to_value(&doc)?)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status_code().is_success() {
|
||||
let body = response.text().await?;
|
||||
anyhow::bail!("Failed to store memory: {body}");
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn parse_hits(data: &serde_json::Value) -> anyhow::Result<Vec<MemoryDocument>> {
|
||||
let hits = data["hits"]["hits"]
|
||||
.as_array()
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
|
||||
let mut docs = Vec::with_capacity(hits.len());
|
||||
for hit in &hits {
|
||||
if let Ok(doc) = serde_json::from_value::<MemoryDocument>(hit["_source"].clone()) {
|
||||
docs.push(doc);
|
||||
}
|
||||
}
|
||||
Ok(docs)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn fake_os_response(sources: Vec<serde_json::Value>) -> serde_json::Value {
|
||||
let hits: Vec<serde_json::Value> = sources
|
||||
.into_iter()
|
||||
.map(|s| json!({ "_source": s }))
|
||||
.collect();
|
||||
json!({ "hits": { "hits": hits } })
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hits_multiple() {
|
||||
let data = fake_os_response(vec![
|
||||
json!({
|
||||
"id": "a", "user_id": "sienna@sunbeam.pt",
|
||||
"content": "prefers terse answers", "category": "preference",
|
||||
"created_at": 1710000000000_i64, "updated_at": 1710000000000_i64,
|
||||
"source": "auto"
|
||||
}),
|
||||
json!({
|
||||
"id": "b", "user_id": "sienna@sunbeam.pt",
|
||||
"content": "working on drive UI", "category": "fact",
|
||||
"created_at": 1710000000000_i64, "updated_at": 1710000000000_i64,
|
||||
"source": "script"
|
||||
}),
|
||||
]);
|
||||
|
||||
let docs = parse_hits(&data).unwrap();
|
||||
assert_eq!(docs.len(), 2);
|
||||
assert_eq!(docs[0].id, "a");
|
||||
assert_eq!(docs[0].content, "prefers terse answers");
|
||||
assert_eq!(docs[1].id, "b");
|
||||
assert_eq!(docs[1].category, "fact");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hits_empty() {
|
||||
let data = json!({ "hits": { "hits": [] } });
|
||||
let docs = parse_hits(&data).unwrap();
|
||||
assert!(docs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hits_missing_structure() {
|
||||
let data = json!({});
|
||||
let docs = parse_hits(&data).unwrap();
|
||||
assert!(docs.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_hits_skips_malformed() {
|
||||
let data = fake_os_response(vec![
|
||||
json!({
|
||||
"id": "good", "user_id": "x@y",
|
||||
"content": "ok", "category": "fact",
|
||||
"created_at": 1, "updated_at": 1, "source": "auto"
|
||||
}),
|
||||
json!({ "bad": "no fields" }),
|
||||
]);
|
||||
|
||||
let docs = parse_hits(&data).unwrap();
|
||||
assert_eq!(docs.len(), 1);
|
||||
assert_eq!(docs[0].id, "good");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user