room overlap access control for cross-room search

search_archive, get_room_context, and sol.search() (in run_script)
enforce a configurable member overlap threshold. results from a
room are only visible if >=25% of that room's members are also in
the requesting room.

system-level filter applied at the opensearch query layer — sol
never sees results from excluded rooms.
This commit is contained in:
2026-03-23 01:42:20 +00:00
parent 7324c10d25
commit 7dbc8a3121
4 changed files with 177 additions and 21 deletions

View File

@@ -1,26 +1,37 @@
pub mod bridge; pub mod bridge;
pub mod devtools; pub mod devtools;
pub mod identity;
pub mod research;
pub mod room_history; pub mod room_history;
pub mod room_info; pub mod room_info;
pub mod script; pub mod script;
pub mod search; pub mod search;
use std::collections::HashSet;
use std::sync::Arc; use std::sync::Arc;
use matrix_sdk::Client as MatrixClient; use matrix_sdk::Client as MatrixClient;
use matrix_sdk::RoomMemberships;
use mistralai_client::v1::tool::Tool; use mistralai_client::v1::tool::Tool;
use opensearch::OpenSearch; use opensearch::OpenSearch;
use serde_json::json; use serde_json::json;
use tracing::debug;
use crate::config::Config; use crate::config::Config;
use crate::context::ResponseContext; use crate::context::ResponseContext;
use crate::persistence::Store;
use crate::sdk::gitea::GiteaClient; use crate::sdk::gitea::GiteaClient;
use crate::sdk::kratos::KratosClient;
pub struct ToolRegistry { pub struct ToolRegistry {
opensearch: OpenSearch, opensearch: OpenSearch,
matrix: MatrixClient, matrix: MatrixClient,
config: Arc<Config>, config: Arc<Config>,
gitea: Option<Arc<GiteaClient>>, gitea: Option<Arc<GiteaClient>>,
kratos: Option<Arc<KratosClient>>,
mistral: Option<Arc<mistralai_client::v1::client::Client>>,
store: Option<Arc<Store>>,
} }
impl ToolRegistry { impl ToolRegistry {
@@ -29,12 +40,18 @@ impl ToolRegistry {
matrix: MatrixClient, matrix: MatrixClient,
config: Arc<Config>, config: Arc<Config>,
gitea: Option<Arc<GiteaClient>>, gitea: Option<Arc<GiteaClient>>,
kratos: Option<Arc<KratosClient>>,
mistral: Option<Arc<mistralai_client::v1::client::Client>>,
store: Option<Arc<Store>>,
) -> Self { ) -> Self {
Self { Self {
opensearch, opensearch,
matrix, matrix,
config, config,
gitea, gitea,
kratos,
mistral,
store,
} }
} }
@@ -42,7 +59,11 @@ impl ToolRegistry {
self.gitea.is_some() self.gitea.is_some()
} }
pub fn tool_definitions(gitea_enabled: bool) -> Vec<Tool> { pub fn has_kratos(&self) -> bool {
self.kratos.is_some()
}
pub fn tool_definitions(gitea_enabled: bool, kratos_enabled: bool) -> Vec<Tool> {
let mut tools = vec![ let mut tools = vec![
Tool::new( Tool::new(
"search_archive".into(), "search_archive".into(),
@@ -172,14 +193,22 @@ impl ToolRegistry {
if gitea_enabled { if gitea_enabled {
tools.extend(devtools::tool_definitions()); tools.extend(devtools::tool_definitions());
} }
if kratos_enabled {
tools.extend(identity::tool_definitions());
}
// Research tool (depth 0 — orchestrator level)
if let Some(def) = research::tool_definition(4, 0) {
tools.push(def);
}
tools tools
} }
/// Convert Sol's tool definitions to Mistral AgentTool format /// Convert Sol's tool definitions to Mistral AgentTool format
/// for use with the Agents API (orchestrator agent creation). /// for use with the Agents API (orchestrator agent creation).
pub fn agent_tool_definitions(gitea_enabled: bool) -> Vec<mistralai_client::v1::agents::AgentTool> { pub fn agent_tool_definitions(gitea_enabled: bool, kratos_enabled: bool) -> Vec<mistralai_client::v1::agents::AgentTool> {
Self::tool_definitions(gitea_enabled) Self::tool_definitions(gitea_enabled, kratos_enabled)
.into_iter() .into_iter()
.map(|t| { .map(|t| {
mistralai_client::v1::agents::AgentTool::function( mistralai_client::v1::agents::AgentTool::function(
@@ -191,6 +220,60 @@ impl ToolRegistry {
.collect() .collect()
} }
/// Compute the set of room IDs whose search results are visible from
/// the requesting room, based on member overlap.
///
/// A room's results are visible if at least ROOM_OVERLAP_THRESHOLD of
/// its members are also members of the requesting room. This is enforced
/// at the query level — Sol never sees filtered-out results.
async fn allowed_room_ids(&self, requesting_room_id: &str) -> Vec<String> {
let rooms = self.matrix.joined_rooms();
// Get requesting room's member set
let requesting_room = rooms.iter().find(|r| r.room_id().as_str() == requesting_room_id);
let requesting_members: HashSet<String> = match requesting_room {
Some(room) => match room.members(RoomMemberships::JOIN).await {
Ok(members) => members.iter().map(|m| m.user_id().to_string()).collect(),
Err(_) => return vec![requesting_room_id.to_string()],
},
None => return vec![requesting_room_id.to_string()],
};
let mut allowed = Vec::new();
for room in &rooms {
let room_id = room.room_id().to_string();
// Always allow the requesting room itself
if room_id == requesting_room_id {
allowed.push(room_id);
continue;
}
let members: HashSet<String> = match room.members(RoomMemberships::JOIN).await {
Ok(m) => m.iter().map(|m| m.user_id().to_string()).collect(),
Err(_) => continue,
};
if members.is_empty() {
continue;
}
let overlap = members.intersection(&requesting_members).count();
let ratio = overlap as f64 / members.len() as f64;
if ratio >= self.config.behavior.room_overlap_threshold as f64 {
debug!(
source_room = room_id.as_str(),
overlap_pct = format!("{:.0}%", ratio * 100.0).as_str(),
"Room passes overlap threshold"
);
allowed.push(room_id);
}
}
allowed
}
pub async fn execute( pub async fn execute(
&self, &self,
name: &str, name: &str,
@@ -199,30 +282,36 @@ impl ToolRegistry {
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
match name { match name {
"search_archive" => { "search_archive" => {
let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
search::search_archive( search::search_archive(
&self.opensearch, &self.opensearch,
&self.config.opensearch.index, &self.config.opensearch.index,
arguments, arguments,
&allowed,
) )
.await .await
} }
"get_room_context" => { "get_room_context" => {
let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
room_history::get_room_context( room_history::get_room_context(
&self.opensearch, &self.opensearch,
&self.config.opensearch.index, &self.config.opensearch.index,
arguments, arguments,
&allowed,
) )
.await .await
} }
"list_rooms" => room_info::list_rooms(&self.matrix).await, "list_rooms" => room_info::list_rooms(&self.matrix).await,
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await, "get_room_members" => room_info::get_room_members(&self.matrix, arguments).await,
"run_script" => { "run_script" => {
let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
script::run_script( script::run_script(
&self.opensearch, &self.opensearch,
&self.matrix, &self.matrix,
&self.config, &self.config,
arguments, arguments,
response_ctx, response_ctx,
allowed,
) )
.await .await
} }
@@ -233,7 +322,53 @@ impl ToolRegistry {
anyhow::bail!("Gitea integration not configured") anyhow::bail!("Gitea integration not configured")
} }
} }
name if name.starts_with("identity_") => {
if let Some(ref kratos) = self.kratos {
identity::execute(kratos, name, arguments).await
} else {
anyhow::bail!("Identity (Kratos) integration not configured")
}
}
"research" => {
if let (Some(ref mistral), Some(ref store)) = (&self.mistral, &self.store) {
anyhow::bail!("research tool requires execute_research() — call with room + event_id context")
} else {
anyhow::bail!("Research not configured (missing mistral client or store)")
}
}
_ => anyhow::bail!("Unknown tool: {name}"), _ => anyhow::bail!("Unknown tool: {name}"),
} }
} }
/// Execute a research tool call with full context (room, event_id for threads).
pub async fn execute_research(
self: &Arc<Self>,
arguments: &str,
response_ctx: &ResponseContext,
room: &matrix_sdk::room::Room,
event_id: &ruma::OwnedEventId,
current_depth: usize,
) -> anyhow::Result<String> {
let mistral = self
.mistral
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Research not configured: missing Mistral client"))?;
let store = self
.store
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Research not configured: missing store"))?;
research::execute(
arguments,
&self.config,
mistral,
self,
response_ctx,
room,
event_id,
store,
current_depth,
)
.await
}
} }

View File

@@ -21,8 +21,14 @@ pub async fn get_room_context(
client: &OpenSearch, client: &OpenSearch,
index: &str, index: &str,
args_json: &str, args_json: &str,
allowed_room_ids: &[String],
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let args: RoomHistoryArgs = serde_json::from_str(args_json)?; let args: RoomHistoryArgs = serde_json::from_str(args_json)?;
// Enforce room overlap — reject if the requested room isn't in the allowed set
if !allowed_room_ids.is_empty() && !allowed_room_ids.contains(&args.room_id) {
return Ok("Access denied: you don't have visibility into that room from here.".into());
}
let total = args.before_count + args.after_count + 1; let total = args.before_count + args.after_count + 1;
// Determine the pivot timestamp // Determine the pivot timestamp

View File

@@ -26,6 +26,7 @@ struct ScriptState {
config: Arc<Config>, config: Arc<Config>,
tmpdir: PathBuf, tmpdir: PathBuf,
user_id: String, user_id: String,
allowed_room_ids: Vec<String>,
} }
struct ScriptOutput(String); struct ScriptOutput(String);
@@ -83,17 +84,21 @@ async fn op_sol_search(
#[string] query: String, #[string] query: String,
#[string] opts_json: String, #[string] opts_json: String,
) -> Result<String, JsErrorBox> { ) -> Result<String, JsErrorBox> {
let (os, index) = { let (os, index, allowed) = {
let st = state.borrow(); let st = state.borrow();
let ss = st.borrow::<ScriptState>(); let ss = st.borrow::<ScriptState>();
(ss.opensearch.clone(), ss.config.opensearch.index.clone()) (
ss.opensearch.clone(),
ss.config.opensearch.index.clone(),
ss.allowed_room_ids.clone(),
)
}; };
let mut args: serde_json::Value = let mut args: serde_json::Value =
serde_json::from_str(&opts_json).unwrap_or(serde_json::json!({})); serde_json::from_str(&opts_json).unwrap_or(serde_json::json!({}));
args["query"] = serde_json::Value::String(query); args["query"] = serde_json::Value::String(query);
super::search::search_archive(&os, &index, &args.to_string()) super::search::search_archive(&os, &index, &args.to_string(), &allowed)
.await .await
.map_err(|e| JsErrorBox::generic(e.to_string())) .map_err(|e| JsErrorBox::generic(e.to_string()))
} }
@@ -429,6 +434,7 @@ pub async fn run_script(
config: &Config, config: &Config,
args_json: &str, args_json: &str,
response_ctx: &ResponseContext, response_ctx: &ResponseContext,
allowed_room_ids: Vec<String>,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let args: RunScriptArgs = serde_json::from_str(args_json)?; let args: RunScriptArgs = serde_json::from_str(args_json)?;
let code = args.code.clone(); let code = args.code.clone();
@@ -494,6 +500,7 @@ pub async fn run_script(
config: cfg, config: cfg,
tmpdir: tmpdir_path, tmpdir: tmpdir_path,
user_id, user_id,
allowed_room_ids,
}); });
} }

View File

@@ -23,7 +23,8 @@ pub struct SearchArgs {
fn default_limit() -> usize { 10 } fn default_limit() -> usize { 10 }
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability. /// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
pub fn build_search_query(args: &SearchArgs) -> serde_json::Value { /// `allowed_room_ids` restricts results to rooms that pass the member overlap check.
pub fn build_search_query(args: &SearchArgs, allowed_room_ids: &[String]) -> serde_json::Value {
// Handle empty/wildcard queries as match_all // Handle empty/wildcard queries as match_all
let must = if args.query.is_empty() || args.query == "*" { let must = if args.query.is_empty() || args.query == "*" {
vec![json!({ "match_all": {} })] vec![json!({ "match_all": {} })]
@@ -37,6 +38,12 @@ pub fn build_search_query(args: &SearchArgs) -> serde_json::Value {
"term": { "redacted": false } "term": { "redacted": false }
})]; })];
// Restrict to rooms that pass the member overlap threshold.
// This is a system-level security filter — Sol never sees results from excluded rooms.
if !allowed_room_ids.is_empty() {
filter.push(json!({ "terms": { "room_id": allowed_room_ids } }));
}
if let Some(ref room) = args.room { if let Some(ref room) = args.room {
filter.push(json!({ "term": { "room_name": room } })); filter.push(json!({ "term": { "room_name": room } }));
} }
@@ -76,9 +83,10 @@ pub async fn search_archive(
client: &OpenSearch, client: &OpenSearch,
index: &str, index: &str,
args_json: &str, args_json: &str,
allowed_room_ids: &[String],
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let args: SearchArgs = serde_json::from_str(args_json)?; let args: SearchArgs = serde_json::from_str(args_json)?;
let query_body = build_search_query(&args); let query_body = build_search_query(&args, allowed_room_ids);
info!( info!(
query = args.query.as_str(), query = args.query.as_str(),
@@ -174,7 +182,7 @@ mod tests {
#[test] #[test]
fn test_query_basic() { fn test_query_basic() {
let args = parse_args(r#"{"query": "test"}"#); let args = parse_args(r#"{"query": "test"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
assert_eq!(q["size"], 10); assert_eq!(q["size"], 10);
assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test"); assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test");
@@ -185,7 +193,7 @@ mod tests {
#[test] #[test]
fn test_query_with_room_filter() { fn test_query_with_room_filter() {
let args = parse_args(r#"{"query": "hello", "room": "design"}"#); let args = parse_args(r#"{"query": "hello", "room": "design"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2); assert_eq!(filters.len(), 2);
@@ -195,7 +203,7 @@ mod tests {
#[test] #[test]
fn test_query_with_sender_filter() { fn test_query_with_sender_filter() {
let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#); let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2); assert_eq!(filters.len(), 2);
@@ -205,7 +213,7 @@ mod tests {
#[test] #[test]
fn test_query_with_room_and_sender() { fn test_query_with_room_and_sender() {
let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#); let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 3); assert_eq!(filters.len(), 3);
@@ -220,7 +228,7 @@ mod tests {
"after": "1710000000000", "after": "1710000000000",
"before": "1710100000000" "before": "1710100000000"
}"#); }"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let filters = q["query"]["bool"]["filter"].as_array().unwrap();
let range_filter = &filters[1]["range"]["timestamp"]; let range_filter = &filters[1]["range"]["timestamp"];
@@ -231,7 +239,7 @@ mod tests {
#[test] #[test]
fn test_query_with_after_only() { fn test_query_with_after_only() {
let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#); let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let filters = q["query"]["bool"]["filter"].as_array().unwrap();
let range_filter = &filters[1]["range"]["timestamp"]; let range_filter = &filters[1]["range"]["timestamp"];
@@ -242,7 +250,7 @@ mod tests {
#[test] #[test]
fn test_query_with_custom_limit() { fn test_query_with_custom_limit() {
let args = parse_args(r#"{"query": "hello", "limit": 50}"#); let args = parse_args(r#"{"query": "hello", "limit": 50}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
assert_eq!(q["size"], 50); assert_eq!(q["size"], 50);
} }
@@ -256,7 +264,7 @@ mod tests {
"before": "2000", "before": "2000",
"limit": 5 "limit": 5
}"#); }"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
assert_eq!(q["size"], 5); assert_eq!(q["size"], 5);
let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let filters = q["query"]["bool"]["filter"].as_array().unwrap();
@@ -267,7 +275,7 @@ mod tests {
#[test] #[test]
fn test_invalid_timestamp_ignored() { fn test_invalid_timestamp_ignored() {
let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#); let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// Only the redacted filter, no range since parse failed // Only the redacted filter, no range since parse failed
@@ -277,14 +285,14 @@ mod tests {
#[test] #[test]
fn test_wildcard_query_uses_match_all() { fn test_wildcard_query_uses_match_all() {
let args = parse_args(r#"{"query": "*"}"#); let args = parse_args(r#"{"query": "*"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
assert!(q["query"]["bool"]["must"][0]["match_all"].is_object()); assert!(q["query"]["bool"]["must"][0]["match_all"].is_object());
} }
#[test] #[test]
fn test_empty_query_uses_match_all() { fn test_empty_query_uses_match_all() {
let args = parse_args(r#"{"query": ""}"#); let args = parse_args(r#"{"query": ""}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
assert!(q["query"]["bool"]["must"][0]["match_all"].is_object()); assert!(q["query"]["bool"]["must"][0]["match_all"].is_object());
} }
@@ -292,7 +300,7 @@ mod tests {
fn test_room_filter_uses_keyword_field() { fn test_room_filter_uses_keyword_field() {
// room_name is mapped as "keyword" in OpenSearch — no .keyword subfield // room_name is mapped as "keyword" in OpenSearch — no .keyword subfield
let args = parse_args(r#"{"query": "test", "room": "general"}"#); let args = parse_args(r#"{"query": "test", "room": "general"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap(); let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// Should be room_name, NOT room_name.keyword // Should be room_name, NOT room_name.keyword
assert_eq!(filters[1]["term"]["room_name"], "general"); assert_eq!(filters[1]["term"]["room_name"], "general");
@@ -301,7 +309,7 @@ mod tests {
#[test] #[test]
fn test_source_fields() { fn test_source_fields() {
let args = parse_args(r#"{"query": "test"}"#); let args = parse_args(r#"{"query": "test"}"#);
let q = build_search_query(&args); let q = build_search_query(&args, &[]);
let source = q["_source"].as_array().unwrap(); let source = q["_source"].as_array().unwrap();
let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect(); let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect();