Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
2a1d7a003d
|
|||
|
1ba4e016ba
|
|||
|
567d4c1171
|
|||
|
447bead0b7
|
|||
|
de33ddfe33
|
|||
|
7dbc8a3121
|
|||
|
7324c10d25
|
|||
|
3b62d86c45
|
|||
|
1058afb635
|
|||
|
84278fc1f5
|
|||
|
8e7c572381
|
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -4184,6 +4184,7 @@ dependencies = [
|
||||
"deno_ast",
|
||||
"deno_core",
|
||||
"deno_error",
|
||||
"futures",
|
||||
"libsqlite3-sys",
|
||||
"matrix-sdk",
|
||||
"mistralai-client",
|
||||
|
||||
@@ -37,3 +37,4 @@ reqwest = { version = "0.12", default-features = false, features = ["rustls-tls"
|
||||
uuid = { version = "1", features = ["v4"] }
|
||||
base64 = "0.22"
|
||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||
futures = "0.3"
|
||||
|
||||
104
src/agent_ux.rs
104
src/agent_ux.rs
@@ -1,5 +1,5 @@
|
||||
use matrix_sdk::room::Room;
|
||||
use ruma::events::relation::InReplyTo;
|
||||
use ruma::events::relation::Thread;
|
||||
use ruma::events::room::message::{Relation, RoomMessageEventContent};
|
||||
use ruma::OwnedEventId;
|
||||
use tracing::warn;
|
||||
@@ -47,25 +47,19 @@ impl AgentProgress {
|
||||
}
|
||||
}
|
||||
|
||||
/// Post a step update to the thread. Creates the thread on first call.
|
||||
/// Post a step update to the thread on the user's message.
|
||||
pub async fn post_step(&mut self, text: &str) {
|
||||
let content = if let Some(ref _root) = self.thread_root_id {
|
||||
// Reply in existing thread
|
||||
let mut msg = RoomMessageEventContent::text_markdown(text);
|
||||
msg.relates_to = Some(Relation::Reply {
|
||||
in_reply_to: InReplyTo::new(self.user_event_id.clone()),
|
||||
});
|
||||
msg
|
||||
} else {
|
||||
// First message — starts the thread as a reply to the user's message
|
||||
let mut msg = RoomMessageEventContent::text_markdown(text);
|
||||
msg.relates_to = Some(Relation::Reply {
|
||||
in_reply_to: InReplyTo::new(self.user_event_id.clone()),
|
||||
});
|
||||
msg
|
||||
};
|
||||
let latest = self
|
||||
.thread_root_id
|
||||
.as_ref()
|
||||
.unwrap_or(&self.user_event_id)
|
||||
.clone();
|
||||
|
||||
match self.room.send(content).await {
|
||||
let mut msg = RoomMessageEventContent::text_markdown(text);
|
||||
let thread = Thread::plain(self.user_event_id.clone(), latest);
|
||||
msg.relates_to = Some(Relation::Thread(thread));
|
||||
|
||||
match self.room.send(msg).await {
|
||||
Ok(response) => {
|
||||
if self.thread_root_id.is_none() {
|
||||
self.thread_root_id = Some(response.event_id);
|
||||
@@ -96,19 +90,51 @@ impl AgentProgress {
|
||||
.await;
|
||||
}
|
||||
|
||||
/// Format a tool call for the thread.
|
||||
/// Format a tool call for the thread — concise, not raw args.
|
||||
pub fn format_tool_call(name: &str, args: &str) -> String {
|
||||
format!("`{name}` → ```json\n{args}\n```")
|
||||
// Extract just the key params, not the full JSON blob
|
||||
let summary = match serde_json::from_str::<serde_json::Value>(args) {
|
||||
Ok(v) => {
|
||||
let params: Vec<String> = v
|
||||
.as_object()
|
||||
.map(|obj| {
|
||||
obj.iter()
|
||||
.filter(|(_, v)| !v.is_null() && v.as_str() != Some(""))
|
||||
.map(|(k, v)| {
|
||||
let val = match v {
|
||||
serde_json::Value::String(s) => {
|
||||
if s.len() > 40 {
|
||||
format!("{}…", &s[..40])
|
||||
} else {
|
||||
s.clone()
|
||||
}
|
||||
}
|
||||
other => other.to_string(),
|
||||
};
|
||||
format!("{k}={val}")
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
if params.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!(" ({})", params.join(", "))
|
||||
}
|
||||
}
|
||||
Err(_) => String::new(),
|
||||
};
|
||||
format!("🔧 `{name}`{summary}")
|
||||
}
|
||||
|
||||
/// Format a tool result for the thread.
|
||||
/// Format a tool result for the thread — short summary only.
|
||||
pub fn format_tool_result(name: &str, result: &str) -> String {
|
||||
let truncated = if result.len() > 500 {
|
||||
format!("{}…", &result[..500])
|
||||
let truncated = if result.len() > 200 {
|
||||
format!("{}…", &result[..200])
|
||||
} else {
|
||||
result.to_string()
|
||||
};
|
||||
format!("`{name}` ← {truncated}")
|
||||
format!("← `{name}`: {truncated}")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -117,23 +143,41 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_call() {
|
||||
let formatted = AgentProgress::format_tool_call("search_archive", r#"{"query":"test"}"#);
|
||||
fn test_format_tool_call_with_params() {
|
||||
let formatted = AgentProgress::format_tool_call("search_archive", r#"{"query":"test","room":"general"}"#);
|
||||
assert!(formatted.contains("search_archive"));
|
||||
assert!(formatted.contains("test"));
|
||||
assert!(formatted.contains("query=test"));
|
||||
assert!(formatted.contains("room=general"));
|
||||
assert!(formatted.starts_with("🔧"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_call_no_params() {
|
||||
let formatted = AgentProgress::format_tool_call("list_rooms", "{}");
|
||||
assert_eq!(formatted, "🔧 `list_rooms`");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_call_truncates_long_values() {
|
||||
let long_code = "x".repeat(100);
|
||||
let args = format!(r#"{{"code":"{}"}}"#, long_code);
|
||||
let formatted = AgentProgress::format_tool_call("run_script", &args);
|
||||
assert!(formatted.contains("code="));
|
||||
assert!(formatted.contains("…"));
|
||||
assert!(formatted.len() < 200);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_result_truncation() {
|
||||
let long = "x".repeat(1000);
|
||||
let long = "x".repeat(500);
|
||||
let formatted = AgentProgress::format_tool_result("search", &long);
|
||||
assert!(formatted.len() < 600);
|
||||
assert!(formatted.len() < 300);
|
||||
assert!(formatted.ends_with('…'));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_format_tool_result_short() {
|
||||
let formatted = AgentProgress::format_tool_result("search", "3 results found");
|
||||
assert_eq!(formatted, "`search` ← 3 results found");
|
||||
assert_eq!(formatted, "← `search`: 3 results found");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -69,11 +69,7 @@ pub fn orchestrator_request(
|
||||
name: ORCHESTRATOR_NAME.to_string(),
|
||||
description: Some(ORCHESTRATOR_DESCRIPTION.to_string()),
|
||||
instructions: Some(instructions),
|
||||
tools: {
|
||||
let mut all_tools = tools;
|
||||
all_tools.push(AgentTool::web_search());
|
||||
Some(all_tools)
|
||||
},
|
||||
tools: if tools.is_empty() { None } else { Some(tools) },
|
||||
handoffs: None,
|
||||
completion_args: Some(CompletionArgs {
|
||||
temperature: Some(0.5),
|
||||
|
||||
@@ -12,7 +12,11 @@ use crate::config::Config;
|
||||
#[derive(Debug)]
|
||||
pub enum Engagement {
|
||||
MustRespond { reason: MustRespondReason },
|
||||
MaybeRespond { relevance: f32, hook: String },
|
||||
/// Respond inline in the room — Sol has something valuable to contribute.
|
||||
Respond { relevance: f32, hook: String },
|
||||
/// Respond in a thread — Sol has something to add but it's tangential
|
||||
/// or the room is busy with a human-to-human conversation.
|
||||
ThreadReply { relevance: f32, hook: String },
|
||||
React { emoji: String, relevance: f32 },
|
||||
Ignore,
|
||||
}
|
||||
@@ -51,6 +55,8 @@ impl Evaluator {
|
||||
}
|
||||
}
|
||||
|
||||
/// `is_reply_to_human` — true if this message is a Matrix reply to a non-Sol user.
|
||||
/// `messages_since_sol` — how many messages have been sent since Sol last spoke in this room.
|
||||
pub async fn evaluate(
|
||||
&self,
|
||||
sender: &str,
|
||||
@@ -58,6 +64,9 @@ impl Evaluator {
|
||||
is_dm: bool,
|
||||
recent_messages: &[String],
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
is_reply_to_human: bool,
|
||||
messages_since_sol: usize,
|
||||
is_silenced: bool,
|
||||
) -> Engagement {
|
||||
let body_preview: String = body.chars().take(80).collect();
|
||||
|
||||
@@ -67,7 +76,7 @@ impl Evaluator {
|
||||
return Engagement::Ignore;
|
||||
}
|
||||
|
||||
// Direct mention: @sol:sunbeam.pt
|
||||
// Direct mention: @sol:sunbeam.pt — always responds, breaks silence
|
||||
if self.mention_regex.is_match(body) {
|
||||
info!(sender, body = body_preview.as_str(), rule = "direct_mention", "Engagement: MustRespond");
|
||||
return Engagement::MustRespond {
|
||||
@@ -75,7 +84,7 @@ impl Evaluator {
|
||||
};
|
||||
}
|
||||
|
||||
// DM
|
||||
// DM — always responds (silence only applies to group rooms)
|
||||
if is_dm {
|
||||
info!(sender, body = body_preview.as_str(), rule = "dm", "Engagement: MustRespond");
|
||||
return Engagement::MustRespond {
|
||||
@@ -83,6 +92,12 @@ impl Evaluator {
|
||||
};
|
||||
}
|
||||
|
||||
// If silenced in this room, only direct @mention breaks through (checked above)
|
||||
if is_silenced {
|
||||
debug!(sender, body = body_preview.as_str(), "Silenced in this room — ignoring");
|
||||
return Engagement::Ignore;
|
||||
}
|
||||
|
||||
// Name invocation: "sol ..." or "hey sol ..."
|
||||
if self.name_regex.is_match(body) {
|
||||
info!(sender, body = body_preview.as_str(), rule = "name_invocation", "Engagement: MustRespond");
|
||||
@@ -91,6 +106,32 @@ impl Evaluator {
|
||||
};
|
||||
}
|
||||
|
||||
// ── Structural suppression (A+B) ──
|
||||
|
||||
// A: If this is a reply to another human (not Sol), cap at React-only.
|
||||
// People replying to each other aren't asking for Sol's input.
|
||||
if is_reply_to_human {
|
||||
info!(
|
||||
sender, body = body_preview.as_str(),
|
||||
rule = "reply_to_human",
|
||||
"Reply to non-Sol human — suppressing to React-only"
|
||||
);
|
||||
// Still run the LLM eval for potential emoji reaction, but cap the result
|
||||
let engagement = self.evaluate_relevance(body, recent_messages, mistral).await;
|
||||
return match engagement {
|
||||
Engagement::React { emoji, relevance } => Engagement::React { emoji, relevance },
|
||||
Engagement::Respond { relevance, .. } if relevance >= self.config.behavior.reaction_threshold => {
|
||||
// Would have responded, but demote to just a reaction if the LLM suggested one
|
||||
Engagement::Ignore
|
||||
}
|
||||
_ => Engagement::Ignore,
|
||||
};
|
||||
}
|
||||
|
||||
// B: Consecutive message decay. After 3+ human messages without Sol,
|
||||
// switch from active to passive evaluation context.
|
||||
let force_passive = messages_since_sol >= 3;
|
||||
|
||||
info!(
|
||||
sender, body = body_preview.as_str(),
|
||||
threshold = self.config.behavior.spontaneous_threshold,
|
||||
@@ -98,11 +139,13 @@ impl Evaluator {
|
||||
context_len = recent_messages.len(),
|
||||
eval_window = self.config.behavior.evaluation_context_window,
|
||||
detect_sol = self.config.behavior.detect_sol_in_conversation,
|
||||
messages_since_sol,
|
||||
force_passive,
|
||||
is_reply_to_human,
|
||||
"No rule match — running LLM relevance evaluation"
|
||||
);
|
||||
|
||||
// Cheap evaluation call for spontaneous responses
|
||||
self.evaluate_relevance(body, recent_messages, mistral)
|
||||
self.evaluate_relevance_with_mode(body, recent_messages, mistral, force_passive)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -140,6 +183,16 @@ impl Evaluator {
|
||||
body: &str,
|
||||
recent_messages: &[String],
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
) -> Engagement {
|
||||
self.evaluate_relevance_with_mode(body, recent_messages, mistral, false).await
|
||||
}
|
||||
|
||||
async fn evaluate_relevance_with_mode(
|
||||
&self,
|
||||
body: &str,
|
||||
recent_messages: &[String],
|
||||
mistral: &Arc<mistralai_client::v1::client::Client>,
|
||||
force_passive: bool,
|
||||
) -> Engagement {
|
||||
let window = self.config.behavior.evaluation_context_window;
|
||||
let context = recent_messages
|
||||
@@ -151,8 +204,11 @@ impl Evaluator {
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n");
|
||||
|
||||
// Check if Sol recently participated in this conversation
|
||||
let sol_in_context = self.config.behavior.detect_sol_in_conversation
|
||||
// Check if Sol recently participated in this conversation.
|
||||
// force_passive overrides: if 3+ human messages since Sol spoke, treat as passive
|
||||
// even if Sol's messages are visible in the context window.
|
||||
let sol_in_context = !force_passive
|
||||
&& self.config.behavior.detect_sol_in_conversation
|
||||
&& recent_messages.iter().any(|m| {
|
||||
let lower = m.to_lowercase();
|
||||
lower.starts_with("sol:") || lower.starts_with("sol ") || lower.contains("@sol:")
|
||||
@@ -181,15 +237,16 @@ impl Evaluator {
|
||||
"Building evaluation prompt"
|
||||
);
|
||||
|
||||
// System message: Sol's full personality + evaluation framing.
|
||||
// This gives the evaluator deep context on who Sol is, what they care about,
|
||||
// and how they'd naturally engage — so relevance scoring reflects Sol's actual character.
|
||||
// System message: Sol's full personality + evaluation framing + time context.
|
||||
let tc = crate::time_context::TimeContext::now();
|
||||
|
||||
let system = format!(
|
||||
"You are Sol's engagement evaluator. Your job is to decide whether Sol should \
|
||||
respond to a message in a group chat, based on Sol's personality, expertise, \
|
||||
and relationship with the people in the room.\n\n\
|
||||
"You are Sol's engagement evaluator. Your job is to decide whether and HOW Sol \
|
||||
should respond to a message in a group chat.\n\n\
|
||||
# who sol is\n\n\
|
||||
{}\n\n\
|
||||
# time\n\n\
|
||||
{}\n\n\
|
||||
# your task\n\n\
|
||||
Read the conversation below and evaluate whether Sol would naturally want to \
|
||||
respond to the latest message. Consider:\n\
|
||||
@@ -198,16 +255,25 @@ impl Evaluator {
|
||||
- Is someone implicitly asking for Sol's help (even without mentioning them)?\n\
|
||||
- Is this a continuation of something Sol was already involved in?\n\
|
||||
- Would Sol find this genuinely interesting or have something meaningful to add?\n\
|
||||
- Would a reaction (emoji) be more appropriate than a full response?\n\n\
|
||||
- Are two humans talking to each other? If so, Sol should NOT jump in unless \
|
||||
directly relevant. Two people having a conversation doesn't need a third voice.\n\
|
||||
- Would a reaction (emoji) be more appropriate than a full response?\n\
|
||||
- Would responding in a thread (less intrusive) be better than inline?\n\n\
|
||||
{participation_note}\n\n\
|
||||
Respond ONLY with JSON:\n\
|
||||
{{\"relevance\": 0.0-1.0, \"hook\": \"brief reason or empty string\", \"emoji\": \"a single emoji or empty string\"}}\n\n\
|
||||
relevance=1.0 means Sol absolutely should respond, 0.0 means irrelevant.\n\
|
||||
hook: if responding, a brief note on what Sol would engage with.\n\
|
||||
emoji: if Sol wouldn't write a full response but might react, suggest a single \
|
||||
emoji that feels natural and specific — not generic thumbs up. leave empty if \
|
||||
no reaction fits.",
|
||||
{{\"relevance\": 0.0-1.0, \"response_type\": \"message\"|\"thread\"|\"react\"|\"ignore\", \
|
||||
\"hook\": \"brief reason or empty string\", \"emoji\": \"a single emoji or empty string\"}}\n\n\
|
||||
relevance: 1.0 = Sol absolutely should respond, 0.0 = irrelevant.\n\
|
||||
response_type:\n\
|
||||
- \"message\": Sol has something genuinely valuable to add inline.\n\
|
||||
- \"thread\": Sol has a useful aside or observation, but the main conversation \
|
||||
is between humans — put it in a thread so it doesn't interrupt.\n\
|
||||
- \"react\": emoji reaction only, no text.\n\
|
||||
- \"ignore\": Sol has nothing to add.\n\
|
||||
hook: if responding, what Sol would engage with.\n\
|
||||
emoji: if reacting, a single emoji that feels natural and specific.",
|
||||
self.system_prompt,
|
||||
tc.system_block(),
|
||||
);
|
||||
|
||||
let user_prompt = format!(
|
||||
@@ -249,33 +315,40 @@ impl Evaluator {
|
||||
let relevance = val["relevance"].as_f64().unwrap_or(0.0) as f32;
|
||||
let hook = val["hook"].as_str().unwrap_or("").to_string();
|
||||
let emoji = val["emoji"].as_str().unwrap_or("").to_string();
|
||||
let response_type = val["response_type"].as_str().unwrap_or("ignore").to_string();
|
||||
let threshold = self.config.behavior.spontaneous_threshold;
|
||||
let reaction_threshold = self.config.behavior.reaction_threshold;
|
||||
let reaction_enabled = self.config.behavior.reaction_enabled;
|
||||
|
||||
info!(
|
||||
relevance,
|
||||
threshold,
|
||||
reaction_threshold,
|
||||
response_type = response_type.as_str(),
|
||||
hook = hook.as_str(),
|
||||
emoji = emoji.as_str(),
|
||||
"LLM evaluation parsed"
|
||||
);
|
||||
|
||||
if relevance >= threshold {
|
||||
Engagement::MaybeRespond { relevance, hook }
|
||||
} else if reaction_enabled
|
||||
&& relevance >= reaction_threshold
|
||||
&& !emoji.is_empty()
|
||||
{
|
||||
info!(
|
||||
relevance,
|
||||
emoji = emoji.as_str(),
|
||||
"Reaction range — will react with emoji"
|
||||
);
|
||||
// The LLM decides the response type, but we still gate on relevance threshold
|
||||
match response_type.as_str() {
|
||||
"message" if relevance >= threshold => {
|
||||
Engagement::Respond { relevance, hook }
|
||||
}
|
||||
"thread" if relevance >= threshold * 0.7 => {
|
||||
// Threads have a lower threshold — they're less intrusive
|
||||
Engagement::ThreadReply { relevance, hook }
|
||||
}
|
||||
"react" if reaction_enabled && !emoji.is_empty() => {
|
||||
Engagement::React { emoji, relevance }
|
||||
} else {
|
||||
Engagement::Ignore
|
||||
}
|
||||
// Fallback: if the model says "message" but relevance is below
|
||||
// threshold, check if it would qualify as a thread or reaction
|
||||
"message" | "thread" if relevance >= threshold * 0.7 => {
|
||||
Engagement::ThreadReply { relevance, hook }
|
||||
}
|
||||
_ if reaction_enabled && !emoji.is_empty() && relevance >= self.config.behavior.reaction_threshold => {
|
||||
Engagement::React { emoji, relevance }
|
||||
}
|
||||
_ => Engagement::Ignore,
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use chrono::Utc;
|
||||
use crate::time_context::TimeContext;
|
||||
|
||||
pub struct Personality {
|
||||
template: String,
|
||||
@@ -18,16 +18,9 @@ impl Personality {
|
||||
memory_notes: Option<&str>,
|
||||
is_dm: bool,
|
||||
) -> String {
|
||||
let now = Utc::now();
|
||||
let date = now.format("%Y-%m-%d").to_string();
|
||||
let epoch_ms = now.timestamp_millis().to_string();
|
||||
let tc = TimeContext::now();
|
||||
let members_str = members.join(", ");
|
||||
|
||||
// Pre-compute reference timestamps so the model doesn't have to do math
|
||||
let ts_1h_ago = (now - chrono::Duration::hours(1)).timestamp_millis().to_string();
|
||||
let ts_yesterday = (now - chrono::Duration::days(1)).timestamp_millis().to_string();
|
||||
let ts_last_week = (now - chrono::Duration::days(7)).timestamp_millis().to_string();
|
||||
|
||||
let room_context_rules = if is_dm {
|
||||
String::new()
|
||||
} else {
|
||||
@@ -40,11 +33,9 @@ impl Personality {
|
||||
};
|
||||
|
||||
self.template
|
||||
.replace("{date}", &date)
|
||||
.replace("{epoch_ms}", &epoch_ms)
|
||||
.replace("{ts_1h_ago}", &ts_1h_ago)
|
||||
.replace("{ts_yesterday}", &ts_yesterday)
|
||||
.replace("{ts_last_week}", &ts_last_week)
|
||||
.replace("{date}", &tc.date)
|
||||
.replace("{epoch_ms}", &tc.now.to_string())
|
||||
.replace("{time_block}", &tc.system_block())
|
||||
.replace("{room_name}", room_name)
|
||||
.replace("{members}", &members_str)
|
||||
.replace("{room_context_rules}", &room_context_rules)
|
||||
@@ -60,7 +51,7 @@ mod tests {
|
||||
fn test_date_substitution() {
|
||||
let p = Personality::new("Today is {date}.".to_string());
|
||||
let result = p.build_system_prompt("general", &[], None, false);
|
||||
let today = Utc::now().format("%Y-%m-%d").to_string();
|
||||
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||
assert_eq!(result, format!("Today is {today}."));
|
||||
}
|
||||
|
||||
@@ -93,7 +84,7 @@ mod tests {
|
||||
let members = vec!["Sienna".to_string(), "Lonni".to_string()];
|
||||
let result = p.build_system_prompt("studio", &members, None, false);
|
||||
|
||||
let today = Utc::now().format("%Y-%m-%d").to_string();
|
||||
let today = chrono::Utc::now().format("%Y-%m-%d").to_string();
|
||||
assert!(result.starts_with(&format!("Date: {today}")));
|
||||
assert!(result.contains("Room: studio"));
|
||||
assert!(result.contains("People: Sienna, Lonni"));
|
||||
@@ -132,19 +123,23 @@ mod tests {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_timestamp_variables_substituted() {
|
||||
let p = Personality::new(
|
||||
"now={epoch_ms} 1h={ts_1h_ago} yesterday={ts_yesterday} week={ts_last_week}".to_string(),
|
||||
);
|
||||
fn test_time_block_substituted() {
|
||||
let p = Personality::new("before\n{time_block}\nafter".to_string());
|
||||
let result = p.build_system_prompt("room", &[], None, false);
|
||||
assert!(!result.contains("{time_block}"));
|
||||
assert!(result.contains("epoch_ms:"));
|
||||
assert!(result.contains("today:"));
|
||||
assert!(result.contains("yesterday"));
|
||||
assert!(result.contains("this week"));
|
||||
assert!(result.contains("1h_ago="));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_epoch_ms_substituted() {
|
||||
let p = Personality::new("now={epoch_ms}".to_string());
|
||||
let result = p.build_system_prompt("room", &[], None, false);
|
||||
// Should NOT contain the literal placeholders
|
||||
assert!(!result.contains("{epoch_ms}"));
|
||||
assert!(!result.contains("{ts_1h_ago}"));
|
||||
assert!(!result.contains("{ts_yesterday}"));
|
||||
assert!(!result.contains("{ts_last_week}"));
|
||||
// Should contain numeric values
|
||||
assert!(result.starts_with("now="));
|
||||
assert!(result.contains("1h="));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -21,6 +21,7 @@ use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::conversations::ConversationRegistry;
|
||||
use crate::memory;
|
||||
use crate::time_context::TimeContext;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// Run a Mistral chat completion on a blocking thread.
|
||||
@@ -146,7 +147,7 @@ impl Responder {
|
||||
messages.push(ChatMessage::new_user_message(&trigger));
|
||||
}
|
||||
|
||||
let tool_defs = ToolRegistry::tool_definitions(self.tools.has_gitea());
|
||||
let tool_defs = ToolRegistry::tool_definitions(self.tools.has_gitea(), self.tools.has_kratos());
|
||||
let model = Model::new(&self.config.mistral.default_model);
|
||||
let max_iterations = self.config.mistral.max_tool_iterations;
|
||||
|
||||
@@ -280,6 +281,7 @@ impl Responder {
|
||||
conversation_registry: &ConversationRegistry,
|
||||
image_data_uri: Option<&str>,
|
||||
context_hint: Option<String>,
|
||||
event_id: ruma::OwnedEventId,
|
||||
) -> Option<String> {
|
||||
// Apply response delay
|
||||
if !self.config.behavior.instant_responses {
|
||||
@@ -302,24 +304,16 @@ impl Responder {
|
||||
// Pre-response memory query (same as legacy path)
|
||||
let memory_notes = self.load_memory_notes(response_ctx, trigger_body).await;
|
||||
|
||||
// Build the input message with dynamic context header.
|
||||
// Build the input message with dynamic context.
|
||||
// Agent instructions are static (set at creation), so per-message context
|
||||
// (timestamps, room, members, memory) is prepended to each user message.
|
||||
let now = chrono::Utc::now();
|
||||
let epoch_ms = now.timestamp_millis();
|
||||
let ts_1h = (now - chrono::Duration::hours(1)).timestamp_millis();
|
||||
let ts_yesterday = (now - chrono::Duration::days(1)).timestamp_millis();
|
||||
let ts_last_week = (now - chrono::Duration::days(7)).timestamp_millis();
|
||||
let tc = TimeContext::now();
|
||||
|
||||
let mut context_header = format!(
|
||||
"[context: date={}, epoch_ms={}, ts_1h_ago={}, ts_yesterday={}, ts_last_week={}, room={}, room_name={}]",
|
||||
now.format("%Y-%m-%d"),
|
||||
epoch_ms,
|
||||
ts_1h,
|
||||
ts_yesterday,
|
||||
ts_last_week,
|
||||
room_id,
|
||||
"{}\n[room: {} ({})]",
|
||||
tc.message_line(),
|
||||
room_name,
|
||||
room_id,
|
||||
);
|
||||
|
||||
if let Some(ref notes) = memory_notes {
|
||||
@@ -352,9 +346,12 @@ impl Responder {
|
||||
// Check for function calls — execute locally and send results back
|
||||
let function_calls = response.function_calls();
|
||||
if !function_calls.is_empty() {
|
||||
// Agent UX: reactions + threads require the user's event ID
|
||||
// which we don't have in the responder. For now, log tool calls
|
||||
// and skip UX. TODO: pass event_id through ResponseContext.
|
||||
// Agent UX: react with 🔍 and post tool details in a thread
|
||||
let mut progress = crate::agent_ux::AgentProgress::new(
|
||||
room.clone(),
|
||||
event_id.clone(),
|
||||
);
|
||||
progress.start().await;
|
||||
|
||||
let max_iterations = self.config.mistral.max_tool_iterations;
|
||||
let mut current_response = response;
|
||||
@@ -376,13 +373,30 @@ impl Responder {
|
||||
"Executing tool call (conversations)"
|
||||
);
|
||||
|
||||
|
||||
|
||||
let result = self
|
||||
.tools
|
||||
.execute(&fc.name, &fc.arguments, response_ctx)
|
||||
// Post tool call to thread
|
||||
progress
|
||||
.post_step(&crate::agent_ux::AgentProgress::format_tool_call(
|
||||
&fc.name,
|
||||
&fc.arguments,
|
||||
))
|
||||
.await;
|
||||
|
||||
let result = if fc.name == "research" {
|
||||
self.tools
|
||||
.execute_research(
|
||||
&fc.arguments,
|
||||
response_ctx,
|
||||
room,
|
||||
&event_id,
|
||||
0, // depth 0 — orchestrator level
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
self.tools
|
||||
.execute(&fc.name, &fc.arguments, response_ctx)
|
||||
.await
|
||||
};
|
||||
|
||||
let result_str = match result {
|
||||
Ok(s) => {
|
||||
let preview: String = s.chars().take(500).collect();
|
||||
@@ -427,6 +441,9 @@ impl Responder {
|
||||
debug!(iteration, "Tool iteration complete (conversations)");
|
||||
}
|
||||
|
||||
// Done with tool calls
|
||||
progress.done().await;
|
||||
|
||||
// Extract final text from the last response
|
||||
if let Some(text) = current_response.assistant_text() {
|
||||
let text = strip_sol_prefix(&text);
|
||||
|
||||
@@ -28,6 +28,18 @@ pub struct AgentsConfig {
|
||||
/// Whether to use the Conversations API (vs manual message management).
|
||||
#[serde(default)]
|
||||
pub use_conversations_api: bool,
|
||||
/// Model for research micro-agents.
|
||||
#[serde(default = "default_research_agent_model")]
|
||||
pub research_model: String,
|
||||
/// Max tool calls per research micro-agent.
|
||||
#[serde(default = "default_research_max_iterations")]
|
||||
pub research_max_iterations: usize,
|
||||
/// Max parallel agents per research wave.
|
||||
#[serde(default = "default_research_max_agents")]
|
||||
pub research_max_agents: usize,
|
||||
/// Max recursion depth for research agents spawning sub-agents.
|
||||
#[serde(default = "default_research_max_depth")]
|
||||
pub research_max_depth: usize,
|
||||
}
|
||||
|
||||
impl Default for AgentsConfig {
|
||||
@@ -37,6 +49,10 @@ impl Default for AgentsConfig {
|
||||
domain_model: default_model(),
|
||||
compaction_threshold: default_compaction_threshold(),
|
||||
use_conversations_api: false,
|
||||
research_model: default_research_agent_model(),
|
||||
research_max_iterations: default_research_max_iterations(),
|
||||
research_max_agents: default_research_max_agents(),
|
||||
research_max_depth: default_research_max_depth(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -122,12 +138,24 @@ pub struct BehaviorConfig {
|
||||
pub script_fetch_allowlist: Vec<String>,
|
||||
#[serde(default = "default_memory_extraction_enabled")]
|
||||
pub memory_extraction_enabled: bool,
|
||||
/// Minimum fraction of a source room's members that must also be in the
|
||||
/// requesting room for cross-room search results to be visible.
|
||||
/// 0.0 = no restriction, 1.0 = only same room.
|
||||
#[serde(default = "default_room_overlap_threshold")]
|
||||
pub room_overlap_threshold: f32,
|
||||
/// Duration in ms that Sol stays silent after being told to be quiet.
|
||||
#[serde(default = "default_silence_duration_ms")]
|
||||
pub silence_duration_ms: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize, Default)]
|
||||
pub struct ServicesConfig {
|
||||
#[serde(default)]
|
||||
pub gitea: Option<GiteaConfig>,
|
||||
#[serde(default)]
|
||||
pub kratos: Option<KratosConfig>,
|
||||
#[serde(default)]
|
||||
pub searxng: Option<SearxngConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -135,6 +163,16 @@ pub struct GiteaConfig {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct KratosConfig {
|
||||
pub admin_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct SearxngConfig {
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct VaultConfig {
|
||||
/// OpenBao/Vault URL. Default: http://openbao.data.svc.cluster.local:8200
|
||||
@@ -187,8 +225,14 @@ fn default_script_timeout_secs() -> u64 { 5 }
|
||||
fn default_script_max_heap_mb() -> usize { 64 }
|
||||
fn default_memory_index() -> String { "sol_user_memory".into() }
|
||||
fn default_memory_extraction_enabled() -> bool { true }
|
||||
fn default_room_overlap_threshold() -> f32 { 0.25 }
|
||||
fn default_silence_duration_ms() -> u64 { 1_800_000 } // 30 minutes
|
||||
fn default_db_path() -> String { "/data/sol.db".into() }
|
||||
fn default_compaction_threshold() -> u32 { 118000 } // ~90% of 131K context window
|
||||
fn default_research_agent_model() -> String { "ministral-3b-latest".into() }
|
||||
fn default_research_max_iterations() -> usize { 10 }
|
||||
fn default_research_max_agents() -> usize { 25 }
|
||||
fn default_research_max_depth() -> usize { 4 }
|
||||
|
||||
impl Config {
|
||||
pub fn load(path: &str) -> anyhow::Result<Self> {
|
||||
@@ -322,6 +366,17 @@ state_store_path = "/data/sol/state"
|
||||
assert!(config.services.gitea.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_services_config_with_kratos() {
|
||||
let with_kratos = format!(
|
||||
"{}\n[services.kratos]\nadmin_url = \"http://kratos-admin:80\"\n",
|
||||
MINIMAL_CONFIG
|
||||
);
|
||||
let config = Config::from_str(&with_kratos).unwrap();
|
||||
let kratos = config.services.kratos.unwrap();
|
||||
assert_eq!(kratos.admin_url, "http://kratos-admin:80");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_services_config_with_gitea() {
|
||||
let with_services = format!(
|
||||
|
||||
@@ -81,10 +81,10 @@ impl ConversationRegistry {
|
||||
) -> Result<ConversationResponse, String> {
|
||||
let mut mapping = self.mapping.lock().await;
|
||||
|
||||
// Try to append to existing conversation; if it fails, drop and recreate
|
||||
if let Some(state) = mapping.get_mut(room_id) {
|
||||
// Existing conversation — append
|
||||
let req = AppendConversationRequest {
|
||||
inputs: message,
|
||||
inputs: message.clone(),
|
||||
completion_args: None,
|
||||
handoff_execution: None,
|
||||
store: Some(true),
|
||||
@@ -92,12 +92,11 @@ impl ConversationRegistry {
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = mistral
|
||||
match mistral
|
||||
.append_conversation_async(&state.conversation_id, &req)
|
||||
.await
|
||||
.map_err(|e| format!("append_conversation failed: {}", e.message))?;
|
||||
|
||||
// Update token estimate
|
||||
{
|
||||
Ok(response) => {
|
||||
state.estimated_tokens += response.usage.total_tokens;
|
||||
self.store.update_tokens(room_id, state.estimated_tokens);
|
||||
|
||||
@@ -108,8 +107,23 @@ impl ConversationRegistry {
|
||||
"Appended to conversation"
|
||||
);
|
||||
|
||||
Ok(response)
|
||||
} else {
|
||||
return Ok(response);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
room = room_id,
|
||||
conversation_id = state.conversation_id.as_str(),
|
||||
error = e.message.as_str(),
|
||||
"Conversation corrupted — dropping and creating fresh"
|
||||
);
|
||||
self.store.delete_conversation(room_id);
|
||||
mapping.remove(room_id);
|
||||
// Fall through to create a new conversation below
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// New conversation — create (with optional context hint for continuity)
|
||||
let agent_id = self.agent_id.lock().await.clone();
|
||||
|
||||
|
||||
41
src/main.rs
41
src/main.rs
@@ -10,6 +10,7 @@ mod memory;
|
||||
mod persistence;
|
||||
mod sdk;
|
||||
mod sync;
|
||||
mod time_context;
|
||||
mod tools;
|
||||
|
||||
use std::sync::Arc;
|
||||
@@ -174,11 +175,26 @@ async fn main() -> anyhow::Result<()> {
|
||||
None
|
||||
};
|
||||
|
||||
// Initialize Kratos client if configured
|
||||
let kratos_client: Option<Arc<sdk::kratos::KratosClient>> =
|
||||
if let Some(kratos_config) = &config.services.kratos {
|
||||
info!(url = kratos_config.admin_url.as_str(), "Kratos integration enabled");
|
||||
Some(Arc::new(sdk::kratos::KratosClient::new(
|
||||
kratos_config.admin_url.clone(),
|
||||
)))
|
||||
} else {
|
||||
info!("Kratos integration disabled (missing config)");
|
||||
None
|
||||
};
|
||||
|
||||
let tool_registry = Arc::new(ToolRegistry::new(
|
||||
os_client.clone(),
|
||||
matrix_client.clone(),
|
||||
config.clone(),
|
||||
gitea_client,
|
||||
kratos_client,
|
||||
Some(mistral.clone()),
|
||||
Some(store.clone()),
|
||||
));
|
||||
let indexer = Arc::new(Indexer::new(os_client.clone(), config.clone()));
|
||||
let evaluator = Arc::new(Evaluator::new(config.clone(), system_prompt_text.clone()));
|
||||
@@ -213,13 +229,24 @@ async fn main() -> anyhow::Result<()> {
|
||||
opensearch: os_client,
|
||||
last_response: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
responding_in: Arc::new(tokio::sync::Mutex::new(std::collections::HashSet::new())),
|
||||
silenced_until: Arc::new(tokio::sync::Mutex::new(std::collections::HashMap::new())),
|
||||
});
|
||||
|
||||
// Initialize orchestrator agent if conversations API is enabled
|
||||
let mut agent_recreated = false;
|
||||
if config.agents.use_conversations_api {
|
||||
info!("Conversations API enabled — ensuring orchestrator agent exists");
|
||||
let agent_tools = tools::ToolRegistry::agent_tool_definitions(config.services.gitea.is_some());
|
||||
let agent_tools = tools::ToolRegistry::agent_tool_definitions(
|
||||
config.services.gitea.is_some(),
|
||||
config.services.kratos.is_some(),
|
||||
);
|
||||
let mut active_agents: Vec<(&str, &str)> = vec![];
|
||||
if config.services.gitea.is_some() {
|
||||
active_agents.push(("sol-devtools", "Git repos, issues, PRs, code (Gitea)"));
|
||||
}
|
||||
if config.services.kratos.is_some() {
|
||||
active_agents.push(("sol-identity", "User accounts, sessions, recovery (Kratos)"));
|
||||
}
|
||||
match state
|
||||
.agent_registry
|
||||
.ensure_orchestrator(
|
||||
@@ -227,7 +254,7 @@ async fn main() -> anyhow::Result<()> {
|
||||
&config.agents.orchestrator_model,
|
||||
agent_tools,
|
||||
&state.mistral,
|
||||
&[], // no domain agents yet — delegation section added when they are
|
||||
&active_agents,
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -249,6 +276,16 @@ async fn main() -> anyhow::Result<()> {
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up hung research sessions from previous runs
|
||||
let hung_sessions = store.load_running_research_sessions();
|
||||
if !hung_sessions.is_empty() {
|
||||
info!(count = hung_sessions.len(), "Found hung research sessions — marking as failed");
|
||||
for (session_id, _room_id, query, _findings) in &hung_sessions {
|
||||
warn!(session_id = session_id.as_str(), query = query.as_str(), "Cleaning up hung research session");
|
||||
store.fail_research_session(session_id);
|
||||
}
|
||||
}
|
||||
|
||||
// Backfill reactions from Matrix room timelines
|
||||
info!("Backfilling reactions from room timelines...");
|
||||
if let Err(e) = backfill_reactions(&matrix_client, &state.indexer).await {
|
||||
|
||||
@@ -56,6 +56,19 @@ pub fn make_reply_content(body: &str, reply_to_event_id: OwnedEventId) -> RoomMe
|
||||
content
|
||||
}
|
||||
|
||||
/// Build a threaded reply — shows up in Matrix threads UI.
|
||||
/// The thread root is the event being replied to (creates the thread on first use).
|
||||
pub fn make_thread_reply(
|
||||
body: &str,
|
||||
thread_root_id: OwnedEventId,
|
||||
) -> RoomMessageEventContent {
|
||||
use ruma::events::relation::Thread;
|
||||
let mut content = RoomMessageEventContent::text_markdown(body);
|
||||
let thread = Thread::plain(thread_root_id.clone(), thread_root_id);
|
||||
content.relates_to = Some(Relation::Thread(thread));
|
||||
content
|
||||
}
|
||||
|
||||
/// Send an emoji reaction to a message.
|
||||
pub async fn send_reaction(
|
||||
room: &Room,
|
||||
|
||||
@@ -83,6 +83,18 @@ impl Store {
|
||||
PRIMARY KEY (localpart, service)
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS research_sessions (
|
||||
session_id TEXT PRIMARY KEY,
|
||||
room_id TEXT NOT NULL,
|
||||
event_id TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'running',
|
||||
query TEXT NOT NULL,
|
||||
plan_json TEXT,
|
||||
findings_json TEXT,
|
||||
depth INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
completed_at TEXT
|
||||
);
|
||||
",
|
||||
)?;
|
||||
|
||||
@@ -272,6 +284,91 @@ impl Store {
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Research Sessions
|
||||
// =========================================================================
|
||||
|
||||
/// Create a new research session.
|
||||
pub fn create_research_session(
|
||||
&self,
|
||||
session_id: &str,
|
||||
room_id: &str,
|
||||
event_id: &str,
|
||||
query: &str,
|
||||
plan_json: &str,
|
||||
) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
if let Err(e) = conn.execute(
|
||||
"INSERT INTO research_sessions (session_id, room_id, event_id, query, plan_json, findings_json)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, '[]')",
|
||||
params![session_id, room_id, event_id, query, plan_json],
|
||||
) {
|
||||
warn!("Failed to create research session: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Append a finding to a research session.
|
||||
pub fn append_research_finding(&self, session_id: &str, finding_json: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
// Append to the JSON array
|
||||
if let Err(e) = conn.execute(
|
||||
"UPDATE research_sessions
|
||||
SET findings_json = json_insert(findings_json, '$[#]', json(?1))
|
||||
WHERE session_id = ?2",
|
||||
params![finding_json, session_id],
|
||||
) {
|
||||
warn!("Failed to append research finding: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a research session as complete.
|
||||
pub fn complete_research_session(&self, session_id: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
if let Err(e) = conn.execute(
|
||||
"UPDATE research_sessions SET status = 'complete', completed_at = datetime('now')
|
||||
WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
) {
|
||||
warn!("Failed to complete research session: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Mark a research session as failed.
|
||||
pub fn fail_research_session(&self, session_id: &str) {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
if let Err(e) = conn.execute(
|
||||
"UPDATE research_sessions SET status = 'failed', completed_at = datetime('now')
|
||||
WHERE session_id = ?1",
|
||||
params![session_id],
|
||||
) {
|
||||
warn!("Failed to mark research session failed: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
/// Load all running research sessions (for crash recovery on startup).
|
||||
pub fn load_running_research_sessions(&self) -> Vec<(String, String, String, String)> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
let mut stmt = match conn.prepare(
|
||||
"SELECT session_id, room_id, query, findings_json
|
||||
FROM research_sessions WHERE status = 'running'",
|
||||
) {
|
||||
Ok(s) => s,
|
||||
Err(_) => return Vec::new(),
|
||||
};
|
||||
|
||||
stmt.query_map([], |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
row.get::<_, String>(3)?,
|
||||
))
|
||||
})
|
||||
.ok()
|
||||
.map(|rows| rows.filter_map(|r| r.ok()).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Load all agent mappings (for startup recovery).
|
||||
pub fn load_all_agents(&self) -> Vec<(String, String)> {
|
||||
let conn = self.conn.lock().unwrap();
|
||||
@@ -427,4 +524,95 @@ mod tests {
|
||||
let store = Store::open_memory().unwrap();
|
||||
assert!(store.get_service_user("nobody", "gitea").is_none());
|
||||
}
|
||||
|
||||
// ── Research session tests ──────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_research_session_lifecycle() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
|
||||
// Create
|
||||
store.create_research_session("sess-1", "!room:x", "$event1", "investigate SBBB", "[]");
|
||||
let running = store.load_running_research_sessions();
|
||||
assert_eq!(running.len(), 1);
|
||||
assert_eq!(running[0].0, "sess-1");
|
||||
assert_eq!(running[0].2, "investigate SBBB");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_session_append_finding() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
store.create_research_session("sess-2", "!room:x", "$event2", "test", "[]");
|
||||
|
||||
store.append_research_finding("sess-2", r#"{"focus":"repo","findings":"found 3 files"}"#);
|
||||
store.append_research_finding("sess-2", r#"{"focus":"archive","findings":"12 messages"}"#);
|
||||
|
||||
let running = store.load_running_research_sessions();
|
||||
assert_eq!(running.len(), 1);
|
||||
// findings_json should be a JSON array with 2 entries
|
||||
let findings: serde_json::Value = serde_json::from_str(&running[0].3).unwrap();
|
||||
assert_eq!(findings.as_array().unwrap().len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_session_complete() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
store.create_research_session("sess-3", "!room:x", "$event3", "test", "[]");
|
||||
|
||||
store.complete_research_session("sess-3");
|
||||
|
||||
// Should no longer appear in running sessions
|
||||
let running = store.load_running_research_sessions();
|
||||
assert!(running.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_session_fail() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
store.create_research_session("sess-4", "!room:x", "$event4", "test", "[]");
|
||||
|
||||
store.fail_research_session("sess-4");
|
||||
|
||||
let running = store.load_running_research_sessions();
|
||||
assert!(running.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hung_session_cleanup_on_startup() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
|
||||
// Simulate 2 hung sessions + 1 completed
|
||||
store.create_research_session("hung-1", "!room:a", "$e1", "query A", "[]");
|
||||
store.create_research_session("hung-2", "!room:b", "$e2", "query B", "[]");
|
||||
store.create_research_session("done-1", "!room:c", "$e3", "query C", "[]");
|
||||
store.complete_research_session("done-1");
|
||||
|
||||
// Only the 2 hung sessions should be returned
|
||||
let hung = store.load_running_research_sessions();
|
||||
assert_eq!(hung.len(), 2);
|
||||
|
||||
// Clean them up (simulates startup logic)
|
||||
for (session_id, _, _, _) in &hung {
|
||||
store.fail_research_session(session_id);
|
||||
}
|
||||
|
||||
// Now none should be running
|
||||
assert!(store.load_running_research_sessions().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_session_partial_findings_survive_failure() {
|
||||
let store = Store::open_memory().unwrap();
|
||||
store.create_research_session("sess-5", "!room:x", "$e5", "deep dive", "[]");
|
||||
|
||||
// Agent 1 completes, agent 2 hasn't yet
|
||||
store.append_research_finding("sess-5", r#"{"focus":"agent1","findings":"found stuff"}"#);
|
||||
|
||||
// Crash! Mark as failed
|
||||
store.fail_research_session("sess-5");
|
||||
|
||||
// Findings should still be queryable even though session failed
|
||||
// (would need a get_session method to verify, but the key point is
|
||||
// append_research_finding persists incrementally)
|
||||
}
|
||||
}
|
||||
|
||||
690
src/sdk/gitea.rs
690
src/sdk/gitea.rs
@@ -13,8 +13,10 @@ const TOKEN_SCOPES: &[&str] = &[
|
||||
"read:issue",
|
||||
"write:issue",
|
||||
"read:repository",
|
||||
"write:repository",
|
||||
"read:user",
|
||||
"read:organization",
|
||||
"read:notification",
|
||||
];
|
||||
|
||||
pub struct GiteaClient {
|
||||
@@ -110,6 +112,72 @@ pub struct FileContent {
|
||||
pub size: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Comment {
|
||||
pub id: u64,
|
||||
#[serde(default)]
|
||||
pub body: String,
|
||||
pub user: UserRef,
|
||||
#[serde(default)]
|
||||
pub html_url: String,
|
||||
pub created_at: String,
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Branch {
|
||||
pub name: String,
|
||||
#[serde(default)]
|
||||
pub commit: BranchCommit,
|
||||
#[serde(default)]
|
||||
pub protected: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct BranchCommit {
|
||||
#[serde(default)]
|
||||
pub id: String,
|
||||
#[serde(default)]
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Organization {
|
||||
#[serde(default)]
|
||||
pub id: u64,
|
||||
#[serde(default)]
|
||||
pub username: String,
|
||||
#[serde(default)]
|
||||
pub full_name: String,
|
||||
#[serde(default)]
|
||||
pub description: String,
|
||||
#[serde(default)]
|
||||
pub avatar_url: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Notification {
|
||||
pub id: u64,
|
||||
#[serde(default)]
|
||||
pub subject: NotificationSubject,
|
||||
#[serde(default)]
|
||||
pub repository: Option<RepoSummary>,
|
||||
#[serde(default)]
|
||||
pub unread: bool,
|
||||
#[serde(default)]
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default, Serialize, Deserialize)]
|
||||
pub struct NotificationSubject {
|
||||
#[serde(default)]
|
||||
pub title: String,
|
||||
#[serde(default)]
|
||||
pub url: String,
|
||||
#[serde(default, rename = "type")]
|
||||
pub subject_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GiteaUser {
|
||||
login: String,
|
||||
@@ -424,6 +492,77 @@ impl GiteaClient {
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
/// Make an authenticated PATCH request using the user's token.
|
||||
/// On 401, invalidates token and retries once.
|
||||
async fn authed_patch(
|
||||
&self,
|
||||
localpart: &str,
|
||||
path: &str,
|
||||
body: &serde_json::Value,
|
||||
) -> Result<reqwest::Response, String> {
|
||||
let token = self.ensure_token(localpart).await?;
|
||||
let url = format!("{}{}", self.base_url, path);
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("token {token}"))
|
||||
.json(body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("request failed: {e}"))?;
|
||||
|
||||
if resp.status().as_u16() == 401 {
|
||||
debug!(localpart, "Token rejected, re-provisioning");
|
||||
self.token_store.delete(localpart, SERVICE).await;
|
||||
let token = self.ensure_token(localpart).await?;
|
||||
return self
|
||||
.http
|
||||
.patch(&url)
|
||||
.header("Authorization", format!("token {token}"))
|
||||
.json(body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("request failed (retry): {e}"));
|
||||
}
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
/// Make an authenticated DELETE request using the user's token.
|
||||
/// On 401, invalidates token and retries once.
|
||||
async fn authed_delete(
|
||||
&self,
|
||||
localpart: &str,
|
||||
path: &str,
|
||||
) -> Result<reqwest::Response, String> {
|
||||
let token = self.ensure_token(localpart).await?;
|
||||
let url = format!("{}{}", self.base_url, path);
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.delete(&url)
|
||||
.header("Authorization", format!("token {token}"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("request failed: {e}"))?;
|
||||
|
||||
if resp.status().as_u16() == 401 {
|
||||
debug!(localpart, "Token rejected, re-provisioning");
|
||||
self.token_store.delete(localpart, SERVICE).await;
|
||||
let token = self.ensure_token(localpart).await?;
|
||||
return self
|
||||
.http
|
||||
.delete(&url)
|
||||
.header("Authorization", format!("token {token}"))
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("request failed (retry): {e}"));
|
||||
}
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
// ── Public API methods ──────────────────────────────────────────────────
|
||||
|
||||
pub async fn list_repos(
|
||||
@@ -625,6 +764,429 @@ impl GiteaClient {
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
// ── Repos ───────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn create_repo(
|
||||
&self,
|
||||
localpart: &str,
|
||||
name: &str,
|
||||
org: Option<&str>,
|
||||
description: Option<&str>,
|
||||
private: Option<bool>,
|
||||
auto_init: Option<bool>,
|
||||
default_branch: Option<&str>,
|
||||
) -> Result<Repo, String> {
|
||||
let mut json = serde_json::json!({ "name": name });
|
||||
if let Some(d) = description {
|
||||
json["description"] = serde_json::json!(d);
|
||||
}
|
||||
if let Some(p) = private {
|
||||
json["private"] = serde_json::json!(p);
|
||||
}
|
||||
if let Some(a) = auto_init {
|
||||
json["auto_init"] = serde_json::json!(a);
|
||||
}
|
||||
if let Some(b) = default_branch {
|
||||
json["default_branch"] = serde_json::json!(b);
|
||||
}
|
||||
|
||||
let path = if let Some(org) = org {
|
||||
format!("/api/v1/orgs/{org}/repos")
|
||||
} else {
|
||||
"/api/v1/user/repos".to_string()
|
||||
};
|
||||
|
||||
let resp = self.authed_post(localpart, &path, &json).await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("create repo failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn edit_repo(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
description: Option<&str>,
|
||||
private: Option<bool>,
|
||||
archived: Option<bool>,
|
||||
default_branch: Option<&str>,
|
||||
) -> Result<Repo, String> {
|
||||
let mut json = serde_json::json!({});
|
||||
if let Some(d) = description {
|
||||
json["description"] = serde_json::json!(d);
|
||||
}
|
||||
if let Some(p) = private {
|
||||
json["private"] = serde_json::json!(p);
|
||||
}
|
||||
if let Some(a) = archived {
|
||||
json["archived"] = serde_json::json!(a);
|
||||
}
|
||||
if let Some(b) = default_branch {
|
||||
json["default_branch"] = serde_json::json!(b);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.authed_patch(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}"),
|
||||
&json,
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("edit repo failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn fork_repo(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
new_name: Option<&str>,
|
||||
) -> Result<Repo, String> {
|
||||
let mut json = serde_json::json!({});
|
||||
if let Some(n) = new_name {
|
||||
json["name"] = serde_json::json!(n);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.authed_post(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/forks"),
|
||||
&json,
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("fork repo failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn list_org_repos(
|
||||
&self,
|
||||
localpart: &str,
|
||||
org: &str,
|
||||
limit: Option<u32>,
|
||||
) -> Result<Vec<RepoSummary>, String> {
|
||||
let param_str = {
|
||||
let mut encoder = form_urlencoded::Serializer::new(String::new());
|
||||
if let Some(n) = limit {
|
||||
encoder.append_pair("limit", &n.to_string());
|
||||
}
|
||||
let encoded = encoder.finish();
|
||||
if encoded.is_empty() { String::new() } else { format!("?{encoded}") }
|
||||
};
|
||||
|
||||
let path = format!("/api/v1/orgs/{org}/repos{param_str}");
|
||||
let resp = self.authed_get(localpart, &path).await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("list org repos failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
// ── Issues ──────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn edit_issue(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
number: u64,
|
||||
title: Option<&str>,
|
||||
body: Option<&str>,
|
||||
state: Option<&str>,
|
||||
assignees: Option<&[String]>,
|
||||
) -> Result<Issue, String> {
|
||||
let mut json = serde_json::json!({});
|
||||
if let Some(t) = title {
|
||||
json["title"] = serde_json::json!(t);
|
||||
}
|
||||
if let Some(b) = body {
|
||||
json["body"] = serde_json::json!(b);
|
||||
}
|
||||
if let Some(s) = state {
|
||||
json["state"] = serde_json::json!(s);
|
||||
}
|
||||
if let Some(a) = assignees {
|
||||
json["assignees"] = serde_json::json!(a);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.authed_patch(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/issues/{number}"),
|
||||
&json,
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("edit issue failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn list_comments(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
number: u64,
|
||||
) -> Result<Vec<Comment>, String> {
|
||||
let resp = self
|
||||
.authed_get(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/issues/{number}/comments"),
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("list comments failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn create_comment(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
number: u64,
|
||||
body: &str,
|
||||
) -> Result<Comment, String> {
|
||||
let json = serde_json::json!({ "body": body });
|
||||
|
||||
let resp = self
|
||||
.authed_post(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/issues/{number}/comments"),
|
||||
&json,
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("create comment failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
// ── Pull requests ───────────────────────────────────────────────────────
|
||||
|
||||
pub async fn get_pull(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
number: u64,
|
||||
) -> Result<PullRequest, String> {
|
||||
let resp = self
|
||||
.authed_get(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/pulls/{number}"),
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("get pull failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn create_pull(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
title: &str,
|
||||
head: &str,
|
||||
base: &str,
|
||||
body: Option<&str>,
|
||||
) -> Result<PullRequest, String> {
|
||||
let mut json = serde_json::json!({
|
||||
"title": title,
|
||||
"head": head,
|
||||
"base": base,
|
||||
});
|
||||
if let Some(b) = body {
|
||||
json["body"] = serde_json::json!(b);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.authed_post(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/pulls"),
|
||||
&json,
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("create pull failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn merge_pull(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
number: u64,
|
||||
method: Option<&str>,
|
||||
delete_branch: Option<bool>,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let mut json = serde_json::json!({
|
||||
"Do": method.unwrap_or("merge"),
|
||||
});
|
||||
if let Some(d) = delete_branch {
|
||||
json["delete_branch_after_merge"] = serde_json::json!(d);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.authed_post(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/pulls/{number}/merge"),
|
||||
&json,
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("merge pull failed: {text}"));
|
||||
}
|
||||
// Merge returns empty body on success (204/200)
|
||||
Ok(serde_json::json!({"status": "merged", "number": number}))
|
||||
}
|
||||
|
||||
// ── Branches ────────────────────────────────────────────────────────────
|
||||
|
||||
pub async fn list_branches(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
) -> Result<Vec<Branch>, String> {
|
||||
let resp = self
|
||||
.authed_get(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/branches"),
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("list branches failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn create_branch(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
branch_name: &str,
|
||||
from_branch: Option<&str>,
|
||||
) -> Result<Branch, String> {
|
||||
let mut json = serde_json::json!({
|
||||
"new_branch_name": branch_name,
|
||||
});
|
||||
if let Some(f) = from_branch {
|
||||
json["old_branch_name"] = serde_json::json!(f);
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.authed_post(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/branches"),
|
||||
&json,
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("create branch failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn delete_branch(
|
||||
&self,
|
||||
localpart: &str,
|
||||
owner: &str,
|
||||
repo: &str,
|
||||
branch: &str,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let resp = self
|
||||
.authed_delete(
|
||||
localpart,
|
||||
&format!("/api/v1/repos/{owner}/{repo}/branches/{branch}"),
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("delete branch failed: {text}"));
|
||||
}
|
||||
Ok(serde_json::json!({"status": "deleted", "branch": branch}))
|
||||
}
|
||||
|
||||
// ── Organizations ───────────────────────────────────────────────────────
|
||||
|
||||
pub async fn list_orgs(
|
||||
&self,
|
||||
localpart: &str,
|
||||
username: &str,
|
||||
) -> Result<Vec<Organization>, String> {
|
||||
let resp = self
|
||||
.authed_get(
|
||||
localpart,
|
||||
&format!("/api/v1/users/{username}/orgs"),
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("list orgs failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
pub async fn get_org(
|
||||
&self,
|
||||
localpart: &str,
|
||||
org: &str,
|
||||
) -> Result<Organization, String> {
|
||||
let resp = self
|
||||
.authed_get(
|
||||
localpart,
|
||||
&format!("/api/v1/orgs/{org}"),
|
||||
)
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("get org failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
|
||||
// ── Notifications ───────────────────────────────────────────────────────
|
||||
|
||||
pub async fn list_notifications(
|
||||
&self,
|
||||
localpart: &str,
|
||||
) -> Result<Vec<Notification>, String> {
|
||||
let resp = self
|
||||
.authed_get(localpart, "/api/v1/notifications")
|
||||
.await?;
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("list notifications failed: {text}"));
|
||||
}
|
||||
resp.json().await.map_err(|e| format!("parse error: {e}"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -729,4 +1291,132 @@ mod tests {
|
||||
assert_eq!(file.name, "README.md");
|
||||
assert_eq!(file.file_type, "file");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_comment_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"id": 99,
|
||||
"body": "looks good to me",
|
||||
"user": { "login": "lonni" },
|
||||
"html_url": "https://src.sunbeam.pt/studio/sol/issues/1#issuecomment-99",
|
||||
"created_at": "2026-03-22T10:00:00Z",
|
||||
"updated_at": "2026-03-22T10:00:00Z",
|
||||
});
|
||||
let comment: Comment = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(comment.id, 99);
|
||||
assert_eq!(comment.body, "looks good to me");
|
||||
assert_eq!(comment.user.login, "lonni");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"name": "feature/auth",
|
||||
"commit": { "id": "abc123def456", "message": "add login flow" },
|
||||
"protected": false,
|
||||
});
|
||||
let branch: Branch = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(branch.name, "feature/auth");
|
||||
assert_eq!(branch.commit.id, "abc123def456");
|
||||
assert!(!branch.protected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_branch_minimal() {
|
||||
let json = serde_json::json!({ "name": "main" });
|
||||
let branch: Branch = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(branch.name, "main");
|
||||
assert_eq!(branch.commit.id, ""); // default
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_organization_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"id": 1,
|
||||
"username": "studio",
|
||||
"full_name": "Sunbeam Studios",
|
||||
"description": "Game studio",
|
||||
"avatar_url": "https://src.sunbeam.pt/avatars/1",
|
||||
});
|
||||
let org: Organization = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(org.username, "studio");
|
||||
assert_eq!(org.full_name, "Sunbeam Studios");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_notification_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"id": 42,
|
||||
"subject": { "title": "New issue", "url": "/api/v1/...", "type": "Issue" },
|
||||
"repository": { "full_name": "studio/sol", "description": "" },
|
||||
"unread": true,
|
||||
"updated_at": "2026-03-22T10:00:00Z",
|
||||
});
|
||||
let notif: Notification = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(notif.id, 42);
|
||||
assert!(notif.unread);
|
||||
assert_eq!(notif.subject.title, "New issue");
|
||||
assert_eq!(notif.subject.subject_type, "Issue");
|
||||
assert_eq!(notif.repository.as_ref().unwrap().full_name, "studio/sol");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_notification_minimal() {
|
||||
let json = serde_json::json!({ "id": 1 });
|
||||
let notif: Notification = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(notif.id, 1);
|
||||
assert!(!notif.unread);
|
||||
assert!(notif.repository.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_scopes_include_write_repo() {
|
||||
// New tools need write:repository for create/edit/fork/branch operations
|
||||
assert!(TOKEN_SCOPES.contains(&"write:repository"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_token_scopes_include_notifications() {
|
||||
assert!(TOKEN_SCOPES.contains(&"read:notification"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pull_request_with_refs() {
|
||||
let json = serde_json::json!({
|
||||
"number": 3,
|
||||
"title": "Add auth",
|
||||
"body": "implements OIDC",
|
||||
"state": "open",
|
||||
"html_url": "https://src.sunbeam.pt/studio/sol/pulls/3",
|
||||
"user": { "login": "sienna" },
|
||||
"head": { "label": "feature/auth", "ref": "feature/auth", "sha": "abc123" },
|
||||
"base": { "label": "main", "ref": "main", "sha": "def456" },
|
||||
"mergeable": true,
|
||||
"created_at": "2026-03-22T10:00:00Z",
|
||||
"updated_at": "2026-03-22T10:00:00Z",
|
||||
});
|
||||
let pr: PullRequest = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(pr.number, 3);
|
||||
assert_eq!(pr.mergeable, Some(true));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_repo_full_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"full_name": "studio/marathon",
|
||||
"description": "P2P game engine",
|
||||
"html_url": "https://src.sunbeam.pt/studio/marathon",
|
||||
"default_branch": "mainline",
|
||||
"open_issues_count": 121,
|
||||
"stars_count": 0,
|
||||
"forks_count": 2,
|
||||
"updated_at": "2026-03-06T13:21:24Z",
|
||||
"private": false,
|
||||
});
|
||||
let repo: Repo = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(repo.full_name, "studio/marathon");
|
||||
assert_eq!(repo.default_branch, "mainline");
|
||||
assert_eq!(repo.open_issues_count, 121);
|
||||
assert_eq!(repo.forks_count, 2);
|
||||
}
|
||||
}
|
||||
|
||||
359
src/sdk/kratos.rs
Normal file
359
src/sdk/kratos.rs
Normal file
@@ -0,0 +1,359 @@
|
||||
use reqwest::Client as HttpClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{info, warn};
|
||||
|
||||
pub struct KratosClient {
|
||||
admin_url: String,
|
||||
http: HttpClient,
|
||||
}
|
||||
|
||||
// ── Response types ──────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Identity {
|
||||
pub id: String,
|
||||
#[serde(default)]
|
||||
pub state: String,
|
||||
pub traits: IdentityTraits,
|
||||
#[serde(default)]
|
||||
pub created_at: String,
|
||||
#[serde(default)]
|
||||
pub updated_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct IdentityTraits {
|
||||
#[serde(default)]
|
||||
pub email: String,
|
||||
#[serde(default)]
|
||||
pub name: Option<NameTraits>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct NameTraits {
|
||||
#[serde(default)]
|
||||
pub first: String,
|
||||
#[serde(default)]
|
||||
pub last: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct Session {
|
||||
pub id: String,
|
||||
#[serde(default)]
|
||||
pub active: bool,
|
||||
#[serde(default)]
|
||||
pub authenticated_at: String,
|
||||
#[serde(default)]
|
||||
pub expires_at: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct RecoveryResponse {
|
||||
#[serde(default)]
|
||||
pub recovery_link: String,
|
||||
#[serde(default)]
|
||||
pub recovery_code: String,
|
||||
}
|
||||
|
||||
// ── Implementation ──────────────────────────────────────────────────────────
|
||||
|
||||
impl KratosClient {
|
||||
pub fn new(admin_url: String) -> Self {
|
||||
Self {
|
||||
admin_url: admin_url.trim_end_matches('/').to_string(),
|
||||
http: HttpClient::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve an email or UUID to an identity ID.
|
||||
/// If the input looks like a UUID, use it directly.
|
||||
/// Otherwise, search by credentials_identifier (email).
|
||||
async fn resolve_id(&self, email_or_id: &str) -> Result<String, String> {
|
||||
if is_uuid(email_or_id) {
|
||||
return Ok(email_or_id.to_string());
|
||||
}
|
||||
|
||||
// Search by email
|
||||
let url = format!(
|
||||
"{}/admin/identities?credentials_identifier={}",
|
||||
self.admin_url,
|
||||
urlencoding::encode(email_or_id)
|
||||
);
|
||||
let resp = self
|
||||
.http
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("failed to search identities: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("identity search failed: {text}"));
|
||||
}
|
||||
|
||||
let identities: Vec<Identity> = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse identities: {e}"))?;
|
||||
|
||||
identities
|
||||
.first()
|
||||
.map(|i| i.id.clone())
|
||||
.ok_or_else(|| format!("no identity found for '{email_or_id}'"))
|
||||
}
|
||||
|
||||
pub async fn list_users(
|
||||
&self,
|
||||
search: Option<&str>,
|
||||
limit: Option<u32>,
|
||||
) -> Result<Vec<Identity>, String> {
|
||||
let mut url = format!("{}/admin/identities", self.admin_url);
|
||||
let mut params = vec![];
|
||||
if let Some(s) = search {
|
||||
params.push(format!(
|
||||
"credentials_identifier={}",
|
||||
urlencoding::encode(s)
|
||||
));
|
||||
}
|
||||
params.push(format!("page_size={}", limit.unwrap_or(50)));
|
||||
if !params.is_empty() {
|
||||
url.push_str(&format!("?{}", params.join("&")));
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("failed to list identities: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("list identities failed: {text}"));
|
||||
}
|
||||
|
||||
resp.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse identities: {e}"))
|
||||
}
|
||||
|
||||
pub async fn get_user(&self, email_or_id: &str) -> Result<Identity, String> {
|
||||
let id = self.resolve_id(email_or_id).await?;
|
||||
let url = format!("{}/admin/identities/{}", self.admin_url, id);
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("failed to get identity: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("get identity failed: {text}"));
|
||||
}
|
||||
|
||||
resp.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse identity: {e}"))
|
||||
}
|
||||
|
||||
pub async fn create_user(
|
||||
&self,
|
||||
email: &str,
|
||||
first_name: Option<&str>,
|
||||
last_name: Option<&str>,
|
||||
) -> Result<Identity, String> {
|
||||
let mut traits = serde_json::json!({ "email": email });
|
||||
if first_name.is_some() || last_name.is_some() {
|
||||
traits["name"] = serde_json::json!({
|
||||
"first": first_name.unwrap_or(""),
|
||||
"last": last_name.unwrap_or(""),
|
||||
});
|
||||
}
|
||||
|
||||
let body = serde_json::json!({
|
||||
"schema_id": "default",
|
||||
"traits": traits,
|
||||
});
|
||||
|
||||
let url = format!("{}/admin/identities", self.admin_url);
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("failed to create identity: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("create identity failed: {text}"));
|
||||
}
|
||||
|
||||
let identity: Identity = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse identity: {e}"))?;
|
||||
|
||||
info!(id = identity.id.as_str(), email, "Created identity");
|
||||
Ok(identity)
|
||||
}
|
||||
|
||||
pub async fn recover_user(&self, email_or_id: &str) -> Result<RecoveryResponse, String> {
|
||||
let id = self.resolve_id(email_or_id).await?;
|
||||
|
||||
let body = serde_json::json!({
|
||||
"identity_id": id,
|
||||
});
|
||||
|
||||
let url = format!("{}/admin/recovery/code", self.admin_url);
|
||||
let resp = self
|
||||
.http
|
||||
.post(&url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("failed to create recovery: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("create recovery failed: {text}"));
|
||||
}
|
||||
|
||||
resp.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse recovery response: {e}"))
|
||||
}
|
||||
|
||||
pub async fn disable_user(&self, email_or_id: &str) -> Result<Identity, String> {
|
||||
self.set_state(email_or_id, "inactive").await
|
||||
}
|
||||
|
||||
pub async fn enable_user(&self, email_or_id: &str) -> Result<Identity, String> {
|
||||
self.set_state(email_or_id, "active").await
|
||||
}
|
||||
|
||||
async fn set_state(&self, email_or_id: &str, state: &str) -> Result<Identity, String> {
|
||||
let id = self.resolve_id(email_or_id).await?;
|
||||
let url = format!("{}/admin/identities/{}", self.admin_url, id);
|
||||
|
||||
let body = serde_json::json!({ "state": state });
|
||||
let resp = self
|
||||
.http
|
||||
.put(&url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("failed to update identity state: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("update identity state failed: {text}"));
|
||||
}
|
||||
|
||||
info!(id = id.as_str(), state, "Updated identity state");
|
||||
resp.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse identity: {e}"))
|
||||
}
|
||||
|
||||
pub async fn list_sessions(&self, email_or_id: &str) -> Result<Vec<Session>, String> {
|
||||
let id = self.resolve_id(email_or_id).await?;
|
||||
let url = format!("{}/admin/identities/{}/sessions", self.admin_url, id);
|
||||
|
||||
let resp = self
|
||||
.http
|
||||
.get(&url)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("failed to list sessions: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
return Err(format!("list sessions failed: {text}"));
|
||||
}
|
||||
|
||||
resp.json()
|
||||
.await
|
||||
.map_err(|e| format!("failed to parse sessions: {e}"))
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a string looks like a UUID (Kratos identity ID format).
|
||||
fn is_uuid(s: &str) -> bool {
|
||||
s.len() == 36
|
||||
&& s.chars()
|
||||
.all(|c| c.is_ascii_hexdigit() || c == '-')
|
||||
&& s.matches('-').count() == 4
|
||||
}
|
||||
|
||||
// ── URL encoding helper ─────────────────────────────────────────────────────
|
||||
|
||||
mod urlencoding {
|
||||
pub fn encode(s: &str) -> String {
|
||||
url::form_urlencoded::byte_serialize(s.as_bytes()).collect()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_identity_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"id": "cd0a2db5-1234-5678-9abc-def012345678",
|
||||
"state": "active",
|
||||
"traits": {
|
||||
"email": "sienna@sunbeam.pt",
|
||||
"name": { "first": "Sienna", "last": "V" }
|
||||
},
|
||||
"created_at": "2026-03-05T10:00:00Z",
|
||||
"updated_at": "2026-03-20T12:00:00Z",
|
||||
});
|
||||
let id: Identity = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(id.state, "active");
|
||||
assert_eq!(id.traits.email, "sienna@sunbeam.pt");
|
||||
assert_eq!(id.traits.name.as_ref().unwrap().first, "Sienna");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_identity_minimal_traits() {
|
||||
let json = serde_json::json!({
|
||||
"id": "abc-123",
|
||||
"traits": { "email": "test@example.com" },
|
||||
});
|
||||
let id: Identity = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(id.traits.email, "test@example.com");
|
||||
assert!(id.traits.name.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_session_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"id": "sess-123",
|
||||
"active": true,
|
||||
"authenticated_at": "2026-03-22T10:00:00Z",
|
||||
"expires_at": "2026-04-21T10:00:00Z",
|
||||
});
|
||||
let sess: Session = serde_json::from_value(json).unwrap();
|
||||
assert!(sess.active);
|
||||
assert_eq!(sess.id, "sess-123");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_uuid() {
|
||||
assert!(is_uuid("cd0a2db5-1234-5678-9abc-def012345678"));
|
||||
assert!(!is_uuid("sienna@sunbeam.pt"));
|
||||
assert!(!is_uuid("not-a-uuid"));
|
||||
assert!(!is_uuid(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_urlencoding() {
|
||||
assert_eq!(urlencoding::encode("hello@world.com"), "hello%40world.com");
|
||||
assert_eq!(urlencoding::encode("plain"), "plain");
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod gitea;
|
||||
pub mod kratos;
|
||||
pub mod tokens;
|
||||
pub mod vault;
|
||||
|
||||
99
src/sync.rs
99
src/sync.rs
@@ -42,6 +42,8 @@ pub struct AppState {
|
||||
pub last_response: Arc<Mutex<HashMap<String, Instant>>>,
|
||||
/// Tracks rooms where a response is currently being generated (in-flight guard)
|
||||
pub responding_in: Arc<Mutex<std::collections::HashSet<String>>>,
|
||||
/// Rooms where Sol has been told to be quiet — maps room_id → silenced_until
|
||||
pub silenced_until: Arc<Mutex<HashMap<String, Instant>>>,
|
||||
}
|
||||
|
||||
pub async fn start_sync(client: Client, state: Arc<AppState>) -> anyhow::Result<()> {
|
||||
@@ -193,6 +195,38 @@ async fn handle_message(
|
||||
);
|
||||
}
|
||||
|
||||
// Silence detection — if someone tells Sol to be quiet, set a per-room timer
|
||||
{
|
||||
let lower = body.to_lowercase();
|
||||
let silence_phrases = [
|
||||
"shut up", "be quiet", "shush", "silence", "stop talking",
|
||||
"quiet down", "hush", "enough sol", "sol enough", "sol stop",
|
||||
"sol shut up", "sol be quiet", "sol shush",
|
||||
];
|
||||
if silence_phrases.iter().any(|p| lower.contains(p)) {
|
||||
let duration = std::time::Duration::from_millis(
|
||||
state.config.behavior.silence_duration_ms,
|
||||
);
|
||||
let until = Instant::now() + duration;
|
||||
let mut silenced = state.silenced_until.lock().await;
|
||||
silenced.insert(room_id.clone(), until);
|
||||
info!(
|
||||
room = room_id.as_str(),
|
||||
duration_mins = state.config.behavior.silence_duration_ms / 60_000,
|
||||
"Silenced in room"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Check if Sol is currently silenced in this room
|
||||
let is_silenced = {
|
||||
let silenced = state.silenced_until.lock().await;
|
||||
silenced
|
||||
.get(&room_id)
|
||||
.map(|until| Instant::now() < *until)
|
||||
.unwrap_or(false)
|
||||
};
|
||||
|
||||
// Evaluate whether to respond
|
||||
let recent: Vec<String> = {
|
||||
let convs = state.conversations.lock().await;
|
||||
@@ -203,28 +237,65 @@ async fn handle_message(
|
||||
.collect()
|
||||
};
|
||||
|
||||
// A: Check if this message is a reply to another human (not Sol)
|
||||
let is_reply_to_human = is_reply && !is_dm && {
|
||||
// If it's a reply, check the conversation context for who the previous
|
||||
// message was from. We don't have event IDs in context, so we use a
|
||||
// heuristic: if the most recent message before this one was from a human
|
||||
// (not Sol), this reply is likely directed at them.
|
||||
let convs = state.conversations.lock().await;
|
||||
let ctx = convs.get_context(&room_id);
|
||||
let sol_id = &state.config.matrix.user_id;
|
||||
// Check the message before the current one (last in context before we added ours)
|
||||
ctx.iter().rev().skip(1).next()
|
||||
.map(|m| m.sender != *sol_id)
|
||||
.unwrap_or(false)
|
||||
};
|
||||
|
||||
// B: Count messages since Sol last spoke in this room
|
||||
let messages_since_sol = {
|
||||
let convs = state.conversations.lock().await;
|
||||
let ctx = convs.get_context(&room_id);
|
||||
let sol_id = &state.config.matrix.user_id;
|
||||
ctx.iter().rev().take_while(|m| m.sender != *sol_id).count()
|
||||
};
|
||||
|
||||
let engagement = state
|
||||
.evaluator
|
||||
.evaluate(&sender, &body, is_dm, &recent, &state.mistral)
|
||||
.evaluate(
|
||||
&sender, &body, is_dm, &recent, &state.mistral,
|
||||
is_reply_to_human, messages_since_sol, is_silenced,
|
||||
)
|
||||
.await;
|
||||
|
||||
let (should_respond, is_spontaneous) = match engagement {
|
||||
// use_thread: if true, Sol responds in a thread instead of inline
|
||||
let (should_respond, is_spontaneous, use_thread) = match engagement {
|
||||
Engagement::MustRespond { reason } => {
|
||||
info!(room = room_id.as_str(), ?reason, "Must respond");
|
||||
(true, false)
|
||||
// Direct mention breaks silence
|
||||
if is_silenced {
|
||||
let mut silenced = state.silenced_until.lock().await;
|
||||
silenced.remove(&room_id);
|
||||
info!(room = room_id.as_str(), "Silence broken by direct mention");
|
||||
}
|
||||
Engagement::MaybeRespond { relevance, hook } => {
|
||||
info!(room = room_id.as_str(), relevance, hook = hook.as_str(), "Maybe respond (spontaneous)");
|
||||
(true, true)
|
||||
(true, false, false)
|
||||
}
|
||||
Engagement::Respond { relevance, hook } => {
|
||||
info!(room = room_id.as_str(), relevance, hook = hook.as_str(), "Respond (spontaneous)");
|
||||
(true, true, false)
|
||||
}
|
||||
Engagement::ThreadReply { relevance, hook } => {
|
||||
info!(room = room_id.as_str(), relevance, hook = hook.as_str(), "Thread reply (spontaneous)");
|
||||
(true, true, true)
|
||||
}
|
||||
Engagement::React { emoji, relevance } => {
|
||||
info!(room = room_id.as_str(), relevance, emoji = emoji.as_str(), "Reacting with emoji");
|
||||
if let Err(e) = matrix_utils::send_reaction(&room, event.event_id.clone().into(), &emoji).await {
|
||||
error!("Failed to send reaction: {e}");
|
||||
}
|
||||
(false, false)
|
||||
(false, false, false)
|
||||
}
|
||||
Engagement::Ignore => (false, false),
|
||||
Engagement::Ignore => (false, false, false),
|
||||
};
|
||||
|
||||
if !should_respond {
|
||||
@@ -310,6 +381,7 @@ async fn handle_message(
|
||||
&state.conversation_registry,
|
||||
image_data_uri.as_deref(),
|
||||
context_hint,
|
||||
event.event_id.clone().into(),
|
||||
)
|
||||
.await
|
||||
} else {
|
||||
@@ -331,17 +403,20 @@ async fn handle_message(
|
||||
};
|
||||
|
||||
if let Some(text) = response {
|
||||
// Reply with reference only when directly addressed. Spontaneous
|
||||
// and DM messages are sent as plain content — feels more natural.
|
||||
let content = if !is_spontaneous && !is_dm {
|
||||
let content = if use_thread {
|
||||
// Thread reply — less intrusive, for tangential contributions
|
||||
matrix_utils::make_thread_reply(&text, event.event_id.to_owned())
|
||||
} else if !is_spontaneous && !is_dm {
|
||||
// Direct reply — when explicitly addressed
|
||||
matrix_utils::make_reply_content(&text, event.event_id.to_owned())
|
||||
} else {
|
||||
// Plain message — spontaneous or DM, feels more natural
|
||||
ruma::events::room::message::RoomMessageEventContent::text_markdown(&text)
|
||||
};
|
||||
if let Err(e) = room.send(content).await {
|
||||
error!("Failed to send response: {e}");
|
||||
} else {
|
||||
info!(room = room_id.as_str(), len = text.len(), is_dm, "Response sent");
|
||||
info!(room = room_id.as_str(), len = text.len(), is_dm, use_thread, "Response sent");
|
||||
}
|
||||
// Post-response memory extraction (fire-and-forget)
|
||||
if state.config.behavior.memory_extraction_enabled {
|
||||
|
||||
278
src/time_context.rs
Normal file
278
src/time_context.rs
Normal file
@@ -0,0 +1,278 @@
|
||||
use chrono::{Datelike, Duration, NaiveTime, TimeZone, Utc, Weekday};
|
||||
|
||||
/// Comprehensive time context for the model.
|
||||
/// All epoch values are milliseconds. All day boundaries are midnight UTC.
|
||||
pub struct TimeContext {
|
||||
// ── Current moment ──
|
||||
pub now: i64,
|
||||
pub date: String, // 2026-03-22
|
||||
pub time: String, // 14:35
|
||||
pub datetime: String, // 2026-03-22T14:35:12Z
|
||||
pub day_of_week: String, // Saturday
|
||||
pub day_of_week_short: String, // Sat
|
||||
|
||||
// ── Today ──
|
||||
pub today_start: i64, // midnight today
|
||||
pub today_end: i64, // 23:59:59.999 today
|
||||
|
||||
// ── Yesterday ──
|
||||
pub yesterday_start: i64,
|
||||
pub yesterday_end: i64,
|
||||
pub yesterday_name: String, // Friday
|
||||
|
||||
// ── Day before yesterday ──
|
||||
pub two_days_ago_start: i64,
|
||||
pub two_days_ago_end: i64,
|
||||
pub two_days_ago_name: String,
|
||||
|
||||
// ── This week (Monday start) ──
|
||||
pub this_week_start: i64,
|
||||
|
||||
// ── Last week ──
|
||||
pub last_week_start: i64,
|
||||
pub last_week_end: i64,
|
||||
|
||||
// ── This month ──
|
||||
pub this_month_start: i64,
|
||||
|
||||
// ── Last month ──
|
||||
pub last_month_start: i64,
|
||||
pub last_month_end: i64,
|
||||
|
||||
// ── Rolling offsets from now ──
|
||||
pub ago_1h: i64,
|
||||
pub ago_6h: i64,
|
||||
pub ago_12h: i64,
|
||||
pub ago_24h: i64,
|
||||
pub ago_48h: i64,
|
||||
pub ago_7d: i64,
|
||||
pub ago_14d: i64,
|
||||
pub ago_30d: i64,
|
||||
}
|
||||
|
||||
fn weekday_name(w: Weekday) -> &'static str {
|
||||
match w {
|
||||
Weekday::Mon => "Monday",
|
||||
Weekday::Tue => "Tuesday",
|
||||
Weekday::Wed => "Wednesday",
|
||||
Weekday::Thu => "Thursday",
|
||||
Weekday::Fri => "Friday",
|
||||
Weekday::Sat => "Saturday",
|
||||
Weekday::Sun => "Sunday",
|
||||
}
|
||||
}
|
||||
|
||||
fn weekday_short(w: Weekday) -> &'static str {
|
||||
match w {
|
||||
Weekday::Mon => "Mon",
|
||||
Weekday::Tue => "Tue",
|
||||
Weekday::Wed => "Wed",
|
||||
Weekday::Thu => "Thu",
|
||||
Weekday::Fri => "Fri",
|
||||
Weekday::Sat => "Sat",
|
||||
Weekday::Sun => "Sun",
|
||||
}
|
||||
}
|
||||
|
||||
impl TimeContext {
|
||||
pub fn now() -> Self {
|
||||
let now = Utc::now();
|
||||
let today = now.date_naive();
|
||||
let yesterday = today - Duration::days(1);
|
||||
let two_days_ago = today - Duration::days(2);
|
||||
|
||||
let midnight = NaiveTime::from_hms_opt(0, 0, 0).unwrap();
|
||||
let end_of_day = NaiveTime::from_hms_milli_opt(23, 59, 59, 999).unwrap();
|
||||
|
||||
let today_start = Utc.from_utc_datetime(&today.and_time(midnight)).timestamp_millis();
|
||||
let today_end = Utc.from_utc_datetime(&today.and_time(end_of_day)).timestamp_millis();
|
||||
let yesterday_start = Utc.from_utc_datetime(&yesterday.and_time(midnight)).timestamp_millis();
|
||||
let yesterday_end = Utc.from_utc_datetime(&yesterday.and_time(end_of_day)).timestamp_millis();
|
||||
let two_days_ago_start = Utc.from_utc_datetime(&two_days_ago.and_time(midnight)).timestamp_millis();
|
||||
let two_days_ago_end = Utc.from_utc_datetime(&two_days_ago.and_time(end_of_day)).timestamp_millis();
|
||||
|
||||
// This week (Monday start)
|
||||
let days_since_monday = today.weekday().num_days_from_monday() as i64;
|
||||
let monday = today - Duration::days(days_since_monday);
|
||||
let this_week_start = Utc.from_utc_datetime(&monday.and_time(midnight)).timestamp_millis();
|
||||
|
||||
// Last week
|
||||
let last_monday = monday - Duration::days(7);
|
||||
let last_sunday = monday - Duration::days(1);
|
||||
let last_week_start = Utc.from_utc_datetime(&last_monday.and_time(midnight)).timestamp_millis();
|
||||
let last_week_end = Utc.from_utc_datetime(&last_sunday.and_time(end_of_day)).timestamp_millis();
|
||||
|
||||
// This month
|
||||
let first_of_month = today.with_day(1).unwrap();
|
||||
let this_month_start = Utc.from_utc_datetime(&first_of_month.and_time(midnight)).timestamp_millis();
|
||||
|
||||
// Last month
|
||||
let last_month_last_day = first_of_month - Duration::days(1);
|
||||
let last_month_first = last_month_last_day.with_day(1).unwrap();
|
||||
let last_month_start = Utc.from_utc_datetime(&last_month_first.and_time(midnight)).timestamp_millis();
|
||||
let last_month_end = Utc.from_utc_datetime(&last_month_last_day.and_time(end_of_day)).timestamp_millis();
|
||||
|
||||
let now_ms = now.timestamp_millis();
|
||||
|
||||
Self {
|
||||
now: now_ms,
|
||||
date: now.format("%Y-%m-%d").to_string(),
|
||||
time: now.format("%H:%M").to_string(),
|
||||
datetime: now.format("%Y-%m-%dT%H:%M:%SZ").to_string(),
|
||||
day_of_week: weekday_name(today.weekday()).to_string(),
|
||||
day_of_week_short: weekday_short(today.weekday()).to_string(),
|
||||
|
||||
today_start,
|
||||
today_end,
|
||||
|
||||
yesterday_start,
|
||||
yesterday_end,
|
||||
yesterday_name: weekday_name(yesterday.weekday()).to_string(),
|
||||
|
||||
two_days_ago_start,
|
||||
two_days_ago_end,
|
||||
two_days_ago_name: weekday_name(two_days_ago.weekday()).to_string(),
|
||||
|
||||
this_week_start,
|
||||
last_week_start,
|
||||
last_week_end,
|
||||
this_month_start,
|
||||
last_month_start,
|
||||
last_month_end,
|
||||
|
||||
ago_1h: now_ms - 3_600_000,
|
||||
ago_6h: now_ms - 21_600_000,
|
||||
ago_12h: now_ms - 43_200_000,
|
||||
ago_24h: now_ms - 86_400_000,
|
||||
ago_48h: now_ms - 172_800_000,
|
||||
ago_7d: now_ms - 604_800_000,
|
||||
ago_14d: now_ms - 1_209_600_000,
|
||||
ago_30d: now_ms - 2_592_000_000,
|
||||
}
|
||||
}
|
||||
|
||||
/// Full time block for system prompts (~25 values).
|
||||
/// Used in the legacy path template and the conversations API per-message header.
|
||||
pub fn system_block(&self) -> String {
|
||||
format!(
|
||||
"\
|
||||
## time\n\
|
||||
\n\
|
||||
current: {} {} UTC ({}, {})\n\
|
||||
epoch_ms: {}\n\
|
||||
\n\
|
||||
day boundaries (midnight UTC, use these for search_archive after/before):\n\
|
||||
today: {} to {}\n\
|
||||
yesterday ({}): {} to {}\n\
|
||||
{} ago: {} to {}\n\
|
||||
this week (Mon): {} to now\n\
|
||||
last week: {} to {}\n\
|
||||
this month: {} to now\n\
|
||||
last month: {} to {}\n\
|
||||
\n\
|
||||
rolling offsets:\n\
|
||||
1h_ago={} 6h_ago={} 12h_ago={} 24h_ago={}\n\
|
||||
48h_ago={} 7d_ago={} 14d_ago={} 30d_ago={}",
|
||||
self.date, self.time, self.day_of_week, self.datetime,
|
||||
self.now,
|
||||
self.today_start, self.today_end,
|
||||
self.yesterday_name, self.yesterday_start, self.yesterday_end,
|
||||
self.two_days_ago_name, self.two_days_ago_start, self.two_days_ago_end,
|
||||
self.this_week_start,
|
||||
self.last_week_start, self.last_week_end,
|
||||
self.this_month_start,
|
||||
self.last_month_start, self.last_month_end,
|
||||
self.ago_1h, self.ago_6h, self.ago_12h, self.ago_24h,
|
||||
self.ago_48h, self.ago_7d, self.ago_14d, self.ago_30d,
|
||||
)
|
||||
}
|
||||
|
||||
/// Compact time line for per-message injection (~5 key values).
|
||||
pub fn message_line(&self) -> String {
|
||||
format!(
|
||||
"[time: {} {} UTC | today={}-{} | yesterday={}-{} | now={}]",
|
||||
self.date, self.time,
|
||||
self.today_start, self.today_end,
|
||||
self.yesterday_start, self.yesterday_end,
|
||||
self.now,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_time_context_now() {
|
||||
let tc = TimeContext::now();
|
||||
assert!(tc.now > 0);
|
||||
assert!(tc.today_start <= tc.now);
|
||||
assert!(tc.today_end >= tc.now);
|
||||
assert!(tc.yesterday_start < tc.today_start);
|
||||
assert!(tc.yesterday_end < tc.today_start);
|
||||
assert!(tc.this_week_start <= tc.today_start);
|
||||
assert!(tc.last_week_start < tc.this_week_start);
|
||||
assert!(tc.this_month_start <= tc.today_start);
|
||||
assert!(tc.last_month_start < tc.this_month_start);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_day_boundaries_are_midnight() {
|
||||
let tc = TimeContext::now();
|
||||
// today_start should be divisible by 86400000 (midnight)
|
||||
// (not exactly, due to timezone, but should end in 00:00:00.000)
|
||||
assert!(tc.today_start % 1000 == 0); // whole second
|
||||
assert!(tc.yesterday_start % 1000 == 0);
|
||||
// end of day should be .999
|
||||
assert!(tc.today_end % 1000 == 999);
|
||||
assert!(tc.yesterday_end % 1000 == 999);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_yesterday_is_24h_before_today() {
|
||||
let tc = TimeContext::now();
|
||||
assert_eq!(tc.today_start - tc.yesterday_start, 86_400_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_rolling_offsets() {
|
||||
let tc = TimeContext::now();
|
||||
assert_eq!(tc.now - tc.ago_1h, 3_600_000);
|
||||
assert_eq!(tc.now - tc.ago_24h, 86_400_000);
|
||||
assert_eq!(tc.now - tc.ago_7d, 604_800_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_day_names() {
|
||||
let tc = TimeContext::now();
|
||||
let valid_days = ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"];
|
||||
assert!(valid_days.contains(&tc.day_of_week.as_str()));
|
||||
assert!(valid_days.contains(&tc.yesterday_name.as_str()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_block_contains_key_values() {
|
||||
let tc = TimeContext::now();
|
||||
let block = tc.system_block();
|
||||
assert!(block.contains("epoch_ms:"));
|
||||
assert!(block.contains("today:"));
|
||||
assert!(block.contains("yesterday"));
|
||||
assert!(block.contains("this week"));
|
||||
assert!(block.contains("last week"));
|
||||
assert!(block.contains("this month"));
|
||||
assert!(block.contains("1h_ago="));
|
||||
assert!(block.contains("30d_ago="));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_message_line_compact() {
|
||||
let tc = TimeContext::now();
|
||||
let line = tc.message_line();
|
||||
assert!(line.starts_with("[time:"));
|
||||
assert!(line.contains("today="));
|
||||
assert!(line.contains("yesterday="));
|
||||
assert!(line.contains("now="));
|
||||
assert!(line.ends_with(']'));
|
||||
}
|
||||
}
|
||||
@@ -136,6 +136,270 @@ pub async fn execute(
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_create_repo" => {
|
||||
let name = args["name"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'name'"))?;
|
||||
let org = args["org"].as_str();
|
||||
let description = args["description"].as_str();
|
||||
let private = args["private"].as_bool();
|
||||
let auto_init = args["auto_init"].as_bool();
|
||||
let default_branch = args["default_branch"].as_str();
|
||||
|
||||
match gitea
|
||||
.create_repo(localpart, name, org, description, private, auto_init, default_branch)
|
||||
.await
|
||||
{
|
||||
Ok(r) => Ok(serde_json::to_string(&r).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_edit_repo" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let description = args["description"].as_str();
|
||||
let private = args["private"].as_bool();
|
||||
let archived = args["archived"].as_bool();
|
||||
let default_branch = args["default_branch"].as_str();
|
||||
|
||||
match gitea
|
||||
.edit_repo(localpart, owner, repo, description, private, archived, default_branch)
|
||||
.await
|
||||
{
|
||||
Ok(r) => Ok(serde_json::to_string(&r).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_fork_repo" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let new_name = args["new_name"].as_str();
|
||||
|
||||
match gitea.fork_repo(localpart, owner, repo, new_name).await {
|
||||
Ok(r) => Ok(serde_json::to_string(&r).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_list_org_repos" => {
|
||||
let org = args["org"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'org'"))?;
|
||||
let limit = args["limit"].as_u64().map(|n| n as u32);
|
||||
|
||||
match gitea.list_org_repos(localpart, org, limit).await {
|
||||
Ok(repos) => Ok(serde_json::to_string(&repos).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_edit_issue" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let number = args["number"]
|
||||
.as_u64()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'number'"))?;
|
||||
let title = args["title"].as_str();
|
||||
let body = args["body"].as_str();
|
||||
let state = args["state"].as_str();
|
||||
let assignees: Option<Vec<String>> = args["assignees"]
|
||||
.as_array()
|
||||
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect());
|
||||
|
||||
match gitea
|
||||
.edit_issue(localpart, owner, repo, number, title, body, state, assignees.as_deref())
|
||||
.await
|
||||
{
|
||||
Ok(issue) => Ok(serde_json::to_string(&issue).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_list_comments" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let number = args["number"]
|
||||
.as_u64()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'number'"))?;
|
||||
|
||||
match gitea.list_comments(localpart, owner, repo, number).await {
|
||||
Ok(comments) => Ok(serde_json::to_string(&comments).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_create_comment" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let number = args["number"]
|
||||
.as_u64()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'number'"))?;
|
||||
let body = args["body"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'body'"))?;
|
||||
|
||||
match gitea
|
||||
.create_comment(localpart, owner, repo, number, body)
|
||||
.await
|
||||
{
|
||||
Ok(comment) => Ok(serde_json::to_string(&comment).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_get_pull" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let number = args["number"]
|
||||
.as_u64()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'number'"))?;
|
||||
|
||||
match gitea.get_pull(localpart, owner, repo, number).await {
|
||||
Ok(pr) => Ok(serde_json::to_string(&pr).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_create_pull" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let title = args["title"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'title'"))?;
|
||||
let head = args["head"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'head'"))?;
|
||||
let base = args["base"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'base'"))?;
|
||||
let body = args["body"].as_str();
|
||||
|
||||
match gitea
|
||||
.create_pull(localpart, owner, repo, title, head, base, body)
|
||||
.await
|
||||
{
|
||||
Ok(pr) => Ok(serde_json::to_string(&pr).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_merge_pull" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let number = args["number"]
|
||||
.as_u64()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'number'"))?;
|
||||
let method = args["method"].as_str();
|
||||
let delete_branch = args["delete_branch"].as_bool();
|
||||
|
||||
match gitea
|
||||
.merge_pull(localpart, owner, repo, number, method, delete_branch)
|
||||
.await
|
||||
{
|
||||
Ok(result) => Ok(serde_json::to_string(&result).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_list_branches" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
|
||||
match gitea.list_branches(localpart, owner, repo).await {
|
||||
Ok(branches) => Ok(serde_json::to_string(&branches).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_create_branch" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let branch_name = args["branch_name"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'branch_name'"))?;
|
||||
let from_branch = args["from_branch"].as_str();
|
||||
|
||||
match gitea
|
||||
.create_branch(localpart, owner, repo, branch_name, from_branch)
|
||||
.await
|
||||
{
|
||||
Ok(branch) => Ok(serde_json::to_string(&branch).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_delete_branch" => {
|
||||
let owner = args["owner"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'owner'"))?;
|
||||
let repo = args["repo"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'repo'"))?;
|
||||
let branch = args["branch"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'branch'"))?;
|
||||
|
||||
match gitea.delete_branch(localpart, owner, repo, branch).await {
|
||||
Ok(result) => Ok(serde_json::to_string(&result).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_list_orgs" => {
|
||||
let username = args["username"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'username'"))?;
|
||||
|
||||
match gitea.list_orgs(localpart, username).await {
|
||||
Ok(orgs) => Ok(serde_json::to_string(&orgs).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_get_org" => {
|
||||
let org = args["org"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'org'"))?;
|
||||
|
||||
match gitea.get_org(localpart, org).await {
|
||||
Ok(o) => Ok(serde_json::to_string(&o).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"gitea_list_notifications" => {
|
||||
match gitea.list_notifications(localpart).await {
|
||||
Ok(notifs) => Ok(serde_json::to_string(¬ifs).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
_ => anyhow::bail!("Unknown devtools tool: {name}"),
|
||||
}
|
||||
}
|
||||
@@ -297,7 +561,10 @@ pub fn tool_definitions() -> Vec<mistralai_client::v1::tool::Tool> {
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_get_file".into(),
|
||||
"Get the contents of a file from a repository.".into(),
|
||||
"Get file contents or list directory entries. Use with path='' to list the repo root. \
|
||||
Use with a directory path to list its contents. Use with a file path to get the file. \
|
||||
This is how you explore and browse repositories."
|
||||
.into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -321,5 +588,394 @@ pub fn tool_definitions() -> Vec<mistralai_client::v1::tool::Tool> {
|
||||
"required": ["owner", "repo", "path"]
|
||||
}),
|
||||
),
|
||||
// ── Repos (new) ─────────────────────────────────────────────────────
|
||||
Tool::new(
|
||||
"gitea_create_repo".into(),
|
||||
"Create a new repository for the requesting user, or under an org.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"org": {
|
||||
"type": "string",
|
||||
"description": "Organization to create the repo under (omit for personal repo)"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "Repository description"
|
||||
},
|
||||
"private": {
|
||||
"type": "boolean",
|
||||
"description": "Whether the repo is private (default: false)"
|
||||
},
|
||||
"auto_init": {
|
||||
"type": "boolean",
|
||||
"description": "Initialize with a README (default: false)"
|
||||
},
|
||||
"default_branch": {
|
||||
"type": "string",
|
||||
"description": "Default branch name (default: main)"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_edit_repo".into(),
|
||||
"Update repository settings (description, visibility, archived, default branch).".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"description": {
|
||||
"type": "string",
|
||||
"description": "New description"
|
||||
},
|
||||
"private": {
|
||||
"type": "boolean",
|
||||
"description": "Set visibility"
|
||||
},
|
||||
"archived": {
|
||||
"type": "boolean",
|
||||
"description": "Archive or unarchive the repo"
|
||||
},
|
||||
"default_branch": {
|
||||
"type": "string",
|
||||
"description": "Set default branch"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_fork_repo".into(),
|
||||
"Fork a repository into the requesting user's account.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"new_name": {
|
||||
"type": "string",
|
||||
"description": "Name for the forked repo (default: same as original)"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_list_org_repos".into(),
|
||||
"List repositories belonging to an organization.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"org": {
|
||||
"type": "string",
|
||||
"description": "Organization name"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (default 20)"
|
||||
}
|
||||
},
|
||||
"required": ["org"]
|
||||
}),
|
||||
),
|
||||
// ── Issues (new) ────────────────────────────────────────────────────
|
||||
Tool::new(
|
||||
"gitea_edit_issue".into(),
|
||||
"Update an issue (title, body, state, assignees).".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "Issue number"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "New title"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "New body (markdown)"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"description": "Set state: open or closed"
|
||||
},
|
||||
"assignees": {
|
||||
"type": "array",
|
||||
"items": { "type": "string" },
|
||||
"description": "Usernames to assign"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo", "number"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_list_comments".into(),
|
||||
"List comments on an issue.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "Issue number"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo", "number"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_create_comment".into(),
|
||||
"Add a comment to an issue. Authored by the requesting user.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "Issue number"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "Comment body (markdown)"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo", "number", "body"]
|
||||
}),
|
||||
),
|
||||
// ── Pull requests (new) ─────────────────────────────────────────────
|
||||
Tool::new(
|
||||
"gitea_get_pull".into(),
|
||||
"Get details of a specific pull request by number.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "Pull request number"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo", "number"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_create_pull".into(),
|
||||
"Create a pull request. Authored by the requesting user.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "PR title"
|
||||
},
|
||||
"head": {
|
||||
"type": "string",
|
||||
"description": "Source branch"
|
||||
},
|
||||
"base": {
|
||||
"type": "string",
|
||||
"description": "Target branch"
|
||||
},
|
||||
"body": {
|
||||
"type": "string",
|
||||
"description": "PR description (markdown)"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo", "title", "head", "base"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_merge_pull".into(),
|
||||
"Merge a pull request.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"number": {
|
||||
"type": "integer",
|
||||
"description": "Pull request number"
|
||||
},
|
||||
"method": {
|
||||
"type": "string",
|
||||
"description": "Merge method: merge, rebase, or squash (default: merge)"
|
||||
},
|
||||
"delete_branch": {
|
||||
"type": "boolean",
|
||||
"description": "Delete head branch after merge (default: false)"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo", "number"]
|
||||
}),
|
||||
),
|
||||
// ── Branches (new) ──────────────────────────────────────────────────
|
||||
Tool::new(
|
||||
"gitea_list_branches".into(),
|
||||
"List branches in a repository.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_create_branch".into(),
|
||||
"Create a new branch in a repository.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"branch_name": {
|
||||
"type": "string",
|
||||
"description": "Name for the new branch"
|
||||
},
|
||||
"from_branch": {
|
||||
"type": "string",
|
||||
"description": "Branch to create from (default: default branch)"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo", "branch_name"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_delete_branch".into(),
|
||||
"Delete a branch from a repository.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"owner": {
|
||||
"type": "string",
|
||||
"description": "Repository owner"
|
||||
},
|
||||
"repo": {
|
||||
"type": "string",
|
||||
"description": "Repository name"
|
||||
},
|
||||
"branch": {
|
||||
"type": "string",
|
||||
"description": "Branch name to delete"
|
||||
}
|
||||
},
|
||||
"required": ["owner", "repo", "branch"]
|
||||
}),
|
||||
),
|
||||
// ── Organizations (new) ─────────────────────────────────────────────
|
||||
Tool::new(
|
||||
"gitea_list_orgs".into(),
|
||||
"List organizations a user belongs to.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"username": {
|
||||
"type": "string",
|
||||
"description": "Username to list orgs for"
|
||||
}
|
||||
},
|
||||
"required": ["username"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"gitea_get_org".into(),
|
||||
"Get details about an organization.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"org": {
|
||||
"type": "string",
|
||||
"description": "Organization name"
|
||||
}
|
||||
},
|
||||
"required": ["org"]
|
||||
}),
|
||||
),
|
||||
// ── Notifications (new) ─────────────────────────────────────────────
|
||||
Tool::new(
|
||||
"gitea_list_notifications".into(),
|
||||
"List unread notifications for the requesting user.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}),
|
||||
),
|
||||
]
|
||||
}
|
||||
|
||||
209
src/tools/identity.rs
Normal file
209
src/tools/identity.rs
Normal file
@@ -0,0 +1,209 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::sdk::kratos::KratosClient;
|
||||
|
||||
/// Execute an identity tool call. Returns a JSON string result.
|
||||
pub async fn execute(
|
||||
kratos: &Arc<KratosClient>,
|
||||
name: &str,
|
||||
arguments: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let args: Value = serde_json::from_str(arguments)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid tool arguments: {e}"))?;
|
||||
|
||||
match name {
|
||||
"identity_list_users" => {
|
||||
let search = args["search"].as_str();
|
||||
let limit = args["limit"].as_u64().map(|n| n as u32);
|
||||
|
||||
match kratos.list_users(search, limit).await {
|
||||
Ok(users) => Ok(serde_json::to_string(&users).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"identity_get_user" => {
|
||||
let email_or_id = args["email_or_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'email_or_id'"))?;
|
||||
|
||||
match kratos.get_user(email_or_id).await {
|
||||
Ok(user) => Ok(serde_json::to_string(&user).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"identity_create_user" => {
|
||||
let email = args["email"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'email'"))?;
|
||||
let first_name = args["first_name"].as_str();
|
||||
let last_name = args["last_name"].as_str();
|
||||
|
||||
match kratos.create_user(email, first_name, last_name).await {
|
||||
Ok(user) => Ok(serde_json::to_string(&user).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"identity_recover_user" => {
|
||||
let email_or_id = args["email_or_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'email_or_id'"))?;
|
||||
|
||||
match kratos.recover_user(email_or_id).await {
|
||||
Ok(recovery) => Ok(serde_json::to_string(&recovery).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"identity_disable_user" => {
|
||||
let email_or_id = args["email_or_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'email_or_id'"))?;
|
||||
|
||||
match kratos.disable_user(email_or_id).await {
|
||||
Ok(user) => Ok(serde_json::to_string(&user).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"identity_enable_user" => {
|
||||
let email_or_id = args["email_or_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'email_or_id'"))?;
|
||||
|
||||
match kratos.enable_user(email_or_id).await {
|
||||
Ok(user) => Ok(serde_json::to_string(&user).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
"identity_list_sessions" => {
|
||||
let email_or_id = args["email_or_id"]
|
||||
.as_str()
|
||||
.ok_or_else(|| anyhow::anyhow!("missing 'email_or_id'"))?;
|
||||
|
||||
match kratos.list_sessions(email_or_id).await {
|
||||
Ok(sessions) => Ok(serde_json::to_string(&sessions).unwrap_or_default()),
|
||||
Err(e) => Ok(json!({"error": e}).to_string()),
|
||||
}
|
||||
}
|
||||
_ => anyhow::bail!("Unknown identity tool: {name}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Return Mistral tool definitions for identity tools.
|
||||
pub fn tool_definitions() -> Vec<mistralai_client::v1::tool::Tool> {
|
||||
use mistralai_client::v1::tool::Tool;
|
||||
|
||||
vec![
|
||||
Tool::new(
|
||||
"identity_list_users".into(),
|
||||
"List or search user accounts on the platform.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"search": {
|
||||
"type": "string",
|
||||
"description": "Search by email address or identifier"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results (default 50)"
|
||||
}
|
||||
}
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"identity_get_user".into(),
|
||||
"Get full details of a user account by email or ID.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email_or_id": {
|
||||
"type": "string",
|
||||
"description": "Email address or identity UUID"
|
||||
}
|
||||
},
|
||||
"required": ["email_or_id"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"identity_create_user".into(),
|
||||
"Create a new user account on the platform.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string",
|
||||
"description": "Email address for the new account"
|
||||
},
|
||||
"first_name": {
|
||||
"type": "string",
|
||||
"description": "First name"
|
||||
},
|
||||
"last_name": {
|
||||
"type": "string",
|
||||
"description": "Last name"
|
||||
}
|
||||
},
|
||||
"required": ["email"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"identity_recover_user".into(),
|
||||
"Generate a one-time recovery link for a user account. \
|
||||
Use this when someone is locked out or needs to reset their password."
|
||||
.into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email_or_id": {
|
||||
"type": "string",
|
||||
"description": "Email address or identity UUID"
|
||||
}
|
||||
},
|
||||
"required": ["email_or_id"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"identity_disable_user".into(),
|
||||
"Disable (lock out) a user account. They will not be able to log in.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email_or_id": {
|
||||
"type": "string",
|
||||
"description": "Email address or identity UUID"
|
||||
}
|
||||
},
|
||||
"required": ["email_or_id"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"identity_enable_user".into(),
|
||||
"Re-enable a previously disabled user account.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email_or_id": {
|
||||
"type": "string",
|
||||
"description": "Email address or identity UUID"
|
||||
}
|
||||
},
|
||||
"required": ["email_or_id"]
|
||||
}),
|
||||
),
|
||||
Tool::new(
|
||||
"identity_list_sessions".into(),
|
||||
"List active sessions for a user account.".into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email_or_id": {
|
||||
"type": "string",
|
||||
"description": "Email address or identity UUID"
|
||||
}
|
||||
},
|
||||
"required": ["email_or_id"]
|
||||
}),
|
||||
),
|
||||
]
|
||||
}
|
||||
152
src/tools/mod.rs
152
src/tools/mod.rs
@@ -1,26 +1,38 @@
|
||||
pub mod bridge;
|
||||
pub mod devtools;
|
||||
pub mod identity;
|
||||
pub mod research;
|
||||
pub mod room_history;
|
||||
pub mod web_search;
|
||||
pub mod room_info;
|
||||
pub mod script;
|
||||
pub mod search;
|
||||
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk::Client as MatrixClient;
|
||||
use matrix_sdk::RoomMemberships;
|
||||
use mistralai_client::v1::tool::Tool;
|
||||
use opensearch::OpenSearch;
|
||||
use serde_json::json;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::persistence::Store;
|
||||
use crate::sdk::gitea::GiteaClient;
|
||||
use crate::sdk::kratos::KratosClient;
|
||||
|
||||
|
||||
pub struct ToolRegistry {
|
||||
opensearch: OpenSearch,
|
||||
matrix: MatrixClient,
|
||||
config: Arc<Config>,
|
||||
gitea: Option<Arc<GiteaClient>>,
|
||||
kratos: Option<Arc<KratosClient>>,
|
||||
mistral: Option<Arc<mistralai_client::v1::client::Client>>,
|
||||
store: Option<Arc<Store>>,
|
||||
}
|
||||
|
||||
impl ToolRegistry {
|
||||
@@ -29,12 +41,18 @@ impl ToolRegistry {
|
||||
matrix: MatrixClient,
|
||||
config: Arc<Config>,
|
||||
gitea: Option<Arc<GiteaClient>>,
|
||||
kratos: Option<Arc<KratosClient>>,
|
||||
mistral: Option<Arc<mistralai_client::v1::client::Client>>,
|
||||
store: Option<Arc<Store>>,
|
||||
) -> Self {
|
||||
Self {
|
||||
opensearch,
|
||||
matrix,
|
||||
config,
|
||||
gitea,
|
||||
kratos,
|
||||
mistral,
|
||||
store,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -42,7 +60,11 @@ impl ToolRegistry {
|
||||
self.gitea.is_some()
|
||||
}
|
||||
|
||||
pub fn tool_definitions(gitea_enabled: bool) -> Vec<Tool> {
|
||||
pub fn has_kratos(&self) -> bool {
|
||||
self.kratos.is_some()
|
||||
}
|
||||
|
||||
pub fn tool_definitions(gitea_enabled: bool, kratos_enabled: bool) -> Vec<Tool> {
|
||||
let mut tools = vec![
|
||||
Tool::new(
|
||||
"search_archive".into(),
|
||||
@@ -172,14 +194,25 @@ impl ToolRegistry {
|
||||
if gitea_enabled {
|
||||
tools.extend(devtools::tool_definitions());
|
||||
}
|
||||
if kratos_enabled {
|
||||
tools.extend(identity::tool_definitions());
|
||||
}
|
||||
|
||||
// Web search (SearXNG — free, self-hosted)
|
||||
tools.push(web_search::tool_definition());
|
||||
|
||||
// Research tool (depth 0 — orchestrator level)
|
||||
if let Some(def) = research::tool_definition(4, 0) {
|
||||
tools.push(def);
|
||||
}
|
||||
|
||||
tools
|
||||
}
|
||||
|
||||
/// Convert Sol's tool definitions to Mistral AgentTool format
|
||||
/// for use with the Agents API (orchestrator agent creation).
|
||||
pub fn agent_tool_definitions(gitea_enabled: bool) -> Vec<mistralai_client::v1::agents::AgentTool> {
|
||||
Self::tool_definitions(gitea_enabled)
|
||||
pub fn agent_tool_definitions(gitea_enabled: bool, kratos_enabled: bool) -> Vec<mistralai_client::v1::agents::AgentTool> {
|
||||
Self::tool_definitions(gitea_enabled, kratos_enabled)
|
||||
.into_iter()
|
||||
.map(|t| {
|
||||
mistralai_client::v1::agents::AgentTool::function(
|
||||
@@ -191,6 +224,60 @@ impl ToolRegistry {
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Compute the set of room IDs whose search results are visible from
|
||||
/// the requesting room, based on member overlap.
|
||||
///
|
||||
/// A room's results are visible if at least ROOM_OVERLAP_THRESHOLD of
|
||||
/// its members are also members of the requesting room. This is enforced
|
||||
/// at the query level — Sol never sees filtered-out results.
|
||||
async fn allowed_room_ids(&self, requesting_room_id: &str) -> Vec<String> {
|
||||
let rooms = self.matrix.joined_rooms();
|
||||
|
||||
// Get requesting room's member set
|
||||
let requesting_room = rooms.iter().find(|r| r.room_id().as_str() == requesting_room_id);
|
||||
let requesting_members: HashSet<String> = match requesting_room {
|
||||
Some(room) => match room.members(RoomMemberships::JOIN).await {
|
||||
Ok(members) => members.iter().map(|m| m.user_id().to_string()).collect(),
|
||||
Err(_) => return vec![requesting_room_id.to_string()],
|
||||
},
|
||||
None => return vec![requesting_room_id.to_string()],
|
||||
};
|
||||
|
||||
let mut allowed = Vec::new();
|
||||
for room in &rooms {
|
||||
let room_id = room.room_id().to_string();
|
||||
|
||||
// Always allow the requesting room itself
|
||||
if room_id == requesting_room_id {
|
||||
allowed.push(room_id);
|
||||
continue;
|
||||
}
|
||||
|
||||
let members: HashSet<String> = match room.members(RoomMemberships::JOIN).await {
|
||||
Ok(m) => m.iter().map(|m| m.user_id().to_string()).collect(),
|
||||
Err(_) => continue,
|
||||
};
|
||||
|
||||
if members.is_empty() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let overlap = members.intersection(&requesting_members).count();
|
||||
let ratio = overlap as f64 / members.len() as f64;
|
||||
|
||||
if ratio >= self.config.behavior.room_overlap_threshold as f64 {
|
||||
debug!(
|
||||
source_room = room_id.as_str(),
|
||||
overlap_pct = format!("{:.0}%", ratio * 100.0).as_str(),
|
||||
"Room passes overlap threshold"
|
||||
);
|
||||
allowed.push(room_id);
|
||||
}
|
||||
}
|
||||
|
||||
allowed
|
||||
}
|
||||
|
||||
pub async fn execute(
|
||||
&self,
|
||||
name: &str,
|
||||
@@ -199,30 +286,36 @@ impl ToolRegistry {
|
||||
) -> anyhow::Result<String> {
|
||||
match name {
|
||||
"search_archive" => {
|
||||
let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
|
||||
search::search_archive(
|
||||
&self.opensearch,
|
||||
&self.config.opensearch.index,
|
||||
arguments,
|
||||
&allowed,
|
||||
)
|
||||
.await
|
||||
}
|
||||
"get_room_context" => {
|
||||
let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
|
||||
room_history::get_room_context(
|
||||
&self.opensearch,
|
||||
&self.config.opensearch.index,
|
||||
arguments,
|
||||
&allowed,
|
||||
)
|
||||
.await
|
||||
}
|
||||
"list_rooms" => room_info::list_rooms(&self.matrix).await,
|
||||
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await,
|
||||
"run_script" => {
|
||||
let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
|
||||
script::run_script(
|
||||
&self.opensearch,
|
||||
&self.matrix,
|
||||
&self.config,
|
||||
arguments,
|
||||
response_ctx,
|
||||
allowed,
|
||||
)
|
||||
.await
|
||||
}
|
||||
@@ -233,7 +326,60 @@ impl ToolRegistry {
|
||||
anyhow::bail!("Gitea integration not configured")
|
||||
}
|
||||
}
|
||||
name if name.starts_with("identity_") => {
|
||||
if let Some(ref kratos) = self.kratos {
|
||||
identity::execute(kratos, name, arguments).await
|
||||
} else {
|
||||
anyhow::bail!("Identity (Kratos) integration not configured")
|
||||
}
|
||||
}
|
||||
"search_web" => {
|
||||
if let Some(ref searxng) = self.config.services.searxng {
|
||||
web_search::search(&searxng.url, arguments).await
|
||||
} else {
|
||||
anyhow::bail!("Web search not configured (missing [services.searxng])")
|
||||
}
|
||||
}
|
||||
"research" => {
|
||||
if let (Some(ref mistral), Some(ref store)) = (&self.mistral, &self.store) {
|
||||
anyhow::bail!("research tool requires execute_research() — call with room + event_id context")
|
||||
} else {
|
||||
anyhow::bail!("Research not configured (missing mistral client or store)")
|
||||
}
|
||||
}
|
||||
_ => anyhow::bail!("Unknown tool: {name}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute a research tool call with full context (room, event_id for threads).
|
||||
pub async fn execute_research(
|
||||
self: &Arc<Self>,
|
||||
arguments: &str,
|
||||
response_ctx: &ResponseContext,
|
||||
room: &matrix_sdk::room::Room,
|
||||
event_id: &ruma::OwnedEventId,
|
||||
current_depth: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let mistral = self
|
||||
.mistral
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Research not configured: missing Mistral client"))?;
|
||||
let store = self
|
||||
.store
|
||||
.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("Research not configured: missing store"))?;
|
||||
|
||||
research::execute(
|
||||
arguments,
|
||||
&self.config,
|
||||
mistral,
|
||||
self,
|
||||
response_ctx,
|
||||
room,
|
||||
event_id,
|
||||
store,
|
||||
current_depth,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
519
src/tools/research.rs
Normal file
519
src/tools/research.rs
Normal file
@@ -0,0 +1,519 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use matrix_sdk::room::Room;
|
||||
use mistralai_client::v1::client::Client as MistralClient;
|
||||
use mistralai_client::v1::conversations::{
|
||||
AppendConversationRequest, ConversationInput, CreateConversationRequest,
|
||||
};
|
||||
use ruma::OwnedEventId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tokio::sync::mpsc;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::agent_ux::AgentProgress;
|
||||
use crate::config::Config;
|
||||
use crate::context::ResponseContext;
|
||||
use crate::persistence::Store;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
// ── Types ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResearchTask {
|
||||
pub focus: String,
|
||||
pub instructions: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ResearchResult {
|
||||
pub focus: String,
|
||||
pub findings: String,
|
||||
pub tool_calls_made: usize,
|
||||
pub status: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum ProgressUpdate {
|
||||
AgentStarted { focus: String },
|
||||
AgentDone { focus: String, summary: String },
|
||||
AgentFailed { focus: String, error: String },
|
||||
}
|
||||
|
||||
// ── Tool definition ────────────────────────────────────────────────────────
|
||||
|
||||
pub fn tool_definition(max_depth: usize, current_depth: usize) -> Option<mistralai_client::v1::tool::Tool> {
|
||||
if current_depth >= max_depth {
|
||||
return None; // At max depth, don't offer the research tool
|
||||
}
|
||||
|
||||
Some(mistralai_client::v1::tool::Tool::new(
|
||||
"research".into(),
|
||||
"Spawn parallel research agents to investigate a complex topic. Each agent \
|
||||
gets its own LLM conversation and can use all tools independently. Use this \
|
||||
for multi-faceted questions that need parallel investigation across repos, \
|
||||
archives, and the web. Each agent should have a focused, specific task."
|
||||
.into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"focus": {
|
||||
"type": "string",
|
||||
"description": "Short label (e.g., 'repo structure', 'license audit')"
|
||||
},
|
||||
"instructions": {
|
||||
"type": "string",
|
||||
"description": "Detailed instructions for this research agent"
|
||||
}
|
||||
},
|
||||
"required": ["focus", "instructions"]
|
||||
},
|
||||
"description": "List of parallel research tasks (3-25 recommended). Each gets its own agent."
|
||||
}
|
||||
},
|
||||
"required": ["tasks"]
|
||||
}),
|
||||
))
|
||||
}
|
||||
|
||||
// ── Execution ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Execute a research tool call — spawns parallel micro-agents.
|
||||
pub async fn execute(
|
||||
args: &str,
|
||||
config: &Arc<Config>,
|
||||
mistral: &Arc<MistralClient>,
|
||||
tools: &Arc<ToolRegistry>,
|
||||
response_ctx: &ResponseContext,
|
||||
room: &Room,
|
||||
event_id: &OwnedEventId,
|
||||
store: &Arc<Store>,
|
||||
current_depth: usize,
|
||||
) -> anyhow::Result<String> {
|
||||
let parsed: serde_json::Value = serde_json::from_str(args)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid research arguments: {e}"))?;
|
||||
|
||||
let tasks: Vec<ResearchTask> = serde_json::from_value(
|
||||
parsed.get("tasks").cloned().unwrap_or(json!([])),
|
||||
)
|
||||
.map_err(|e| anyhow::anyhow!("Invalid research tasks: {e}"))?;
|
||||
|
||||
if tasks.is_empty() {
|
||||
return Ok(json!({"error": "No research tasks provided"}).to_string());
|
||||
}
|
||||
|
||||
let max_agents = config.agents.research_max_agents;
|
||||
let tasks = if tasks.len() > max_agents {
|
||||
warn!(
|
||||
count = tasks.len(),
|
||||
max = max_agents,
|
||||
"Clamping research tasks to max"
|
||||
);
|
||||
tasks[..max_agents].to_vec()
|
||||
} else {
|
||||
tasks
|
||||
};
|
||||
|
||||
let session_id = uuid::Uuid::new_v4().to_string();
|
||||
let plan_json = serde_json::to_string(&tasks).unwrap_or_default();
|
||||
|
||||
// Persist session
|
||||
store.create_research_session(
|
||||
&session_id,
|
||||
&response_ctx.room_id,
|
||||
&event_id.to_string(),
|
||||
&format!("research (depth {})", current_depth),
|
||||
&plan_json,
|
||||
);
|
||||
|
||||
info!(
|
||||
session_id = session_id.as_str(),
|
||||
agents = tasks.len(),
|
||||
depth = current_depth,
|
||||
"Starting research session"
|
||||
);
|
||||
|
||||
// Progress channel for thread updates
|
||||
let (tx, mut rx) = mpsc::channel::<ProgressUpdate>(64);
|
||||
|
||||
// Spawn thread updater
|
||||
let thread_room = room.clone();
|
||||
let thread_event_id = event_id.clone();
|
||||
let agent_count = tasks.len();
|
||||
let updater = tokio::spawn(async move {
|
||||
let mut progress = AgentProgress::new(thread_room, thread_event_id);
|
||||
progress
|
||||
.post_step(&format!("🔬 researching with {} agents...", agent_count))
|
||||
.await;
|
||||
|
||||
while let Some(update) = rx.recv().await {
|
||||
let msg = match update {
|
||||
ProgressUpdate::AgentStarted { focus } => {
|
||||
format!("🔎 {focus}")
|
||||
}
|
||||
ProgressUpdate::AgentDone { focus, summary } => {
|
||||
let short: String = summary.chars().take(100).collect();
|
||||
format!("✅ {focus}: {short}")
|
||||
}
|
||||
ProgressUpdate::AgentFailed { focus, error } => {
|
||||
format!("❌ {focus}: {error}")
|
||||
}
|
||||
};
|
||||
progress.post_step(&msg).await;
|
||||
}
|
||||
});
|
||||
|
||||
// Create per-agent senders before dropping the original
|
||||
let agent_senders: Vec<_> = tasks.iter().map(|_| tx.clone()).collect();
|
||||
drop(tx); // Drop original so updater knows when all agents are done
|
||||
|
||||
// Run all research agents concurrently with per-agent timeout.
|
||||
// Without timeout, a hung Mistral API call blocks the entire sync loop.
|
||||
let agent_timeout = std::time::Duration::from_secs(120); // 2 minutes per agent max
|
||||
|
||||
let futures: Vec<_> = tasks
|
||||
.iter()
|
||||
.zip(agent_senders.iter())
|
||||
.map(|(task, sender)| {
|
||||
let task = task.clone();
|
||||
let sender = sender.clone();
|
||||
let sid = session_id.clone();
|
||||
async move {
|
||||
match tokio::time::timeout(
|
||||
agent_timeout,
|
||||
run_research_agent(
|
||||
&task,
|
||||
config,
|
||||
mistral,
|
||||
tools,
|
||||
response_ctx,
|
||||
&sender,
|
||||
&sid,
|
||||
store,
|
||||
room,
|
||||
event_id,
|
||||
current_depth,
|
||||
),
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(_) => {
|
||||
warn!(focus = task.focus.as_str(), "Research agent timed out");
|
||||
let _ = sender
|
||||
.send(ProgressUpdate::AgentFailed {
|
||||
focus: task.focus.clone(),
|
||||
error: "timed out after 2 minutes".into(),
|
||||
})
|
||||
.await;
|
||||
ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings: "Agent timed out after 2 minutes".into(),
|
||||
tool_calls_made: 0,
|
||||
status: "timeout".into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let results = futures::future::join_all(futures).await;
|
||||
|
||||
// Wait for thread updater to finish
|
||||
let _ = updater.await;
|
||||
|
||||
// Mark session complete
|
||||
store.complete_research_session(&session_id);
|
||||
|
||||
// Format results for the orchestrator
|
||||
let total_calls: usize = results.iter().map(|r| r.tool_calls_made).sum();
|
||||
info!(
|
||||
session_id = session_id.as_str(),
|
||||
agents = results.len(),
|
||||
total_tool_calls = total_calls,
|
||||
"Research session complete"
|
||||
);
|
||||
|
||||
let output = results
|
||||
.iter()
|
||||
.map(|r| format!("### {} [{}]\n{}\n", r.focus, r.status, r.findings))
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n---\n\n");
|
||||
|
||||
Ok(format!(
|
||||
"Research complete ({} agents, {} tool calls):\n\n{}",
|
||||
results.len(),
|
||||
total_calls,
|
||||
output
|
||||
))
|
||||
}
|
||||
|
||||
/// Run a single research micro-agent.
|
||||
async fn run_research_agent(
|
||||
task: &ResearchTask,
|
||||
config: &Arc<Config>,
|
||||
mistral: &Arc<MistralClient>,
|
||||
tools: &Arc<ToolRegistry>,
|
||||
response_ctx: &ResponseContext,
|
||||
tx: &mpsc::Sender<ProgressUpdate>,
|
||||
session_id: &str,
|
||||
store: &Arc<Store>,
|
||||
room: &Room,
|
||||
event_id: &OwnedEventId,
|
||||
current_depth: usize,
|
||||
) -> ResearchResult {
|
||||
let _ = tx
|
||||
.send(ProgressUpdate::AgentStarted {
|
||||
focus: task.focus.clone(),
|
||||
})
|
||||
.await;
|
||||
|
||||
let model = &config.agents.research_model;
|
||||
let max_iterations = config.agents.research_max_iterations;
|
||||
|
||||
// Build tool definitions (include research tool if not at max depth)
|
||||
let mut tool_defs = ToolRegistry::tool_definitions(
|
||||
tools.has_gitea(),
|
||||
tools.has_kratos(),
|
||||
);
|
||||
if let Some(research_def) = tool_definition(config.agents.research_max_depth, current_depth + 1) {
|
||||
tool_defs.push(research_def);
|
||||
}
|
||||
|
||||
let mistral_tools: Vec<mistralai_client::v1::tool::Tool> = tool_defs;
|
||||
|
||||
let instructions = format!(
|
||||
"You are a focused research agent. Your task:\n\n\
|
||||
**Focus:** {}\n\n\
|
||||
**Instructions:** {}\n\n\
|
||||
Use the available tools to investigate. Be thorough but focused. \
|
||||
When done, provide a clear summary of your findings.",
|
||||
task.focus, task.instructions
|
||||
);
|
||||
|
||||
// Create conversation
|
||||
let req = CreateConversationRequest {
|
||||
inputs: ConversationInput::Text(instructions),
|
||||
model: Some(model.clone()),
|
||||
agent_id: None,
|
||||
agent_version: None,
|
||||
name: Some(format!("sol-research-{}", &session_id[..8])),
|
||||
description: None,
|
||||
instructions: None,
|
||||
completion_args: None,
|
||||
tools: Some(
|
||||
mistral_tools
|
||||
.into_iter()
|
||||
.map(|t| {
|
||||
mistralai_client::v1::agents::AgentTool::function(
|
||||
t.function.name,
|
||||
t.function.description,
|
||||
t.function.parameters,
|
||||
)
|
||||
})
|
||||
.collect(),
|
||||
),
|
||||
handoff_execution: None,
|
||||
metadata: None,
|
||||
store: Some(false), // Don't persist research conversations on Mistral's side
|
||||
stream: false,
|
||||
};
|
||||
|
||||
let response = match mistral.create_conversation_async(&req).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let error = format!("Failed to create research conversation: {}", e.message);
|
||||
let _ = tx
|
||||
.send(ProgressUpdate::AgentFailed {
|
||||
focus: task.focus.clone(),
|
||||
error: error.clone(),
|
||||
})
|
||||
.await;
|
||||
return ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings: error,
|
||||
tool_calls_made: 0,
|
||||
status: "failed".into(),
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
let conv_id = response.conversation_id.clone();
|
||||
let mut current_response = response;
|
||||
let mut tool_calls_made = 0;
|
||||
|
||||
// Tool call loop
|
||||
for _iteration in 0..max_iterations {
|
||||
let calls = current_response.function_calls();
|
||||
if calls.is_empty() {
|
||||
break;
|
||||
}
|
||||
|
||||
let mut result_entries = Vec::new();
|
||||
|
||||
for fc in &calls {
|
||||
let call_id = fc.tool_call_id.as_deref().unwrap_or("unknown");
|
||||
tool_calls_made += 1;
|
||||
|
||||
debug!(
|
||||
focus = task.focus.as_str(),
|
||||
tool = fc.name.as_str(),
|
||||
"Research agent tool call"
|
||||
);
|
||||
|
||||
let result = if fc.name == "research" {
|
||||
// Recursive research — spawn sub-agents
|
||||
match execute(
|
||||
&fc.arguments,
|
||||
config,
|
||||
mistral,
|
||||
tools,
|
||||
response_ctx,
|
||||
room,
|
||||
event_id,
|
||||
store,
|
||||
current_depth + 1,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(s) => s,
|
||||
Err(e) => format!("Research error: {e}"),
|
||||
}
|
||||
} else {
|
||||
match tools.execute(&fc.name, &fc.arguments, response_ctx).await {
|
||||
Ok(s) => s,
|
||||
Err(e) => format!("Error: {e}"),
|
||||
}
|
||||
};
|
||||
|
||||
result_entries.push(
|
||||
mistralai_client::v1::conversations::ConversationEntry::FunctionResult(
|
||||
mistralai_client::v1::conversations::FunctionResultEntry {
|
||||
tool_call_id: call_id.to_string(),
|
||||
result,
|
||||
id: None,
|
||||
object: None,
|
||||
created_at: None,
|
||||
completed_at: None,
|
||||
},
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// Send results back
|
||||
let append_req = AppendConversationRequest {
|
||||
inputs: ConversationInput::Entries(result_entries),
|
||||
completion_args: None,
|
||||
handoff_execution: None,
|
||||
store: Some(false),
|
||||
tool_confirmations: None,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
current_response = match mistral
|
||||
.append_conversation_async(&conv_id, &append_req)
|
||||
.await
|
||||
{
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let error = format!("Research agent conversation failed: {}", e.message);
|
||||
let _ = tx
|
||||
.send(ProgressUpdate::AgentFailed {
|
||||
focus: task.focus.clone(),
|
||||
error: error.clone(),
|
||||
})
|
||||
.await;
|
||||
return ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings: error,
|
||||
tool_calls_made,
|
||||
status: "failed".into(),
|
||||
};
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// Extract final text
|
||||
let findings = current_response
|
||||
.assistant_text()
|
||||
.unwrap_or_else(|| format!("(no summary after {} tool calls)", tool_calls_made));
|
||||
|
||||
// Persist finding
|
||||
let finding_json = serde_json::to_string(&ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings: findings.clone(),
|
||||
tool_calls_made,
|
||||
status: "complete".into(),
|
||||
})
|
||||
.unwrap_or_default();
|
||||
store.append_research_finding(session_id, &finding_json);
|
||||
|
||||
let summary: String = findings.chars().take(100).collect();
|
||||
let _ = tx
|
||||
.send(ProgressUpdate::AgentDone {
|
||||
focus: task.focus.clone(),
|
||||
summary,
|
||||
})
|
||||
.await;
|
||||
|
||||
ResearchResult {
|
||||
focus: task.focus.clone(),
|
||||
findings,
|
||||
tool_calls_made,
|
||||
status: "complete".into(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_research_task_deserialize() {
|
||||
let json = json!({
|
||||
"focus": "repo structure",
|
||||
"instructions": "browse studio/sbbb root directory"
|
||||
});
|
||||
let task: ResearchTask = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(task.focus, "repo structure");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_research_result_serialize() {
|
||||
let result = ResearchResult {
|
||||
focus: "licensing".into(),
|
||||
findings: "found AGPL in 2 repos".into(),
|
||||
tool_calls_made: 5,
|
||||
status: "complete".into(),
|
||||
};
|
||||
let json = serde_json::to_string(&result).unwrap();
|
||||
assert!(json.contains("AGPL"));
|
||||
assert!(json.contains("\"tool_calls_made\":5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition_available_at_depth_0() {
|
||||
assert!(tool_definition(4, 0).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition_available_at_depth_3() {
|
||||
assert!(tool_definition(4, 3).is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition_unavailable_at_max_depth() {
|
||||
assert!(tool_definition(4, 4).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition_unavailable_beyond_max() {
|
||||
assert!(tool_definition(4, 5).is_none());
|
||||
}
|
||||
}
|
||||
@@ -21,8 +21,14 @@ pub async fn get_room_context(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
args_json: &str,
|
||||
allowed_room_ids: &[String],
|
||||
) -> anyhow::Result<String> {
|
||||
let args: RoomHistoryArgs = serde_json::from_str(args_json)?;
|
||||
|
||||
// Enforce room overlap — reject if the requested room isn't in the allowed set
|
||||
if !allowed_room_ids.is_empty() && !allowed_room_ids.contains(&args.room_id) {
|
||||
return Ok("Access denied: you don't have visibility into that room from here.".into());
|
||||
}
|
||||
let total = args.before_count + args.after_count + 1;
|
||||
|
||||
// Determine the pivot timestamp
|
||||
|
||||
@@ -26,6 +26,7 @@ struct ScriptState {
|
||||
config: Arc<Config>,
|
||||
tmpdir: PathBuf,
|
||||
user_id: String,
|
||||
allowed_room_ids: Vec<String>,
|
||||
}
|
||||
|
||||
struct ScriptOutput(String);
|
||||
@@ -83,17 +84,21 @@ async fn op_sol_search(
|
||||
#[string] query: String,
|
||||
#[string] opts_json: String,
|
||||
) -> Result<String, JsErrorBox> {
|
||||
let (os, index) = {
|
||||
let (os, index, allowed) = {
|
||||
let st = state.borrow();
|
||||
let ss = st.borrow::<ScriptState>();
|
||||
(ss.opensearch.clone(), ss.config.opensearch.index.clone())
|
||||
(
|
||||
ss.opensearch.clone(),
|
||||
ss.config.opensearch.index.clone(),
|
||||
ss.allowed_room_ids.clone(),
|
||||
)
|
||||
};
|
||||
|
||||
let mut args: serde_json::Value =
|
||||
serde_json::from_str(&opts_json).unwrap_or(serde_json::json!({}));
|
||||
args["query"] = serde_json::Value::String(query);
|
||||
|
||||
super::search::search_archive(&os, &index, &args.to_string())
|
||||
super::search::search_archive(&os, &index, &args.to_string(), &allowed)
|
||||
.await
|
||||
.map_err(|e| JsErrorBox::generic(e.to_string()))
|
||||
}
|
||||
@@ -429,6 +434,7 @@ pub async fn run_script(
|
||||
config: &Config,
|
||||
args_json: &str,
|
||||
response_ctx: &ResponseContext,
|
||||
allowed_room_ids: Vec<String>,
|
||||
) -> anyhow::Result<String> {
|
||||
let args: RunScriptArgs = serde_json::from_str(args_json)?;
|
||||
let code = args.code.clone();
|
||||
@@ -494,6 +500,7 @@ pub async fn run_script(
|
||||
config: cfg,
|
||||
tmpdir: tmpdir_path,
|
||||
user_id,
|
||||
allowed_room_ids,
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@@ -23,7 +23,8 @@ pub struct SearchArgs {
|
||||
fn default_limit() -> usize { 10 }
|
||||
|
||||
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
|
||||
pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
|
||||
/// `allowed_room_ids` restricts results to rooms that pass the member overlap check.
|
||||
pub fn build_search_query(args: &SearchArgs, allowed_room_ids: &[String]) -> serde_json::Value {
|
||||
// Handle empty/wildcard queries as match_all
|
||||
let must = if args.query.is_empty() || args.query == "*" {
|
||||
vec![json!({ "match_all": {} })]
|
||||
@@ -37,6 +38,12 @@ pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
|
||||
"term": { "redacted": false }
|
||||
})];
|
||||
|
||||
// Restrict to rooms that pass the member overlap threshold.
|
||||
// This is a system-level security filter — Sol never sees results from excluded rooms.
|
||||
if !allowed_room_ids.is_empty() {
|
||||
filter.push(json!({ "terms": { "room_id": allowed_room_ids } }));
|
||||
}
|
||||
|
||||
if let Some(ref room) = args.room {
|
||||
filter.push(json!({ "term": { "room_name": room } }));
|
||||
}
|
||||
@@ -76,9 +83,10 @@ pub async fn search_archive(
|
||||
client: &OpenSearch,
|
||||
index: &str,
|
||||
args_json: &str,
|
||||
allowed_room_ids: &[String],
|
||||
) -> anyhow::Result<String> {
|
||||
let args: SearchArgs = serde_json::from_str(args_json)?;
|
||||
let query_body = build_search_query(&args);
|
||||
let query_body = build_search_query(&args, allowed_room_ids);
|
||||
|
||||
info!(
|
||||
query = args.query.as_str(),
|
||||
@@ -174,7 +182,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_query_basic() {
|
||||
let args = parse_args(r#"{"query": "test"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
assert_eq!(q["size"], 10);
|
||||
assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test");
|
||||
@@ -185,7 +193,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_query_with_room_filter() {
|
||||
let args = parse_args(r#"{"query": "hello", "room": "design"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 2);
|
||||
@@ -195,7 +203,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_query_with_sender_filter() {
|
||||
let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 2);
|
||||
@@ -205,7 +213,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_query_with_room_and_sender() {
|
||||
let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
assert_eq!(filters.len(), 3);
|
||||
@@ -220,7 +228,7 @@ mod tests {
|
||||
"after": "1710000000000",
|
||||
"before": "1710100000000"
|
||||
}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
let range_filter = &filters[1]["range"]["timestamp"];
|
||||
@@ -231,7 +239,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_query_with_after_only() {
|
||||
let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
let range_filter = &filters[1]["range"]["timestamp"];
|
||||
@@ -242,7 +250,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_query_with_custom_limit() {
|
||||
let args = parse_args(r#"{"query": "hello", "limit": 50}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
assert_eq!(q["size"], 50);
|
||||
}
|
||||
|
||||
@@ -256,7 +264,7 @@ mod tests {
|
||||
"before": "2000",
|
||||
"limit": 5
|
||||
}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
assert_eq!(q["size"], 5);
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
@@ -267,7 +275,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_invalid_timestamp_ignored() {
|
||||
let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
// Only the redacted filter, no range since parse failed
|
||||
@@ -277,14 +285,14 @@ mod tests {
|
||||
#[test]
|
||||
fn test_wildcard_query_uses_match_all() {
|
||||
let args = parse_args(r#"{"query": "*"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
assert!(q["query"]["bool"]["must"][0]["match_all"].is_object());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_query_uses_match_all() {
|
||||
let args = parse_args(r#"{"query": ""}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
assert!(q["query"]["bool"]["must"][0]["match_all"].is_object());
|
||||
}
|
||||
|
||||
@@ -292,7 +300,7 @@ mod tests {
|
||||
fn test_room_filter_uses_keyword_field() {
|
||||
// room_name is mapped as "keyword" in OpenSearch — no .keyword subfield
|
||||
let args = parse_args(r#"{"query": "test", "room": "general"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
|
||||
// Should be room_name, NOT room_name.keyword
|
||||
assert_eq!(filters[1]["term"]["room_name"], "general");
|
||||
@@ -301,7 +309,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_source_fields() {
|
||||
let args = parse_args(r#"{"query": "test"}"#);
|
||||
let q = build_search_query(&args);
|
||||
let q = build_search_query(&args, &[]);
|
||||
|
||||
let source = q["_source"].as_array().unwrap();
|
||||
let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect();
|
||||
|
||||
176
src/tools/web_search.rs
Normal file
176
src/tools/web_search.rs
Normal file
@@ -0,0 +1,176 @@
|
||||
use reqwest::Client as HttpClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use tracing::{debug, info};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SearxngResponse {
|
||||
#[serde(default)]
|
||||
results: Vec<SearxngResult>,
|
||||
#[serde(default)]
|
||||
number_of_results: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct SearxngResult {
|
||||
#[serde(default)]
|
||||
title: String,
|
||||
#[serde(default)]
|
||||
url: String,
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
engine: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct SearchArgs {
|
||||
query: String,
|
||||
#[serde(default = "default_limit")]
|
||||
limit: usize,
|
||||
}
|
||||
|
||||
fn default_limit() -> usize {
|
||||
5
|
||||
}
|
||||
|
||||
/// Execute a web search via SearXNG.
|
||||
pub async fn search(
|
||||
searxng_url: &str,
|
||||
args_json: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let args: SearchArgs = serde_json::from_str(args_json)?;
|
||||
|
||||
let query_encoded = url::form_urlencoded::byte_serialize(args.query.as_bytes())
|
||||
.collect::<String>();
|
||||
|
||||
let url = format!(
|
||||
"{}/search?q={}&format=json&language=en",
|
||||
searxng_url.trim_end_matches('/'),
|
||||
query_encoded,
|
||||
);
|
||||
|
||||
info!(query = args.query.as_str(), limit = args.limit, "Web search via SearXNG");
|
||||
|
||||
let client = HttpClient::new();
|
||||
let resp = client
|
||||
.get(&url)
|
||||
.header("Accept", "application/json")
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("SearXNG request failed: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let text = resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!("SearXNG search failed (HTTP {status}): {text}");
|
||||
}
|
||||
|
||||
let data: SearxngResponse = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse SearXNG response: {e}"))?;
|
||||
|
||||
if data.results.is_empty() {
|
||||
return Ok("No web search results found.".into());
|
||||
}
|
||||
|
||||
let limit = args.limit.min(data.results.len());
|
||||
let results = &data.results[..limit];
|
||||
|
||||
debug!(
|
||||
query = args.query.as_str(),
|
||||
total = data.number_of_results as u64,
|
||||
returned = results.len(),
|
||||
"SearXNG results"
|
||||
);
|
||||
|
||||
// Format results for the LLM
|
||||
let mut output = format!("Web search results for \"{}\":\n\n", args.query);
|
||||
for (i, r) in results.iter().enumerate() {
|
||||
output.push_str(&format!(
|
||||
"{}. **{}**\n {}\n {}\n\n",
|
||||
i + 1,
|
||||
r.title,
|
||||
r.url,
|
||||
if r.content.is_empty() {
|
||||
"(no snippet)".to_string()
|
||||
} else {
|
||||
r.content.clone()
|
||||
},
|
||||
));
|
||||
}
|
||||
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
pub fn tool_definition() -> mistralai_client::v1::tool::Tool {
|
||||
mistralai_client::v1::tool::Tool::new(
|
||||
"search_web".into(),
|
||||
"Search the web via SearXNG. Returns titles, URLs, and snippets from \
|
||||
DuckDuckGo, Wikipedia, StackOverflow, GitHub, and other free engines. \
|
||||
Use for current events, product info, documentation, or anything you're \
|
||||
not certain about. Free and self-hosted — use liberally."
|
||||
.into(),
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "Search query"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "Max results to return (default 5)"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_parse_search_args() {
|
||||
let args: SearchArgs =
|
||||
serde_json::from_str(r#"{"query": "mistral vibe"}"#).unwrap();
|
||||
assert_eq!(args.query, "mistral vibe");
|
||||
assert_eq!(args.limit, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_search_args_with_limit() {
|
||||
let args: SearchArgs =
|
||||
serde_json::from_str(r#"{"query": "rust async", "limit": 10}"#).unwrap();
|
||||
assert_eq!(args.limit, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_searxng_result_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"title": "Mistral AI",
|
||||
"url": "https://mistral.ai",
|
||||
"content": "A leading AI company",
|
||||
"engine": "duckduckgo"
|
||||
});
|
||||
let result: SearxngResult = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(result.title, "Mistral AI");
|
||||
assert_eq!(result.engine, "duckduckgo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_searxng_response_empty() {
|
||||
let json = serde_json::json!({"results": [], "number_of_results": 0.0});
|
||||
let resp: SearxngResponse = serde_json::from_value(json).unwrap();
|
||||
assert!(resp.results.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tool_definition() {
|
||||
let def = tool_definition();
|
||||
assert_eq!(def.function.name, "search_web");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user