Files
sol/src/memory/extractor.rs

159 lines
4.7 KiB
Rust
Raw Normal View History

use std::sync::Arc;
use mistralai_client::v1::{
chat::{ChatMessage, ChatParams, ResponseFormat},
constants::Model,
};
use opensearch::OpenSearch;
use serde::Deserialize;
use tracing::{debug, warn};
use crate::config::Config;
use crate::context::ResponseContext;
use crate::brain::responder::chat_blocking;
use super::store;
#[derive(Debug, Deserialize)]
pub(crate) struct ExtractionResponse {
pub memories: Vec<ExtractedMemory>,
}
#[derive(Debug, Deserialize)]
pub(crate) struct ExtractedMemory {
pub content: String,
pub category: String,
}
/// Validate and normalize a category string.
pub(crate) fn normalize_category(raw: &str) -> &str {
match raw {
"preference" | "fact" | "context" => raw,
_ => "general",
}
}
pub async fn extract_and_store(
mistral: &Arc<mistralai_client::v1::client::Client>,
opensearch: &OpenSearch,
config: &Config,
ctx: &ResponseContext,
user_message: &str,
sol_response: &str,
) -> anyhow::Result<()> {
let display = ctx
.display_name
.as_deref()
.unwrap_or(&ctx.matrix_user_id);
let prompt = format!(
"Analyze this conversation exchange and extract any facts worth remembering about {display}.\n\
Focus on: preferences, personal details, ongoing projects, opinions, recurring topics.\n\n\
They said: {user_message}\n\
Response: {sol_response}\n\n\
Respond ONLY with JSON: {{\"memories\": [{{\"content\": \"...\", \"category\": \"preference|fact|context\"}}]}}\n\
If nothing worth remembering, respond with {{\"memories\": []}}.\n\
Be selective only genuinely useful information."
);
let messages = vec![ChatMessage::new_user_message(&prompt)];
let model = Model::new(&config.mistral.evaluation_model);
let params = ChatParams {
response_format: Some(ResponseFormat::json_object()),
..Default::default()
};
let response = chat_blocking(mistral, model, messages, params).await?;
let text = response.choices[0].message.content.trim();
let extraction: ExtractionResponse = match serde_json::from_str(text) {
Ok(e) => e,
Err(e) => {
debug!(raw = text, "Failed to parse extraction response: {e}");
return Ok(());
}
};
if extraction.memories.is_empty() {
debug!("No memories extracted");
return Ok(());
}
let index = &config.opensearch.memory_index;
for mem in &extraction.memories {
let category = normalize_category(&mem.category);
if let Err(e) = store::set(
opensearch,
index,
&ctx.user_id,
&mem.content,
category,
"auto",
)
.await
{
warn!("Failed to store extracted memory: {e}");
}
}
debug!(
count = extraction.memories.len(),
user = ctx.user_id.as_str(),
"Extracted and stored memories"
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_extraction_response_with_memories() {
let json = r#"{"memories": [
{"content": "prefers terse answers", "category": "preference"},
{"content": "working on drive UI", "category": "fact"}
]}"#;
let resp: ExtractionResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.memories.len(), 2);
assert_eq!(resp.memories[0].content, "prefers terse answers");
assert_eq!(resp.memories[0].category, "preference");
assert_eq!(resp.memories[1].category, "fact");
}
#[test]
fn test_parse_extraction_response_empty() {
let json = r#"{"memories": []}"#;
let resp: ExtractionResponse = serde_json::from_str(json).unwrap();
assert!(resp.memories.is_empty());
}
#[test]
fn test_parse_extraction_response_invalid_json() {
let json = "not json at all";
assert!(serde_json::from_str::<ExtractionResponse>(json).is_err());
}
#[test]
fn test_parse_extraction_response_missing_field() {
let json = r#"{"memories": [{"content": "hi"}]}"#;
assert!(serde_json::from_str::<ExtractionResponse>(json).is_err());
}
#[test]
fn test_normalize_category_valid() {
assert_eq!(normalize_category("preference"), "preference");
assert_eq!(normalize_category("fact"), "fact");
assert_eq!(normalize_category("context"), "context");
}
#[test]
fn test_normalize_category_unknown_falls_back() {
assert_eq!(normalize_category("opinion"), "general");
assert_eq!(normalize_category(""), "general");
assert_eq!(normalize_category("PREFERENCE"), "general");
}
}