308 lines
9.3 KiB
Rust
308 lines
9.3 KiB
Rust
|
|
use std::sync::Arc;
|
||
|
|
|
||
|
|
use mistralai_client::v1::{
|
||
|
|
chat::{ChatMessage, ChatParams, ResponseFormat},
|
||
|
|
constants::Model,
|
||
|
|
};
|
||
|
|
use regex::Regex;
|
||
|
|
use tracing::{debug, warn};
|
||
|
|
|
||
|
|
use crate::config::Config;
|
||
|
|
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub enum Engagement {
|
||
|
|
MustRespond { reason: MustRespondReason },
|
||
|
|
MaybeRespond { relevance: f32, hook: String },
|
||
|
|
Ignore,
|
||
|
|
}
|
||
|
|
|
||
|
|
#[derive(Debug)]
|
||
|
|
pub enum MustRespondReason {
|
||
|
|
DirectMention,
|
||
|
|
DirectMessage,
|
||
|
|
NameInvocation,
|
||
|
|
}
|
||
|
|
|
||
|
|
pub struct Evaluator {
|
||
|
|
config: Arc<Config>,
|
||
|
|
mention_regex: Regex,
|
||
|
|
name_regex: Regex,
|
||
|
|
}
|
||
|
|
|
||
|
|
impl Evaluator {
|
||
|
|
// todo(sienna): regex must be configrable
|
||
|
|
pub fn new(config: Arc<Config>) -> Self {
|
||
|
|
let user_id = &config.matrix.user_id;
|
||
|
|
let mention_pattern = regex::escape(user_id);
|
||
|
|
let mention_regex = Regex::new(&mention_pattern).expect("Failed to compile mention regex");
|
||
|
|
let name_regex =
|
||
|
|
Regex::new(r"(?i)(?:^|\bhey\s+)\bsol\b").expect("Failed to compile name regex");
|
||
|
|
|
||
|
|
Self {
|
||
|
|
config,
|
||
|
|
mention_regex,
|
||
|
|
name_regex,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
pub async fn evaluate(
|
||
|
|
&self,
|
||
|
|
sender: &str,
|
||
|
|
body: &str,
|
||
|
|
is_dm: bool,
|
||
|
|
recent_messages: &[String],
|
||
|
|
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||
|
|
) -> Engagement {
|
||
|
|
// Don't respond to ourselves
|
||
|
|
if sender == self.config.matrix.user_id {
|
||
|
|
return Engagement::Ignore;
|
||
|
|
}
|
||
|
|
|
||
|
|
// Direct mention: @sol:sunbeam.pt
|
||
|
|
if self.mention_regex.is_match(body) {
|
||
|
|
return Engagement::MustRespond {
|
||
|
|
reason: MustRespondReason::DirectMention,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
// DM
|
||
|
|
if is_dm {
|
||
|
|
return Engagement::MustRespond {
|
||
|
|
reason: MustRespondReason::DirectMessage,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
// Name invocation: "sol ..." or "hey sol ..."
|
||
|
|
if self.name_regex.is_match(body) {
|
||
|
|
return Engagement::MustRespond {
|
||
|
|
reason: MustRespondReason::NameInvocation,
|
||
|
|
};
|
||
|
|
}
|
||
|
|
|
||
|
|
// Cheap evaluation call for spontaneous responses
|
||
|
|
self.evaluate_relevance(body, recent_messages, mistral)
|
||
|
|
.await
|
||
|
|
}
|
||
|
|
|
||
|
|
/// Check rule-based engagement (without calling Mistral). Returns Some(Engagement)
|
||
|
|
/// if a rule matched, None if we need to fall through to the LLM evaluation.
|
||
|
|
pub fn evaluate_rules(
|
||
|
|
&self,
|
||
|
|
sender: &str,
|
||
|
|
body: &str,
|
||
|
|
is_dm: bool,
|
||
|
|
) -> Option<Engagement> {
|
||
|
|
if sender == self.config.matrix.user_id {
|
||
|
|
return Some(Engagement::Ignore);
|
||
|
|
}
|
||
|
|
if self.mention_regex.is_match(body) {
|
||
|
|
return Some(Engagement::MustRespond {
|
||
|
|
reason: MustRespondReason::DirectMention,
|
||
|
|
});
|
||
|
|
}
|
||
|
|
if is_dm {
|
||
|
|
return Some(Engagement::MustRespond {
|
||
|
|
reason: MustRespondReason::DirectMessage,
|
||
|
|
});
|
||
|
|
}
|
||
|
|
if self.name_regex.is_match(body) {
|
||
|
|
return Some(Engagement::MustRespond {
|
||
|
|
reason: MustRespondReason::NameInvocation,
|
||
|
|
});
|
||
|
|
}
|
||
|
|
None
|
||
|
|
}
|
||
|
|
|
||
|
|
async fn evaluate_relevance(
|
||
|
|
&self,
|
||
|
|
body: &str,
|
||
|
|
recent_messages: &[String],
|
||
|
|
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||
|
|
) -> Engagement {
|
||
|
|
let context = recent_messages
|
||
|
|
.iter()
|
||
|
|
.rev()
|
||
|
|
.take(5) //todo(sienna): must be configurable
|
||
|
|
.rev()
|
||
|
|
.cloned()
|
||
|
|
.collect::<Vec<_>>()
|
||
|
|
.join("\n");
|
||
|
|
|
||
|
|
let prompt = format!(
|
||
|
|
"You are evaluating whether a virtual librarian named Sol should spontaneously join \
|
||
|
|
a conversation. Sol has deep knowledge of the group's message archive and helps \
|
||
|
|
people find information.\n\n\
|
||
|
|
Recent conversation:\n{context}\n\n\
|
||
|
|
Latest message: {body}\n\n\
|
||
|
|
Respond ONLY with JSON: {{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\"}}\n\
|
||
|
|
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant."
|
||
|
|
);
|
||
|
|
|
||
|
|
let messages = vec![ChatMessage::new_user_message(&prompt)];
|
||
|
|
let params = ChatParams {
|
||
|
|
response_format: Some(ResponseFormat::json_object()),
|
||
|
|
temperature: Some(0.1),
|
||
|
|
max_tokens: Some(100),
|
||
|
|
..Default::default()
|
||
|
|
};
|
||
|
|
|
||
|
|
let model = Model::new(&self.config.mistral.evaluation_model);
|
||
|
|
let client = Arc::clone(mistral);
|
||
|
|
let result = tokio::task::spawn_blocking(move || {
|
||
|
|
client.chat(model, messages, Some(params))
|
||
|
|
})
|
||
|
|
.await
|
||
|
|
.unwrap_or_else(|e| Err(mistralai_client::v1::error::ApiError {
|
||
|
|
message: format!("spawn_blocking join error: {e}"),
|
||
|
|
}));
|
||
|
|
|
||
|
|
match result {
|
||
|
|
Ok(response) => {
|
||
|
|
let text = &response.choices[0].message.content;
|
||
|
|
match serde_json::from_str::<serde_json::Value>(text) {
|
||
|
|
Ok(val) => {
|
||
|
|
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
|
||
|
|
let hook = val["hook"].as_str().unwrap_or("").to_string();
|
||
|
|
|
||
|
|
debug!(relevance, hook = hook.as_str(), "Evaluation result");
|
||
|
|
|
||
|
|
if relevance >= self.config.behavior.spontaneous_threshold {
|
||
|
|
Engagement::MaybeRespond { relevance, hook }
|
||
|
|
} else {
|
||
|
|
Engagement::Ignore
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
warn!("Failed to parse evaluation response: {e}");
|
||
|
|
Engagement::Ignore
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
Err(e) => {
|
||
|
|
warn!("Evaluation call failed: {e}");
|
||
|
|
Engagement::Ignore
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
#[cfg(test)]
|
||
|
|
mod tests {
|
||
|
|
use super::*;
|
||
|
|
use crate::config::Config;
|
||
|
|
|
||
|
|
fn test_config() -> Arc<Config> {
|
||
|
|
let toml = r#"
|
||
|
|
[matrix]
|
||
|
|
homeserver_url = "https://chat.sunbeam.pt"
|
||
|
|
user_id = "@sol:sunbeam.pt"
|
||
|
|
state_store_path = "/tmp/sol"
|
||
|
|
|
||
|
|
[opensearch]
|
||
|
|
url = "http://localhost:9200"
|
||
|
|
index = "test"
|
||
|
|
|
||
|
|
[mistral]
|
||
|
|
[behavior]
|
||
|
|
"#;
|
||
|
|
Arc::new(Config::from_str(toml).unwrap())
|
||
|
|
}
|
||
|
|
|
||
|
|
fn evaluator() -> Evaluator {
|
||
|
|
Evaluator::new(test_config())
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_ignore_own_messages() {
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@sol:sunbeam.pt", "hello everyone", false);
|
||
|
|
assert!(matches!(result, Some(Engagement::Ignore)));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_direct_mention() {
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey @sol:sunbeam.pt what's up?", false);
|
||
|
|
assert!(matches!(
|
||
|
|
result,
|
||
|
|
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_dm_detection() {
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "random message", true);
|
||
|
|
assert!(matches!(
|
||
|
|
result,
|
||
|
|
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMessage })
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_name_invocation_start_of_message() {
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "sol, can you find that link?", false);
|
||
|
|
assert!(matches!(
|
||
|
|
result,
|
||
|
|
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_name_invocation_hey_sol() {
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hey sol do you remember?", false);
|
||
|
|
assert!(matches!(
|
||
|
|
result,
|
||
|
|
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_name_invocation_case_insensitive() {
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "Hey Sol, help me", false);
|
||
|
|
assert!(matches!(
|
||
|
|
result,
|
||
|
|
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_name_invocation_sol_uppercase() {
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "SOL what do you think?", false);
|
||
|
|
assert!(matches!(
|
||
|
|
result,
|
||
|
|
Some(Engagement::MustRespond { reason: MustRespondReason::NameInvocation })
|
||
|
|
));
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_no_false_positive_solstice() {
|
||
|
|
let ev = evaluator();
|
||
|
|
// "solstice" should NOT trigger name invocation — \b boundary prevents it
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "the solstice is coming", false);
|
||
|
|
assert!(result.is_none());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_random_message_falls_through() {
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "what's for lunch?", false);
|
||
|
|
assert!(result.is_none());
|
||
|
|
}
|
||
|
|
|
||
|
|
#[test]
|
||
|
|
fn test_priority_mention_over_dm() {
|
||
|
|
// When both mention and DM are true, mention should match first
|
||
|
|
let ev = evaluator();
|
||
|
|
let result = ev.evaluate_rules("@alice:sunbeam.pt", "hi @sol:sunbeam.pt", true);
|
||
|
|
assert!(matches!(
|
||
|
|
result,
|
||
|
|
Some(Engagement::MustRespond { reason: MustRespondReason::DirectMention })
|
||
|
|
));
|
||
|
|
}
|
||
|
|
}
|