use std::sync::Arc; use mistralai_client::v1::{ chat::{ChatMessage, ChatParams, ResponseFormat}, constants::Model, }; use opensearch::OpenSearch; use serde::Deserialize; use tracing::{debug, warn}; use crate::config::Config; use crate::context::ResponseContext; use crate::brain::responder::chat_blocking; use super::store; #[derive(Debug, Deserialize)] pub(crate) struct ExtractionResponse { pub memories: Vec, } #[derive(Debug, Deserialize)] pub(crate) struct ExtractedMemory { pub content: String, pub category: String, } /// Validate and normalize a category string. pub(crate) fn normalize_category(raw: &str) -> &str { match raw { "preference" | "fact" | "context" => raw, _ => "general", } } pub async fn extract_and_store( mistral: &Arc, opensearch: &OpenSearch, config: &Config, ctx: &ResponseContext, user_message: &str, sol_response: &str, ) -> anyhow::Result<()> { let display = ctx .display_name .as_deref() .unwrap_or(&ctx.matrix_user_id); let prompt = format!( "Analyze this conversation exchange and extract any facts worth remembering about {display}.\n\ Focus on: preferences, personal details, ongoing projects, opinions, recurring topics.\n\n\ They said: {user_message}\n\ Response: {sol_response}\n\n\ Respond ONLY with JSON: {{\"memories\": [{{\"content\": \"...\", \"category\": \"preference|fact|context\"}}]}}\n\ If nothing worth remembering, respond with {{\"memories\": []}}.\n\ Be selective — only genuinely useful information." ); let messages = vec![ChatMessage::new_user_message(&prompt)]; let model = Model::new(&config.mistral.evaluation_model); let params = ChatParams { response_format: Some(ResponseFormat::json_object()), ..Default::default() }; let response = chat_blocking(mistral, model, messages, params).await?; let text = response.choices[0].message.content.text(); let text = text.trim(); let extraction: ExtractionResponse = match serde_json::from_str(text) { Ok(e) => e, Err(e) => { debug!(raw = text, "Failed to parse extraction response: {e}"); return Ok(()); } }; if extraction.memories.is_empty() { debug!("No memories extracted"); return Ok(()); } let index = &config.opensearch.memory_index; for mem in &extraction.memories { let category = normalize_category(&mem.category); if let Err(e) = store::set( opensearch, index, &ctx.user_id, &mem.content, category, "auto", ) .await { warn!("Failed to store extracted memory: {e}"); } } debug!( count = extraction.memories.len(), user = ctx.user_id.as_str(), "Extracted and stored memories" ); Ok(()) } #[cfg(test)] mod tests { use super::*; #[test] fn test_parse_extraction_response_with_memories() { let json = r#"{"memories": [ {"content": "prefers terse answers", "category": "preference"}, {"content": "working on drive UI", "category": "fact"} ]}"#; let resp: ExtractionResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.memories.len(), 2); assert_eq!(resp.memories[0].content, "prefers terse answers"); assert_eq!(resp.memories[0].category, "preference"); assert_eq!(resp.memories[1].category, "fact"); } #[test] fn test_parse_extraction_response_empty() { let json = r#"{"memories": []}"#; let resp: ExtractionResponse = serde_json::from_str(json).unwrap(); assert!(resp.memories.is_empty()); } #[test] fn test_parse_extraction_response_invalid_json() { let json = "not json at all"; assert!(serde_json::from_str::(json).is_err()); } #[test] fn test_parse_extraction_response_missing_field() { let json = r#"{"memories": [{"content": "hi"}]}"#; assert!(serde_json::from_str::(json).is_err()); } #[test] fn test_normalize_category_valid() { assert_eq!(normalize_category("preference"), "preference"); assert_eq!(normalize_category("fact"), "fact"); assert_eq!(normalize_category("context"), "context"); } #[test] fn test_normalize_category_unknown_falls_back() { assert_eq!(normalize_category("opinion"), "general"); assert_eq!(normalize_category(""), "general"); assert_eq!(normalize_category("PREFERENCE"), "general"); } }