Files
sol/src/brain/evaluator.rs

308 lines
9.3 KiB
Rust
Raw Normal View History

use std::sync::Arc;
use mistralai_client::v1::{
chat::{ChatMessage, ChatParams, ResponseFormat},
constants::Model,
};
use regex::Regex;
use tracing::{debug, warn};
use crate::config::Config;
#[derive(Debug)]
pub enum Engagement {
MustRespond { reason: MustRespondReason },
MaybeRespond { relevance: f32, hook: String },
Ignore,
}
#[derive(Debug)]
pub enum MustRespondReason {
DirectMention,
DirectMessage,
NameInvocation,
}
pub struct Evaluator {
config: Arc<Config>,
mention_regex: Regex,
name_regex: Regex,
}
impl Evaluator {
// todo(sienna): regex must be configrable
pub fn new(config: Arc<Config>) -> Self {
let user_id = &config.matrix.user_id;
let mention_pattern = regex::escape(user_id);
let mention_regex = Regex::new(&mention_pattern).expect("Failed to compile mention regex");
let name_regex =
Regex::new(r"(?i)(?:^|\bhey\s+)\bsol\b").expect("Failed to compile name regex");
Self {
config,
mention_regex,
name_regex,
}
}
pub async fn evaluate(
&self,
sender: &str,
body: &str,
is_dm: bool,
recent_messages: &[String],
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Engagement {
// Don't respond to ourselves
if sender == self.config.matrix.user_id {
return Engagement::Ignore;
}
// Direct mention: @sol:sunbeam.pt
if self.mention_regex.is_match(body) {
return Engagement::MustRespond {
reason: MustRespondReason::DirectMention,
};
}
// DM
if is_dm {
return Engagement::MustRespond {
reason: MustRespondReason::DirectMessage,
};
}
// Name invocation: "sol ..." or "hey sol ..."
if self.name_regex.is_match(body) {
return Engagement::MustRespond {
reason: MustRespondReason::NameInvocation,
};
}
// Cheap evaluation call for spontaneous responses
self.evaluate_relevance(body, recent_messages, mistral)
.await
}
/// Check rule-based engagement (without calling Mistral). Returns Some(Engagement)
/// if a rule matched, None if we need to fall through to the LLM evaluation.
pub fn evaluate_rules(
&self,
sender: &str,
body: &str,
is_dm: bool,
) -> Option<Engagement> {
if sender == self.config.matrix.user_id {
return Some(Engagement::Ignore);
}
if self.mention_regex.is_match(body) {
return Some(Engagement::MustRespond {
reason: MustRespondReason::DirectMention,
});
}
if is_dm {
return Some(Engagement::MustRespond {
reason: MustRespondReason::DirectMessage,
});
}
if self.name_regex.is_match(body) {
return Some(Engagement::MustRespond {
reason: MustRespondReason::NameInvocation,
});
}
None
}
async fn evaluate_relevance(
&self,
body: &str,
recent_messages: &[String],
mistral: &Arc<mistralai_client::v1::client::Client>,
) -> Engagement {
let context = recent_messages
.iter()
.rev()
.take(5) //todo(sienna): must be configurable
.rev()
.cloned()
.collect::<Vec<_>>()
.join("\n");
let prompt = format!(
"You are evaluating whether a virtual librarian named Sol should spontaneously join \
a conversation. Sol has deep knowledge of the group's message archive and helps \
people find information.\n\n\
Recent conversation:\n{context}\n\n\
Latest message: {body}\n\n\
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\"}}\n\
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant."
);
let messages = vec![ChatMessage::new_user_message(&prompt)];
let params = ChatParams {
response_format: Some(ResponseFormat::json_object()),
temperature: Some(0.1),
max_tokens: Some(100),
..Default::default()
};
let model = Model::new(&self.config.mistral.evaluation_model);
let client = Arc::clone(mistral);
let result = tokio::task::spawn_blocking(move || {
client.chat(model, messages, Some(params))
})
.await
.unwrap_or_else(|e| Err(mistralai_client::v1::error::ApiError {
message: format!("spawn_blocking join error: {e}"),
}));
match result {
Ok(response) => {
let text = &response.choices[0].message.content;
match serde_json::from_str::<serde_json::Value>(text) {
Ok(val) => {
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
let hook = val["hook"].as_str().unwrap_or("").to_string();
debug!(relevance, hook = hook.as_str(), "Evaluation result");
if relevance >= self.config.behavior.spontaneous_threshold {
Engagement::MaybeRespond { relevance, hook }
} else {
Engagement::Ignore
}
}
Err(e) => {
warn!("Failed to parse evaluation response: {e}");
Engagement::Ignore
}
}
}
Err(e) => {
warn!("Evaluation call failed: {e}");
Engagement::Ignore
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
fn test_config() -> Arc<Config> {
let toml = r#"
[matrix]
homeserver_url = "https://chat.sunbeam.pt"
user_id = "@sol:sunbeam.pt"
state_store_path = "/tmp/sol"
[opensearch]
url = "http://localhost:9200"
index = "test"
[mistral]
[behavior]
"#;
Arc::new(Config::from_str(toml).unwrap())
}
fn evaluator() -> Evaluator {
Evaluator::new(test_config())
}
#[test]
fn test_ignore_own_messages() {
let ev = evaluator();
let result = ev.evaluate_rules("@sol:sunbeam.pt", "hello everyone", false);
assert!(matches!(result, Some(Engagement::Ignore)));
}
#[test]
fn test_direct_mention() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey @sol:sunbeam.pt what's up?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
));
}
#[test]
fn test_dm_detection() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "random message", true);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMessage })
));
}
#[test]
fn test_name_invocation_start_of_message() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "sol, can you find that link?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_hey_sol() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey sol do you remember?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_case_insensitive() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "Hey Sol, help me", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_name_invocation_sol_uppercase() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "SOL what do you think?", false);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
));
}
#[test]
fn test_no_false_positive_solstice() {
let ev = evaluator();
// "solstice" should NOT trigger name invocation — \b boundary prevents it
let result = ev.evaluate_rules("@alice:sunbeam.pt", "the solstice is coming", false);
assert!(result.is_none());
}
#[test]
fn test_random_message_falls_through() {
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "what's for lunch?", false);
assert!(result.is_none());
}
#[test]
fn test_priority_mention_over_dm() {
// When both mention and DM are true, mention should match first
let ev = evaluator();
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hi @sol:sunbeam.pt", true);
assert!(matches!(
result,
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
));
}
}