Files
sol/src/tools/search.rs
Sienna Meridian Satterwhite 7dbc8a3121 room overlap access control for cross-room search
search_archive, get_room_context, and sol.search() (in run_script)
enforce a configurable member overlap threshold. results from a
room are only visible if >=25% of that room's members are also in
the requesting room.

system-level filter applied at the opensearch query layer — sol
never sees results from excluded rooms.
2026-03-23 01:42:20 +00:00

323 lines
10 KiB
Rust

use opensearch::OpenSearch;
use serde::Deserialize;
use serde_json::json;
use tracing::{debug, info};
#[derive(Debug, Deserialize)]
pub struct SearchArgs {
pub query: String,
#[serde(default)]
pub room: Option<String>,
#[serde(default)]
pub sender: Option<String>,
#[serde(default)]
pub after: Option<String>,
#[serde(default)]
pub before: Option<String>,
#[serde(default = "default_limit")]
pub limit: usize,
#[serde(default)]
pub semantic: Option<bool>,
}
fn default_limit() -> usize { 10 }
/// Build the OpenSearch query body from parsed SearchArgs. Extracted for testability.
/// `allowed_room_ids` restricts results to rooms that pass the member overlap check.
pub fn build_search_query(args: &SearchArgs, allowed_room_ids: &[String]) -> serde_json::Value {
// Handle empty/wildcard queries as match_all
let must = if args.query.is_empty() || args.query == "*" {
vec![json!({ "match_all": {} })]
} else {
vec![json!({
"match": { "content": args.query }
})]
};
let mut filter = vec![json!({
"term": { "redacted": false }
})];
// Restrict to rooms that pass the member overlap threshold.
// This is a system-level security filter — Sol never sees results from excluded rooms.
if !allowed_room_ids.is_empty() {
filter.push(json!({ "terms": { "room_id": allowed_room_ids } }));
}
if let Some(ref room) = args.room {
filter.push(json!({ "term": { "room_name": room } }));
}
if let Some(ref sender) = args.sender {
filter.push(json!({ "term": { "sender_name": sender } }));
}
let mut range = serde_json::Map::new();
if let Some(ref after) = args.after {
if let Ok(ts) = after.parse::<i64>() {
range.insert("gte".into(), json!(ts));
}
}
if let Some(ref before) = args.before {
if let Ok(ts) = before.parse::<i64>() {
range.insert("lte".into(), json!(ts));
}
}
if !range.is_empty() {
filter.push(json!({ "range": { "timestamp": range } }));
}
json!({
"size": args.limit,
"query": {
"bool": {
"must": must,
"filter": filter
}
},
"sort": [{ "timestamp": "desc" }],
"_source": ["event_id", "room_name", "sender_name", "timestamp", "content"]
})
}
pub async fn search_archive(
client: &OpenSearch,
index: &str,
args_json: &str,
allowed_room_ids: &[String],
) -> anyhow::Result<String> {
let args: SearchArgs = serde_json::from_str(args_json)?;
let query_body = build_search_query(&args, allowed_room_ids);
info!(
query = args.query.as_str(),
room = args.room.as_deref().unwrap_or("*"),
sender = args.sender.as_deref().unwrap_or("*"),
after = args.after.as_deref().unwrap_or("*"),
before = args.before.as_deref().unwrap_or("*"),
limit = args.limit,
query_json = %query_body,
"Executing search"
);
let response = client
.search(opensearch::SearchParts::Index(&[index]))
.body(query_body)
.send()
.await?;
let body: serde_json::Value = response.json().await?;
let hit_count = body["hits"]["total"]["value"].as_i64().unwrap_or(0);
info!(hit_count, "Search results");
let hits = &body["hits"]["hits"];
let Some(hits_arr) = hits.as_array() else {
return Ok("No results found.".into());
};
if hits_arr.is_empty() {
return Ok("No results found.".into());
}
let mut output = String::new();
for (i, hit) in hits_arr.iter().enumerate() {
let src = &hit["_source"];
let sender = src["sender_name"].as_str().unwrap_or("unknown");
let room = src["room_name"].as_str().unwrap_or("unknown");
let content = src["content"].as_str().unwrap_or("");
let ts = src["timestamp"].as_i64().unwrap_or(0);
let dt = chrono::DateTime::from_timestamp_millis(ts)
.map(|d| d.format("%Y-%m-%d %H:%M").to_string())
.unwrap_or_else(|| "unknown date".into());
output.push_str(&format!(
"{}. [{dt}] #{room}{sender}: {content}\n",
i + 1
));
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
fn parse_args(json: &str) -> SearchArgs {
serde_json::from_str(json).unwrap()
}
#[test]
fn test_parse_minimal_args() {
let args = parse_args(r#"{"query": "hello"}"#);
assert_eq!(args.query, "hello");
assert!(args.room.is_none());
assert!(args.sender.is_none());
assert!(args.after.is_none());
assert!(args.before.is_none());
assert_eq!(args.limit, 10); // default
assert!(args.semantic.is_none());
}
#[test]
fn test_parse_full_args() {
let args = parse_args(r#"{
"query": "meeting notes",
"room": "general",
"sender": "Alice",
"after": "1710000000000",
"before": "1710100000000",
"limit": 25,
"semantic": true
}"#);
assert_eq!(args.query, "meeting notes");
assert_eq!(args.room.as_deref(), Some("general"));
assert_eq!(args.sender.as_deref(), Some("Alice"));
assert_eq!(args.after.as_deref(), Some("1710000000000"));
assert_eq!(args.before.as_deref(), Some("1710100000000"));
assert_eq!(args.limit, 25);
assert_eq!(args.semantic, Some(true));
}
#[test]
fn test_query_basic() {
let args = parse_args(r#"{"query": "test"}"#);
let q = build_search_query(&args, &[]);
assert_eq!(q["size"], 10);
assert_eq!(q["query"]["bool"]["must"][0]["match"]["content"], "test");
assert_eq!(q["query"]["bool"]["filter"][0]["term"]["redacted"], false);
assert_eq!(q["sort"][0]["timestamp"], "desc");
}
#[test]
fn test_query_with_room_filter() {
let args = parse_args(r#"{"query": "hello", "room": "design"}"#);
let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2);
assert_eq!(filters[1]["term"]["room_name"], "design");
}
#[test]
fn test_query_with_sender_filter() {
let args = parse_args(r#"{"query": "hello", "sender": "Bob"}"#);
let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 2);
assert_eq!(filters[1]["term"]["sender_name"], "Bob");
}
#[test]
fn test_query_with_room_and_sender() {
let args = parse_args(r#"{"query": "hello", "room": "dev", "sender": "Carol"}"#);
let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
assert_eq!(filters.len(), 3);
assert_eq!(filters[1]["term"]["room_name"], "dev");
assert_eq!(filters[2]["term"]["sender_name"], "Carol");
}
#[test]
fn test_query_with_date_range() {
let args = parse_args(r#"{
"query": "hello",
"after": "1710000000000",
"before": "1710100000000"
}"#);
let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
let range_filter = &filters[1]["range"]["timestamp"];
assert_eq!(range_filter["gte"], 1710000000000_i64);
assert_eq!(range_filter["lte"], 1710100000000_i64);
}
#[test]
fn test_query_with_after_only() {
let args = parse_args(r#"{"query": "hello", "after": "1710000000000"}"#);
let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
let range_filter = &filters[1]["range"]["timestamp"];
assert_eq!(range_filter["gte"], 1710000000000_i64);
assert!(range_filter.get("lte").is_none());
}
#[test]
fn test_query_with_custom_limit() {
let args = parse_args(r#"{"query": "hello", "limit": 50}"#);
let q = build_search_query(&args, &[]);
assert_eq!(q["size"], 50);
}
#[test]
fn test_query_all_filters_combined() {
let args = parse_args(r#"{
"query": "architecture",
"room": "engineering",
"sender": "Sienna",
"after": "1000",
"before": "2000",
"limit": 5
}"#);
let q = build_search_query(&args, &[]);
assert_eq!(q["size"], 5);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// redacted=false, room, sender, range = 4 filters
assert_eq!(filters.len(), 4);
}
#[test]
fn test_invalid_timestamp_ignored() {
let args = parse_args(r#"{"query": "hello", "after": "not-a-number"}"#);
let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// Only the redacted filter, no range since parse failed
assert_eq!(filters.len(), 1);
}
#[test]
fn test_wildcard_query_uses_match_all() {
let args = parse_args(r#"{"query": "*"}"#);
let q = build_search_query(&args, &[]);
assert!(q["query"]["bool"]["must"][0]["match_all"].is_object());
}
#[test]
fn test_empty_query_uses_match_all() {
let args = parse_args(r#"{"query": ""}"#);
let q = build_search_query(&args, &[]);
assert!(q["query"]["bool"]["must"][0]["match_all"].is_object());
}
#[test]
fn test_room_filter_uses_keyword_field() {
// room_name is mapped as "keyword" in OpenSearch — no .keyword subfield
let args = parse_args(r#"{"query": "test", "room": "general"}"#);
let q = build_search_query(&args, &[]);
let filters = q["query"]["bool"]["filter"].as_array().unwrap();
// Should be room_name, NOT room_name.keyword
assert_eq!(filters[1]["term"]["room_name"], "general");
}
#[test]
fn test_source_fields() {
let args = parse_args(r#"{"query": "test"}"#);
let q = build_search_query(&args, &[]);
let source = q["_source"].as_array().unwrap();
let fields: Vec<&str> = source.iter().map(|v| v.as_str().unwrap()).collect();
assert!(fields.contains(&"event_id"));
assert!(fields.contains(&"room_name"));
assert!(fields.contains(&"sender_name"));
assert!(fields.contains(&"timestamp"));
assert!(fields.contains(&"content"));
}
}