diff --git a/src/tools/mod.rs b/src/tools/mod.rs index cee6c05..d36a66c 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,26 +1,37 @@ pub mod bridge; pub mod devtools; +pub mod identity; +pub mod research; pub mod room_history; pub mod room_info; pub mod script; pub mod search; +use std::collections::HashSet; use std::sync::Arc; use matrix_sdk::Client as MatrixClient; +use matrix_sdk::RoomMemberships; use mistralai_client::v1::tool::Tool; use opensearch::OpenSearch; use serde_json::json; +use tracing::debug; use crate::config::Config; use crate::context::ResponseContext; +use crate::persistence::Store; use crate::sdk::gitea::GiteaClient; +use crate::sdk::kratos::KratosClient; + pub struct ToolRegistry { opensearch: OpenSearch, matrix: MatrixClient, config: Arc, gitea: Option>, + kratos: Option>, + mistral: Option>, + store: Option>, } impl ToolRegistry { @@ -29,12 +40,18 @@ impl ToolRegistry { matrix: MatrixClient, config: Arc, gitea: Option>, + kratos: Option>, + mistral: Option>, + store: Option>, ) -> Self { Self { opensearch, matrix, config, gitea, + kratos, + mistral, + store, } } @@ -42,7 +59,11 @@ impl ToolRegistry { self.gitea.is_some() } - pub fn tool_definitions(gitea_enabled: bool) -> Vec { + pub fn has_kratos(&self) -> bool { + self.kratos.is_some() + } + + pub fn tool_definitions(gitea_enabled: bool, kratos_enabled: bool) -> Vec { let mut tools = vec![ Tool::new( "search_archive".into(), @@ -172,14 +193,22 @@ impl ToolRegistry { if gitea_enabled { tools.extend(devtools::tool_definitions()); } + if kratos_enabled { + tools.extend(identity::tool_definitions()); + } + + // Research tool (depth 0 — orchestrator level) + if let Some(def) = research::tool_definition(4, 0) { + tools.push(def); + } tools } /// Convert Sol's tool definitions to Mistral AgentTool format /// for use with the Agents API (orchestrator agent creation). - pub fn agent_tool_definitions(gitea_enabled: bool) -> Vec { - Self::tool_definitions(gitea_enabled) + pub fn agent_tool_definitions(gitea_enabled: bool, kratos_enabled: bool) -> Vec { + Self::tool_definitions(gitea_enabled, kratos_enabled) .into_iter() .map(|t| { mistralai_client::v1::agents::AgentTool::function( @@ -191,6 +220,60 @@ impl ToolRegistry { .collect() } + /// Compute the set of room IDs whose search results are visible from + /// the requesting room, based on member overlap. + /// + /// A room's results are visible if at least ROOM_OVERLAP_THRESHOLD of + /// its members are also members of the requesting room. This is enforced + /// at the query level — Sol never sees filtered-out results. + async fn allowed_room_ids(&self, requesting_room_id: &str) -> Vec { + let rooms = self.matrix.joined_rooms(); + + // Get requesting room's member set + let requesting_room = rooms.iter().find(|r| r.room_id().as_str() == requesting_room_id); + let requesting_members: HashSet = match requesting_room { + Some(room) => match room.members(RoomMemberships::JOIN).await { + Ok(members) => members.iter().map(|m| m.user_id().to_string()).collect(), + Err(_) => return vec![requesting_room_id.to_string()], + }, + None => return vec![requesting_room_id.to_string()], + }; + + let mut allowed = Vec::new(); + for room in &rooms { + let room_id = room.room_id().to_string(); + + // Always allow the requesting room itself + if room_id == requesting_room_id { + allowed.push(room_id); + continue; + } + + let members: HashSet = match room.members(RoomMemberships::JOIN).await { + Ok(m) => m.iter().map(|m| m.user_id().to_string()).collect(), + Err(_) => continue, + }; + + if members.is_empty() { + continue; + } + + let overlap = members.intersection(&requesting_members).count(); + let ratio = overlap as f64 / members.len() as f64; + + if ratio >= self.config.behavior.room_overlap_threshold as f64 { + debug!( + source_room = room_id.as_str(), + overlap_pct = format!("{:.0}%", ratio * 100.0).as_str(), + "Room passes overlap threshold" + ); + allowed.push(room_id); + } + } + + allowed + } + pub async fn execute( &self, name: &str, @@ -199,30 +282,36 @@ impl ToolRegistry { ) -> anyhow::Result { match name { "search_archive" => { + let allowed = self.allowed_room_ids(&response_ctx.room_id).await; search::search_archive( &self.opensearch, &self.config.opensearch.index, arguments, + &allowed, ) .await } "get_room_context" => { + let allowed = self.allowed_room_ids(&response_ctx.room_id).await; room_history::get_room_context( &self.opensearch, &self.config.opensearch.index, arguments, + &allowed, ) .await } "list_rooms" => room_info::list_rooms(&self.matrix).await, "get_room_members" => room_info::get_room_members(&self.matrix, arguments).await, "run_script" => { + let allowed = self.allowed_room_ids(&response_ctx.room_id).await; script::run_script( &self.opensearch, &self.matrix, &self.config, arguments, response_ctx, + allowed, ) .await } @@ -233,7 +322,53 @@ impl ToolRegistry { anyhow::bail!("Gitea integration not configured") } } + name if name.starts_with("identity_") => { + if let Some(ref kratos) = self.kratos { + identity::execute(kratos, name, arguments).await + } else { + anyhow::bail!("Identity (Kratos) integration not configured") + } + } + "research" => { + if let (Some(ref mistral), Some(ref store)) = (&self.mistral, &self.store) { + anyhow::bail!("research tool requires execute_research() — call with room + event_id context") + } else { + anyhow::bail!("Research not configured (missing mistral client or store)") + } + } _ => anyhow::bail!("Unknown tool: {name}"), } } + + /// Execute a research tool call with full context (room, event_id for threads). + pub async fn execute_research( + self: &Arc, + arguments: &str, + response_ctx: &ResponseContext, + room: &matrix_sdk::room::Room, + event_id: &ruma::OwnedEventId, + current_depth: usize, + ) -> anyhow::Result { + let mistral = self + .mistral + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Research not configured: missing Mistral client"))?; + let store = self + .store + .as_ref() + .ok_or_else(|| anyhow::anyhow!("Research not configured: missing store"))?; + + research::execute( + arguments, + &self.config, + mistral, + self, + response_ctx, + room, + event_id, + store, + current_depth, + ) + .await + } } diff --git a/src/tools/room_history.rs b/src/tools/room_history.rs index d82131c..3e3424b 100644 --- a/src/tools/room_history.rs +++ b/src/tools/room_history.rs @@ -21,8 +21,14 @@ pub async fn get_room_context( client: &OpenSearch, index: &str, args_json: &str, + allowed_room_ids: &[String], ) -> anyhow::Result { let args: RoomHistoryArgs = serde_json::from_str(args_json)?; + + // Enforce room overlap — reject if the requested room isn't in the allowed set + if !allowed_room_ids.is_empty() && !allowed_room_ids.contains(&args.room_id) { + return Ok("Access denied: you don't have visibility into that room from here.".into()); + } let total = args.before_count + args.after_count + 1; // Determine the pivot timestamp diff --git a/src/tools/script.rs b/src/tools/script.rs index 49dd8c4..8e51c1b 100644 --- a/src/tools/script.rs +++ b/src/tools/script.rs @@ -26,6 +26,7 @@ struct ScriptState { config: Arc, tmpdir: PathBuf, user_id: String, + allowed_room_ids: Vec, } struct ScriptOutput(String); @@ -83,17 +84,21 @@ async fn op_sol_search( #[string] query: String, #[string] opts_json: String, ) -> Result { - let (os, index) = { + let (os, index, allowed) = { let st = state.borrow(); let ss = st.borrow::(); - (ss.opensearch.clone(), ss.config.opensearch.index.clone()) + ( + ss.opensearch.clone(), + ss.config.opensearch.index.clone(), + ss.allowed_room_ids.clone(), + ) }; let mut args: serde_json::Value = serde_json::from_str(&opts_json).unwrap_or(serde_json::json!({})); args["query"] = serde_json::Value::String(query); - super::search::search_archive(&os, &index, &args.to_string()) + super::search::search_archive(&os, &index, &args.to_string(), &allowed) .await .map_err(|e| JsErrorBox::generic(e.to_string())) } @@ -429,6 +434,7 @@ pub async fn run_script( config: &Config, args_json: &str, response_ctx: &ResponseContext, + allowed_room_ids: Vec, ) -> anyhow::Result { let args: RunScriptArgs = serde_json::from_str(args_json)?; let code = args.code.clone(); @@ -494,6 +500,7 @@ pub async fn run_script( config: cfg, tmpdir: tmpdir_path, user_id, + allowed_room_ids, }); } diff --git a/src/tools/search.rs b/src/tools/search.rs index f761820..67a3a38 100644 --- a/src/tools/search.rs +++ b/src/tools/search.rs @@ -23,7 +23,8 @@ pub struct SearchArgs { fn default_limit() -> usize { 10 } /// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability. -pub fn build_search_query(args: &SearchArgs) -> serde_json::Value { +/// `allowed_room_ids` restricts results to rooms that pass the member overlap check. +pub fn build_search_query(args: &SearchArgs, allowed_room_ids: &[String]) -> serde_json::Value { // Handle empty/wildcard queries as match_all let must = if args.query.is_empty() || args.query == "*" { vec![json!({ "match_all": {} })] @@ -37,6 +38,12 @@ pub fn build_search_query(args: &SearchArgs) -> serde_json::Value { "term": { "redacted": false } })]; + // Restrict to rooms that pass the member overlap threshold. + // This is a system-level security filter — Sol never sees results from excluded rooms. + if !allowed_room_ids.is_empty() { + filter.push(json!({ "terms": { "room_id": allowed_room_ids } })); + } + if let Some(ref room) = args.room { filter.push(json!({ "term": { "room_name": room } })); } @@ -76,9 +83,10 @@ pub async fn search_archive( client: &OpenSearch, index: &str, args_json: &str, + allowed_room_ids: &[String], ) -> anyhow::Result { let args: SearchArgs = serde_json::from_str(args_json)?; - let query_body = build_search_query(&args); + let query_body = build_search_query(&args, allowed_room_ids); info!( query = args.query.as_str(), @@ -174,7 +182,7 @@ mod tests { #[test] fn test_query_basic() { let args = parse_args(r#"{"query": "test"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); assert_eq!(q["size"], 10); assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test"); @@ -185,7 +193,7 @@ mod tests { #[test] fn test_query_with_room_filter() { let args = parse_args(r#"{"query": "hello", "room": "design"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); let filters = q["query"]["bool"]["filter"].as_array().unwrap(); assert_eq!(filters.len(), 2); @@ -195,7 +203,7 @@ mod tests { #[test] fn test_query_with_sender_filter() { let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); let filters = q["query"]["bool"]["filter"].as_array().unwrap(); assert_eq!(filters.len(), 2); @@ -205,7 +213,7 @@ mod tests { #[test] fn test_query_with_room_and_sender() { let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); let filters = q["query"]["bool"]["filter"].as_array().unwrap(); assert_eq!(filters.len(), 3); @@ -220,7 +228,7 @@ mod tests { "after": "1710000000000", "before": "1710100000000" }"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let range_filter = &filters[1]["range"]["timestamp"]; @@ -231,7 +239,7 @@ mod tests { #[test] fn test_query_with_after_only() { let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let range_filter = &filters[1]["range"]["timestamp"]; @@ -242,7 +250,7 @@ mod tests { #[test] fn test_query_with_custom_limit() { let args = parse_args(r#"{"query": "hello", "limit": 50}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); assert_eq!(q["size"], 50); } @@ -256,7 +264,7 @@ mod tests { "before": "2000", "limit": 5 }"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); assert_eq!(q["size"], 5); let filters = q["query"]["bool"]["filter"].as_array().unwrap(); @@ -267,7 +275,7 @@ mod tests { #[test] fn test_invalid_timestamp_ignored() { let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); let filters = q["query"]["bool"]["filter"].as_array().unwrap(); // Only the redacted filter, no range since parse failed @@ -277,14 +285,14 @@ mod tests { #[test] fn test_wildcard_query_uses_match_all() { let args = parse_args(r#"{"query": "*"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); assert!(q["query"]["bool"]["must"][0]["match_all"].is_object()); } #[test] fn test_empty_query_uses_match_all() { let args = parse_args(r#"{"query": ""}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); assert!(q["query"]["bool"]["must"][0]["match_all"].is_object()); } @@ -292,7 +300,7 @@ mod tests { fn test_room_filter_uses_keyword_field() { // room_name is mapped as "keyword" in OpenSearch — no .keyword subfield let args = parse_args(r#"{"query": "test", "room": "general"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); let filters = q["query"]["bool"]["filter"].as_array().unwrap(); // Should be room_name, NOT room_name.keyword assert_eq!(filters[1]["term"]["room_name"], "general"); @@ -301,7 +309,7 @@ mod tests { #[test] fn test_source_fields() { let args = parse_args(r#"{"query": "test"}"#); - let q = build_search_query(&args); + let q = build_search_query(&args, &[]); let source = q["_source"].as_array().unwrap(); let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect();