From c9cddc80d930d201f234c37529c6e79bd98c9df7 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Sun, 8 Mar 2026 17:41:20 +0000 Subject: [PATCH] Add OpenSearch search backend with hybrid neural+BM25 support Extract a SearchBackend trait from the existing RocksDB search code and add an OpenSearch implementation supporting cross-room search, relevance ranking, fuzzy matching, English stemming, and optional hybrid neural+BM25 semantic search using sentence-transformers. Fix macOS build by gating RLIMIT_NPROC and getrusage to supported platforms. --- .github/workflows/main.yml | 13 + .github/workflows/opensearch-tests.yml | 114 ++ src/api/client/search.rs | 129 ++- src/core/config/check.rs | 16 + src/core/config/mod.rs | 97 ++ src/core/utils/sys/limits.rs | 14 +- src/core/utils/sys/usage.rs | 23 +- src/service/rooms/search/backend.rs | 50 + src/service/rooms/search/mod.rs | 354 +++--- src/service/rooms/search/opensearch.rs | 1442 ++++++++++++++++++++++++ src/service/rooms/search/rocksdb.rs | 193 ++++ src/service/rooms/timeline/append.rs | 3 +- src/service/rooms/timeline/backfill.rs | 3 +- src/service/rooms/timeline/redact.rs | 7 +- tuwunel-example.toml | 66 ++ 15 files changed, 2328 insertions(+), 196 deletions(-) create mode 100644 .github/workflows/opensearch-tests.yml create mode 100644 src/service/rooms/search/backend.rs create mode 100644 src/service/rooms/search/opensearch.rs create mode 100644 src/service/rooms/search/rocksdb.rs diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b26da0d2..23ebfafb 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -332,6 +332,19 @@ jobs: {"sys_target": "x86_64-v4-linux-gnu", "feat_set": "default"}, ] + opensearch-tests: + if: > + !failure() && !cancelled() + && fromJSON(needs.init.outputs.enable_test) + && !fromJSON(needs.init.outputs.is_release) + && !contains(needs.init.outputs.pipeline, '[ci no test]') + + name: OpenSearch + needs: [init, lint] + uses: ./.github/workflows/opensearch-tests.yml + with: + checkout: ${{needs.init.outputs.checkout}} + package: if: > !failure() && !cancelled() diff --git a/.github/workflows/opensearch-tests.yml b/.github/workflows/opensearch-tests.yml new file mode 100644 index 00000000..c5e3bc56 --- /dev/null +++ b/.github/workflows/opensearch-tests.yml @@ -0,0 +1,114 @@ +name: OpenSearch Integration Tests + +on: + workflow_call: + inputs: + checkout: + type: string + default: 'HEAD' + pull_request: + paths: + - 'src/service/rooms/search/**' + - 'src/core/config/mod.rs' + - '.github/workflows/opensearch-tests.yml' + push: + branches: [main] + paths: + - 'src/service/rooms/search/**' + - 'src/core/config/mod.rs' + - '.github/workflows/opensearch-tests.yml' + +jobs: + opensearch-tests: + name: OpenSearch Integration + runs-on: ubuntu-latest + + services: + opensearch: + image: opensearchproject/opensearch:3.5.0 + ports: + - 9200:9200 + env: + discovery.type: single-node + DISABLE_SECURITY_PLUGIN: "true" + OPENSEARCH_INITIAL_ADMIN_PASSWORD: "SuperSecret123!" + OPENSEARCH_JAVA_OPTS: "-Xms1g -Xmx2g" + options: >- + --health-cmd "curl -f http://localhost:9200/_cluster/health || exit 1" + --health-interval 10s + --health-timeout 5s + --health-retries 20 + --health-start-period 60s + + steps: + - uses: actions/checkout@v4 + with: + ref: ${{ inputs.checkout || github.sha }} + + - name: Install Rust toolchain + uses: dtolnay/rust-toolchain@nightly + + - name: Cache cargo registry and build + uses: actions/cache@v4 + with: + path: | + ~/.cargo/registry + ~/.cargo/git + target + key: opensearch-tests-${{ runner.os }}-${{ hashFiles('**/Cargo.lock') }} + restore-keys: | + opensearch-tests-${{ runner.os }}- + + - name: Wait for OpenSearch + run: | + for i in $(seq 1 60); do + if curl -sf http://localhost:9200/_cluster/health; then + echo "" + echo "OpenSearch is ready" + exit 0 + fi + echo "Waiting for OpenSearch... ($i/60)" + sleep 2 + done + echo "OpenSearch did not start in time" + exit 1 + + - name: Configure OpenSearch ML plugin + run: | + curl -sf -X PUT http://localhost:9200/_cluster/settings \ + -H 'Content-Type: application/json' \ + -d '{ + "persistent": { + "plugins.ml_commons.only_run_on_ml_node": "false", + "plugins.ml_commons.native_memory_threshold": "99", + "plugins.ml_commons.allow_registering_model_via_url": "true" + } + }' + + - name: Run OpenSearch integration tests (BM25) + env: + OPENSEARCH_URL: http://localhost:9200 + run: > + cargo test + -p tuwunel_service + --lib + rooms::search::opensearch::tests + -- + --ignored + --test-threads=1 + --skip test_neural + --skip test_hybrid + + - name: Run OpenSearch neural/hybrid tests + env: + OPENSEARCH_URL: http://localhost:9200 + timeout-minutes: 15 + run: > + cargo test + -p tuwunel_service + --lib + rooms::search::opensearch::tests + -- + --ignored + --test-threads=1 + test_neural test_hybrid diff --git a/src/api/client/search.rs b/src/api/client/search.rs index a6f47714..f8434ceb 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -1,4 +1,4 @@ -use std::collections::BTreeMap; +use std::collections::{BTreeMap, BTreeSet}; use axum::extract::State; use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; @@ -92,37 +92,120 @@ async fn category_room_events( .boxed() }); - let results: Vec<_> = rooms + let accessible_rooms: Vec = rooms .filter_map(async |room_id| { check_room_visible(services, sender_user, &room_id, criteria) .await .is_ok() .then_some(room_id) }) - .filter_map(async |room_id| { - let query = RoomQuery { - room_id: &room_id, - user_id: Some(sender_user), - criteria, - skip: next_batch, - limit, - }; - - let (count, results) = services.search.search_pdus(&query).await.ok()?; - - results - .collect::>() - .map(|results| (room_id.clone(), count, results)) - .map(Some) - .await - }) .collect() .await; - let total: UInt = results - .iter() - .fold(0, |a: usize, (_, count, _)| a.saturating_add(*count)) - .try_into()?; + let (results, total) = if services.search.supports_cross_room_search() { + let room_id_refs: Vec<&RoomId> = accessible_rooms + .iter() + .map(AsRef::as_ref) + .collect(); + + let (count, pdus) = services + .search + .search_pdus_multi_room(&room_id_refs, sender_user, criteria, limit, next_batch) + .await?; + + let pdu_list: Vec<_> = pdus.collect().await; + let total: UInt = count.try_into()?; + + // Group results by room for state collection + let room_ids_in_results: Vec = pdu_list + .iter() + .map(|pdu| pdu.room_id().to_owned()) + .collect(); + + let state: RoomStates = if criteria.include_state.is_some_and(is_true!()) { + let unique_rooms: BTreeSet = + room_ids_in_results.iter().cloned().collect(); + futures::stream::iter(unique_rooms.iter()) + .filter_map(async |room_id| { + procure_room_state(services, room_id) + .map_ok(|state| (room_id.clone(), state)) + .await + .ok() + }) + .collect() + .await + } else { + BTreeMap::new() + }; + + let results: Vec = pdu_list + .into_iter() + .stream() + .map(Event::into_format) + .map(|result| SearchResult { + rank: None, + result: Some(result), + context: EventContextResult { + profile_info: BTreeMap::new(), //TODO + events_after: Vec::new(), //TODO + events_before: Vec::new(), //TODO + start: None, //TODO + end: None, //TODO + }, + }) + .collect() + .await; + + let highlights = criteria + .search_term + .split_terminator(|c: char| !c.is_alphanumeric()) + .map(str::to_lowercase) + .collect(); + + let next_batch = (results.len() >= limit) + .then_some(next_batch.saturating_add(results.len())) + .as_ref() + .map(ToString::to_string); + + return Ok(ResultRoomEvents { + count: Some(total), + next_batch, + results, + state, + highlights, + groups: BTreeMap::new(), // TODO + }); + } else { + let results: Vec<_> = accessible_rooms + .iter() + .stream() + .filter_map(async |room_id| { + let query = RoomQuery { + room_id, + user_id: Some(sender_user), + criteria, + skip: next_batch, + limit, + }; + + let (count, results) = services.search.search_pdus(&query).await.ok()?; + + results + .collect::>() + .map(|results| (room_id.clone(), count, results)) + .map(Some) + .await + }) + .collect() + .await; + + let total: UInt = results + .iter() + .fold(0, |a: usize, (_, count, _)| a.saturating_add(*count)) + .try_into()?; + + (results, total) + }; let state: RoomStates = results .iter() diff --git a/src/core/config/check.rs b/src/core/config/check.rs index dd268989..c629014d 100644 --- a/src/core/config/check.rs +++ b/src/core/config/check.rs @@ -116,6 +116,22 @@ pub fn check(config: &Config) -> Result { }); } + if config.search_backend == super::SearchBackendConfig::OpenSearch + && config.search_opensearch_url.is_none() + { + return Err!(Config( + "search_opensearch_url", + "OpenSearch URL must be set when search_backend is \"opensearch\"" + )); + } + + if config.search_opensearch_hybrid && config.search_opensearch_model_id.is_none() { + return Err!(Config( + "search_opensearch_model_id", + "Model ID must be set when search_opensearch_hybrid is enabled" + )); + } + // rocksdb does not allow max_log_files to be 0 if config.rocksdb_max_log_files == 0 { return Err!(Config( diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index effca79a..4712000a 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1100,6 +1100,85 @@ pub struct Config { #[serde(default)] pub auto_deactivate_banned_room_attempts: bool, + /// Search backend to use for full-text message search. + /// + /// Available options: "rocksdb" (default) or "opensearch". + /// + /// default: "rocksdb" + #[serde(default)] + pub search_backend: SearchBackendConfig, + + /// URL of the OpenSearch instance. Required when search_backend is + /// "opensearch". + /// + /// example: "http://localhost:9200" + pub search_opensearch_url: Option, + + /// Name of the OpenSearch index for message search. + /// + /// default: "tuwunel_messages" + #[serde(default = "default_search_opensearch_index")] + pub search_opensearch_index: String, + + /// Authentication for OpenSearch in "user:pass" format. + /// + /// display: sensitive + pub search_opensearch_auth: Option, + + /// Maximum number of documents to batch before flushing to OpenSearch. + /// + /// default: 100 + #[serde(default = "default_search_opensearch_batch_size")] + pub search_opensearch_batch_size: usize, + + /// Maximum time in milliseconds to wait before flushing a partial batch + /// to OpenSearch. + /// + /// default: 1000 + #[serde(default = "default_search_opensearch_flush_interval_ms")] + pub search_opensearch_flush_interval_ms: u64, + + /// Enable hybrid neural+BM25 search in OpenSearch. Requires an ML model + /// deployed in OpenSearch and an ingest pipeline that populates an + /// "embedding" field. + /// + /// When enabled, tuwunel will: + /// - Create the index with a knn_vector "embedding" field + /// - Attach the ingest pipeline (search_opensearch_pipeline) to the index + /// - Use hybrid queries combining BM25 + neural kNN scoring + /// + /// For a complete reference on configuring OpenSearch's ML plugin, model + /// registration, and ingest pipeline setup, see the test helpers in + /// `src/service/rooms/search/opensearch.rs` (the `ensure_neural_model`, + /// `ensure_ingest_pipeline`, etc. functions in the `tests` module). + /// + /// See also: https://opensearch.org/docs/latest/search-plugins/neural-search/ + /// + /// default: false + #[serde(default)] + pub search_opensearch_hybrid: bool, + + /// The model ID registered in OpenSearch for neural search. Required when + /// search_opensearch_hybrid is enabled. + /// + /// example: "aKV84osBBHNT0StI3MBr" + pub search_opensearch_model_id: Option, + + /// Embedding dimension for the neural search model. Must match the output + /// dimension of the deployed model. Common values: 384 + /// (all-MiniLM-L6-v2), 768 (msmarco-distilbert-base-tas-b). + /// + /// default: 384 + #[serde(default = "default_search_opensearch_embedding_dim")] + pub search_opensearch_embedding_dim: usize, + + /// Name of the ingest pipeline that generates embeddings for the + /// "embedding" field. This pipeline must already exist in OpenSearch. + /// + /// default: "tuwunel_embedding_pipeline" + #[serde(default = "default_search_opensearch_pipeline")] + pub search_opensearch_pipeline: String, + /// RocksDB log level. This is not the same as tuwunel's log level. This /// is the log level for the RocksDB engine/library which show up in your /// database folder/path as `LOG` files. tuwunel will log RocksDB errors @@ -2304,6 +2383,14 @@ pub struct Config { catchall: BTreeMap, } +#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)] +#[serde(rename_all = "lowercase")] +pub enum SearchBackendConfig { + #[default] + RocksDb, + OpenSearch, +} + #[derive(Clone, Debug, Deserialize, Default)] #[config_example_generator(filename = "tuwunel-example.toml", section = "global.tls")] pub struct TlsConfig { @@ -3181,6 +3268,16 @@ impl Config { fn true_fn() -> bool { true } +fn default_search_opensearch_index() -> String { "tuwunel_messages".to_owned() } + +fn default_search_opensearch_batch_size() -> usize { 100 } + +fn default_search_opensearch_flush_interval_ms() -> u64 { 1000 } + +fn default_search_opensearch_embedding_dim() -> usize { 384 } + +fn default_search_opensearch_pipeline() -> String { "tuwunel_embedding_pipeline".to_owned() } + #[cfg(test)] fn default_server_name() -> OwnedServerName { ruma::owned_server_name!("localhost") } diff --git a/src/core/utils/sys/limits.rs b/src/core/utils/sys/limits.rs index aa34a143..e561f9d0 100644 --- a/src/core/utils/sys/limits.rs +++ b/src/core/utils/sys/limits.rs @@ -28,7 +28,12 @@ pub fn maximize_fd_limit() -> Result { #[cfg(not(unix))] pub fn maximize_fd_limit() -> Result { Ok(()) } -#[cfg(unix)] +#[cfg(any( + linux_android, + netbsdlike, + target_os = "aix", + target_os = "freebsd", +))] /// Some distributions ship with very low defaults for thread counts; similar to /// low default file descriptor limits. But unlike fd's, thread limit is rarely /// reached, though on large systems (32+ cores) shipping with defaults of @@ -47,7 +52,12 @@ pub fn maximize_thread_limit() -> Result { Ok(()) } -#[cfg(not(unix))] +#[cfg(not(any( + linux_android, + netbsdlike, + target_os = "aix", + target_os = "freebsd", +)))] pub fn maximize_thread_limit() -> Result { Ok(()) } #[cfg(unix)] diff --git a/src/core/utils/sys/usage.rs b/src/core/utils/sys/usage.rs index 6bf6f01a..c8d05959 100644 --- a/src/core/utils/sys/usage.rs +++ b/src/core/utils/sys/usage.rs @@ -1,6 +1,5 @@ -use nix::sys::resource::Usage; #[cfg(unix)] -use nix::sys::resource::{UsageWho, getrusage}; +use nix::sys::resource::{Usage, UsageWho, getrusage}; use crate::Result; @@ -8,7 +7,7 @@ use crate::Result; pub fn usage() -> Result { getrusage(UsageWho::RUSAGE_SELF).map_err(Into::into) } #[cfg(not(unix))] -pub fn usage() -> Result { Ok(Usage::default()) } +pub fn usage() -> Result<()> { Ok(()) } #[cfg(any( target_os = "linux", @@ -17,9 +16,15 @@ pub fn usage() -> Result { Ok(Usage::default()) } ))] pub fn thread_usage() -> Result { getrusage(UsageWho::RUSAGE_THREAD).map_err(Into::into) } -#[cfg(not(any( - target_os = "linux", - target_os = "freebsd", - target_os = "openbsd" -)))] -pub fn thread_usage() -> Result { Ok(Usage::default()) } +#[cfg(all( + unix, + not(any( + target_os = "linux", + target_os = "freebsd", + target_os = "openbsd" + )) +))] +pub fn thread_usage() -> Result { getrusage(UsageWho::RUSAGE_SELF).map_err(Into::into) } + +#[cfg(not(unix))] +pub fn thread_usage() -> Result<()> { Ok(()) } diff --git a/src/service/rooms/search/backend.rs b/src/service/rooms/search/backend.rs new file mode 100644 index 00000000..f9c93d0b --- /dev/null +++ b/src/service/rooms/search/backend.rs @@ -0,0 +1,50 @@ +use async_trait::async_trait; +use ruma::RoomId; +use tuwunel_core::Result; + +use crate::rooms::{short::ShortRoomId, timeline::RawPduId}; + +/// A search hit returned by a backend. +pub(super) struct SearchHit { + pub(super) pdu_id: RawPduId, + #[expect(dead_code)] + pub(super) score: Option, +} + +/// Query parameters passed to the backend for searching. +pub(super) struct BackendQuery<'a> { + pub(super) search_term: &'a str, + pub(super) room_ids: Option<&'a [&'a RoomId]>, +} + +/// Trait abstracting the search index backend. +#[async_trait] +pub(super) trait SearchBackend: Send + Sync { + async fn index_pdu( + &self, + shortroomid: ShortRoomId, + pdu_id: &RawPduId, + room_id: &RoomId, + message_body: &str, + ) -> Result; + + async fn deindex_pdu( + &self, + shortroomid: ShortRoomId, + pdu_id: &RawPduId, + room_id: &RoomId, + message_body: &str, + ) -> Result; + + async fn search_pdu_ids( + &self, + query: &BackendQuery<'_>, + shortroomid: Option, + limit: usize, + skip: usize, + ) -> Result<(usize, Vec)>; + + async fn delete_room(&self, room_id: &RoomId, shortroomid: ShortRoomId) -> Result; + + fn supports_cross_room_search(&self) -> bool { false } +} diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index baa7d09b..91bc01e8 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,34 +1,33 @@ +mod backend; +mod opensearch; +mod rocksdb; + use std::sync::Arc; +use async_trait::async_trait; use futures::{Stream, StreamExt}; use ruma::{RoomId, UserId, api::client::search::search_events::v3::Criteria}; +use tokio::sync::mpsc; use tuwunel_core::{ - PduCount, Result, - arrayvec::ArrayVec, - implement, + Result, implement, matrix::event::{Event, Matches}, - trace, - utils::{ - ArrayVecExt, IterStream, ReadyExt, set, - stream::{TryIgnore, WidebandExt}, - }, + utils::{IterStream, ReadyExt, stream::WidebandExt}, }; -use tuwunel_database::{Interfix, Map, keyval::Val}; -use crate::rooms::{ - short::ShortRoomId, - timeline::{PduId, RawPduId}, +use self::{ + backend::{BackendQuery, SearchBackend, SearchHit}, + opensearch::OpenSearchBackend, + rocksdb::RocksDbBackend, }; +use crate::rooms::{short::ShortRoomId, timeline::RawPduId}; pub struct Service { - db: Data, + backend: Box, + bulk_rx: Option>>, + opensearch: Option>, services: Arc, } -struct Data { - tokenids: Arc, -} - #[derive(Clone, Debug)] pub struct RoomQuery<'a> { pub room_id: &'a RoomId, @@ -38,52 +37,133 @@ pub struct RoomQuery<'a> { pub skip: usize, } -type TokenId = ArrayVec; - -const TOKEN_ID_MAX_LEN: usize = - size_of::() + WORD_MAX_LEN + 1 + size_of::(); -const WORD_MAX_LEN: usize = 50; - +#[async_trait] impl crate::Service for Service { fn build(args: &crate::Args<'_>) -> Result> { + let config = &args.server.config; + let (backend, bulk_rx, opensearch): (Box, _, _) = match config + .search_backend + { + | tuwunel_core::config::SearchBackendConfig::RocksDb => { + let b = RocksDbBackend::new(args.db["tokenids"].clone()); + (Box::new(b), None, None) + }, + | tuwunel_core::config::SearchBackendConfig::OpenSearch => { + let (b, rx) = OpenSearchBackend::new(config)?; + let b = Arc::new(b); + let backend: Box = Box::new(OpenSearchWorkerRef(b.clone())); + (backend, Some(tokio::sync::Mutex::new(rx)), Some(b)) + }, + }; + Ok(Arc::new(Self { - db: Data { tokenids: args.db["tokenids"].clone() }, + backend, + bulk_rx, + opensearch, services: args.services.clone(), })) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } + + async fn worker(self: Arc) -> Result { + if let (Some(os), Some(rx_mutex)) = (self.opensearch.as_ref(), self.bulk_rx.as_ref()) { + os.ensure_index().await?; + + let services = &**self.services; + let config = &services.server.config; + let batch_size = config.search_opensearch_batch_size; + let flush_interval = + std::time::Duration::from_millis(config.search_opensearch_flush_interval_ms); + + let mut rx = rx_mutex.lock().await; + os.process_bulk(&mut rx, batch_size, flush_interval) + .await; + } + + Ok(()) + } +} + +/// Wrapper so we can put the `Arc` behind the trait. +struct OpenSearchWorkerRef(Arc); + +#[async_trait] +impl SearchBackend for OpenSearchWorkerRef { + async fn index_pdu( + &self, + shortroomid: ShortRoomId, + pdu_id: &RawPduId, + room_id: &RoomId, + message_body: &str, + ) -> Result { + self.0 + .index_pdu(shortroomid, pdu_id, room_id, message_body) + .await + } + + async fn deindex_pdu( + &self, + shortroomid: ShortRoomId, + pdu_id: &RawPduId, + room_id: &RoomId, + message_body: &str, + ) -> Result { + self.0 + .deindex_pdu(shortroomid, pdu_id, room_id, message_body) + .await + } + + async fn search_pdu_ids( + &self, + query: &BackendQuery<'_>, + shortroomid: Option, + limit: usize, + skip: usize, + ) -> Result<(usize, Vec)> { + self.0 + .search_pdu_ids(query, shortroomid, limit, skip) + .await + } + + async fn delete_room(&self, room_id: &RoomId, shortroomid: ShortRoomId) -> Result { + self.0.delete_room(room_id, shortroomid).await + } + + fn supports_cross_room_search(&self) -> bool { self.0.supports_cross_room_search() } } #[implement(Service)] -pub fn index_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) { - let batch = tokenize(message_body) - .map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xFF); - key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here - key - }) - .collect::>(); - - self.db - .tokenids - .insert_batch(batch.iter().map(|k| (k.as_slice(), &[]))); +pub async fn index_pdu( + &self, + shortroomid: ShortRoomId, + pdu_id: &RawPduId, + room_id: &RoomId, + message_body: &str, +) { + if let Err(e) = self + .backend + .index_pdu(shortroomid, pdu_id, room_id, message_body) + .await + { + tuwunel_core::warn!("Failed to index PDU: {e}"); + } } #[implement(Service)] -pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) { - let batch = tokenize(message_body).map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xFF); - key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here - key - }); - - for token in batch { - self.db.tokenids.remove(&token); +pub async fn deindex_pdu( + &self, + shortroomid: ShortRoomId, + pdu_id: &RawPduId, + room_id: &RoomId, + message_body: &str, +) { + if let Err(e) = self + .backend + .deindex_pdu(shortroomid, pdu_id, room_id, message_body) + .await + { + tuwunel_core::warn!("Failed to deindex PDU: {e}"); } } @@ -92,12 +172,31 @@ pub async fn search_pdus<'a>( &'a self, query: &'a RoomQuery<'a>, ) -> Result<(usize, impl Stream> + Send + '_)> { - let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await; + let shortroomid = self + .services + .short + .get_shortroomid(query.room_id) + .await?; + + let backend_query = BackendQuery { + search_term: &query.criteria.search_term, + room_ids: None, + }; + + let (count, hits) = self + .backend + .search_pdu_ids( + &backend_query, + Some(shortroomid), + query.limit.saturating_add(query.skip), + 0, + ) + .await?; let filter = &query.criteria.filter; - let count = pdu_ids.len(); - let pdus = pdu_ids + let pdus = hits .into_iter() + .map(|h| h.pdu_id) .stream() .wide_filter_map(async |result_pdu_id: RawPduId| { self.services @@ -121,123 +220,64 @@ pub async fn search_pdus<'a>( Ok((count, pdus)) } -// result is modeled as a stream such that callers don't have to be refactored -// though an additional async/wrap still exists for now #[implement(Service)] -pub async fn search_pdu_ids( - &self, - query: &RoomQuery<'_>, -) -> Result + Send + '_ + use<'_>> { - let shortroomid = self - .services - .short - .get_shortroomid(query.room_id) +pub async fn search_pdus_multi_room<'a>( + &'a self, + room_ids: &'a [&'a RoomId], + sender_user: &'a UserId, + criteria: &'a Criteria, + limit: usize, + skip: usize, +) -> Result<(usize, impl Stream> + Send + '_)> { + let backend_query = BackendQuery { + search_term: &criteria.search_term, + room_ids: Some(room_ids), + }; + + let (count, hits) = self + .backend + .search_pdu_ids(&backend_query, None, limit.saturating_add(skip), 0) .await?; - let pdu_ids = self - .search_pdu_ids_query_room(query, shortroomid) - .await; - - let iters = pdu_ids.into_iter().map(IntoIterator::into_iter); - - Ok(set::intersection(iters).stream()) -} - -#[implement(Service)] -async fn search_pdu_ids_query_room( - &self, - query: &RoomQuery<'_>, - shortroomid: ShortRoomId, -) -> Vec> { - tokenize(&query.criteria.search_term) + let filter = &criteria.filter; + let pdus = hits + .into_iter() + .map(|h| h.pdu_id) .stream() - .wide_then(async |word| { - self.search_pdu_ids_query_words(shortroomid, &word) - .collect::>() + .wide_filter_map(async |result_pdu_id: RawPduId| { + self.services + .timeline + .get_pdu_from_id(&result_pdu_id) .await + .ok() }) - .collect::>() - .await -} - -/// Iterate over PduId's containing a word -#[implement(Service)] -fn search_pdu_ids_query_words<'a>( - &'a self, - shortroomid: ShortRoomId, - word: &'a str, -) -> impl Stream + Send + '_ { - self.search_pdu_ids_query_word(shortroomid, word) - .map(move |key| -> RawPduId { - let key = &key[prefix_len(word)..]; - key.into() + .ready_filter(|pdu| !pdu.is_redacted()) + .ready_filter(move |pdu| filter.matches(pdu)) + .wide_filter_map(async move |pdu| { + self.services + .state_accessor + .user_can_see_event(sender_user, pdu.room_id(), pdu.event_id()) + .await + .then_some(pdu) }) + .skip(skip) + .take(limit); + + Ok((count, pdus)) } -/// Iterate over raw database results for a word #[implement(Service)] -fn search_pdu_ids_query_word( - &self, - shortroomid: ShortRoomId, - word: &str, -) -> impl Stream> + Send + '_ + use<'_> { - // rustc says const'ing this not yet stable - let end_id: RawPduId = PduId { shortroomid, count: PduCount::max() }.into(); - - // Newest pdus first - let end = make_tokenid(shortroomid, word, &end_id); - let prefix = make_prefix(shortroomid, word); - self.db - .tokenids - .rev_raw_keys_from(&end) - .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix)) -} +pub fn supports_cross_room_search(&self) -> bool { self.backend.supports_cross_room_search() } #[implement(Service)] pub async fn delete_all_search_tokenids_for_room(&self, room_id: &RoomId) -> Result { - let prefix = (room_id, Interfix); + let shortroomid = self + .services + .short + .get_shortroomid(room_id) + .await?; - self.db - .tokenids - .keys_prefix_raw(&prefix) - .ignore_err() - .ready_for_each(|key| { - trace!("Removing key: {key:?}"); - self.db.tokenids.remove(key); - }) - .await; - - Ok(()) -} - -/// Splits a string into tokens used as keys in the search inverted index -/// -/// This may be used to tokenize both message bodies (for indexing) or search -/// queries (for querying). -fn tokenize(body: &str) -> impl Iterator + Send + '_ { - body.split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= WORD_MAX_LEN) - .map(str::to_lowercase) -} - -fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &RawPduId) -> TokenId { - let mut key = make_prefix(shortroomid, word); - key.extend_from_slice(pdu_id.as_ref()); - key -} - -fn make_prefix(shortroomid: ShortRoomId, word: &str) -> TokenId { - let mut key = TokenId::new(); - key.extend_from_slice(&shortroomid.to_be_bytes()); - key.extend_from_slice(word.as_bytes()); - key.push(tuwunel_database::SEP); - key -} - -fn prefix_len(word: &str) -> usize { - size_of::() - .saturating_add(word.len()) - .saturating_add(1) + self.backend + .delete_room(room_id, shortroomid) + .await } diff --git a/src/service/rooms/search/opensearch.rs b/src/service/rooms/search/opensearch.rs new file mode 100644 index 00000000..00ae82bb --- /dev/null +++ b/src/service/rooms/search/opensearch.rs @@ -0,0 +1,1442 @@ +use async_trait::async_trait; +use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; +use ruma::RoomId; +use serde_json::{Value, json}; +use tokio::sync::mpsc; +use tuwunel_core::{Result, debug, err, trace, warn}; + +use super::backend::{BackendQuery, SearchBackend, SearchHit}; +use crate::rooms::{short::ShortRoomId, timeline::RawPduId}; + +/// Action to be batched for the OpenSearch bulk API. +pub(super) enum BulkAction { + Index { + doc_id: String, + body: Value, + }, + Delete { + doc_id: String, + }, +} + +pub(super) struct OpenSearchBackend { + client: reqwest::Client, + base_url: String, + index: String, + tx: mpsc::Sender, + hybrid: bool, + model_id: Option, + embedding_dim: usize, + pipeline: String, +} + +impl OpenSearchBackend { + pub(super) fn new( + config: &tuwunel_core::Config, + ) -> Result<(Self, mpsc::Receiver)> { + let url = config + .search_opensearch_url + .as_ref() + .ok_or_else(|| { + err!(Config( + "search_opensearch_url", + "OpenSearch URL must be set when search_backend is opensearch" + )) + })?; + + let mut client_builder = reqwest::Client::builder(); + + if let Some((user, pass)) = config + .search_opensearch_auth + .as_deref() + .and_then(|auth| auth.split_once(':')) + { + let header_value = format!( + "Basic {}", + base64::engine::general_purpose::STANDARD.encode(format!("{user}:{pass}")) + ); + let mut headers = reqwest::header::HeaderMap::new(); + headers.insert( + reqwest::header::AUTHORIZATION, + header_value.parse().map_err(|e| { + err!(Config("search_opensearch_auth", "Invalid auth header: {e}")) + })?, + ); + client_builder = client_builder.default_headers(headers); + } + + let client = client_builder + .build() + .map_err(|e| err!("Failed to build reqwest client: {e}"))?; + + let (tx, rx) = mpsc::channel( + config + .search_opensearch_batch_size + .saturating_mul(4), + ); + + let base_url = url.as_str().trim_end_matches('/').to_owned(); + + Ok(( + Self { + client, + base_url, + index: config.search_opensearch_index.clone(), + tx, + hybrid: config.search_opensearch_hybrid, + model_id: config.search_opensearch_model_id.clone(), + embedding_dim: config.search_opensearch_embedding_dim, + pipeline: config.search_opensearch_pipeline.clone(), + }, + rx, + )) + } + + /// Ensure the index exists with the correct mapping. + pub(super) async fn ensure_index(&self) -> Result { + let url = format!("{}/{}", self.base_url, self.index); + let resp = self.client.head(&url).send().await; + + match resp { + | Ok(r) if r.status().is_success() => { + debug!("OpenSearch index '{}' already exists", self.index); + return Ok(()); + }, + | Ok(r) if r.status().as_u16() == 404 => { + debug!("Creating OpenSearch index '{}'", self.index); + }, + | Ok(r) => { + warn!("Unexpected response checking OpenSearch index: {}", r.status()); + }, + | Err(e) => { + return Err(err!("Failed to check OpenSearch index: {e}")); + }, + } + + let mut mapping = json!({ + "settings": { + "analysis": { + "filter": { + "english_stop": { + "type": "stop", + "stopwords": "_english_" + }, + "english_stemmer": { + "type": "stemmer", + "language": "english" + }, + "english_possessive": { + "type": "stemmer", + "language": "possessive_english" + } + }, + "analyzer": { + "message_analyzer": { + "type": "custom", + "tokenizer": "standard", + "filter": [ + "lowercase", + "english_possessive", + "english_stop", + "english_stemmer" + ] + } + } + } + }, + "mappings": { + "properties": { + "room_id": { "type": "keyword" }, + "pdu_id": { "type": "keyword" }, + "body": { + "type": "text", + "analyzer": "message_analyzer", + "fields": { + "exact": { + "type": "text", + "analyzer": "standard" + } + } + }, + "sender": { "type": "keyword" }, + "origin_server_ts": { "type": "date" } + } + } + }); + + // Add knn_vector field and ingest pipeline for hybrid neural search + if self.hybrid { + mapping["mappings"]["properties"]["embedding"] = json!({ + "type": "knn_vector", + "dimension": self.embedding_dim, + "method": { + "name": "hnsw", + "space_type": "l2", + "engine": "lucene" + } + }); + mapping["settings"]["index"] = json!({ + "knn": true, + "default_pipeline": self.pipeline, + }); + } + + let resp = self + .client + .put(&url) + .json(&mapping) + .send() + .await + .map_err(|e| err!("Failed to create OpenSearch index: {e}"))?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(err!("Failed to create OpenSearch index: {body}")); + } + + debug!("Created OpenSearch index '{}'", self.index); + + if self.hybrid { + self.ensure_search_pipeline().await?; + } + + Ok(()) + } + + /// Ensure the hybrid search pipeline exists for score normalization. + async fn ensure_search_pipeline(&self) -> Result { + let pipeline_name = format!("{}_hybrid_pipeline", self.index); + let url = format!("{}/_search/pipeline/{}", self.base_url, pipeline_name); + + // Check if it already exists + if self + .client + .get(&url) + .send() + .await + .is_ok_and(|r| r.status().is_success()) + { + debug!("Hybrid search pipeline '{pipeline_name}' already exists"); + return Ok(()); + } + + let body = json!({ + "description": "Hybrid search score normalization for tuwunel", + "phase_results_processors": [{ + "normalization-processor": { + "normalization": { "technique": "min_max" }, + "combination": { + "technique": "arithmetic_mean", + "parameters": { "weights": [0.3, 0.7] } + } + } + }] + }); + + let resp = self + .client + .put(&url) + .json(&body) + .send() + .await + .map_err(|e| err!("Failed to create search pipeline: {e}"))?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(err!("Failed to create search pipeline: {body}")); + } + + debug!("Created hybrid search pipeline '{pipeline_name}'"); + Ok(()) + } + + fn doc_id(pdu_id: &RawPduId) -> String { URL_SAFE_NO_PAD.encode(pdu_id.as_ref()) } + + /// Build the OpenSearch query body for a search request. + /// + /// Uses a bool query with: + /// - A `match_phrase` on the stemmed body field (boosted 3x) for exact + /// phrase matches + /// - A `match` on the stemmed body field with `auto` fuzziness and + /// `minimum_should_match` for broad term matching + /// - A `match` on the exact (standard-analyzed) sub-field for non-stemmed + /// matches + /// - Room ID filtering via `filter` clause (no scoring impact) + /// + /// When hybrid search is enabled and a model_id is configured, wraps + /// in a `hybrid` query combining BM25 + neural kNN search. + fn build_search_query(&self, query: &BackendQuery<'_>, limit: usize, skip: usize) -> Value { + // Phrase match on stemmed body — high boost for exact phrase hits + let phrase_match = json!({ + "match_phrase": { + "body": { + "query": query.search_term, + "boost": 3.0, + "slop": 1 + } + } + }); + + // Fuzzy individual-term match on stemmed body + let fuzzy_match = json!({ + "match": { + "body": { + "query": query.search_term, + "fuzziness": "AUTO", + "minimum_should_match": "75%" + } + } + }); + + // Exact (non-stemmed) match on the standard-analyzed sub-field + let exact_match = json!({ + "match": { + "body.exact": { + "query": query.search_term, + "boost": 1.5 + } + } + }); + + let mut filter = Vec::new(); + if let Some(room_ids) = query.room_ids { + let ids: Vec<&str> = room_ids.iter().map(|r| r.as_str()).collect(); + filter.push(json!({ "terms": { "room_id": ids } })); + } + + let bm25_query = json!({ + "bool": { + "should": [phrase_match, fuzzy_match, exact_match], + "minimum_should_match": 1, + "filter": filter, + } + }); + + if let (true, Some(model_id)) = (self.hybrid, &self.model_id) { + // For hybrid queries with room filtering, we need to apply + // the filter via post_filter so it applies to results from + // BOTH the BM25 and neural subqueries. + let mut result = json!({ + "query": { + "hybrid": { + "queries": [ + bm25_query, + { + "neural": { + "embedding": { + "query_text": query.search_term, + "model_id": model_id, + "k": limit + } + } + } + ] + } + }, + "from": skip, + "size": limit, + "_source": ["pdu_id"], + }); + + if !filter.is_empty() { + result["post_filter"] = json!({ + "bool": { "filter": filter } + }); + } + + return result; + } + + json!({ + "query": bm25_query, + "from": skip, + "size": limit, + "_source": ["pdu_id"], + }) + } + + /// Process bulk actions from the channel. Called from the service worker. + pub(super) async fn process_bulk( + &self, + rx: &mut mpsc::Receiver, + batch_size: usize, + flush_interval: std::time::Duration, + ) { + let mut buf = Vec::with_capacity(batch_size); + loop { + buf.clear(); + + // Wait for first item or channel close + match rx.recv().await { + | Some(action) => buf.push(action), + | None => return, + } + + // Drain up to batch_size using a timeout + let deadline = tokio::time::Instant::now() + .checked_add(flush_interval) + .expect("flush interval overflow"); + while buf.len() < batch_size { + match tokio::time::timeout_at(deadline, rx.recv()).await { + | Ok(Some(action)) => buf.push(action), + | Ok(None) | Err(_) => break, + } + } + + if buf.is_empty() { + continue; + } + + if let Err(e) = self.flush_bulk(&buf).await { + warn!("OpenSearch bulk flush error: {e}"); + } + } + } + + async fn flush_bulk(&self, actions: &[BulkAction]) -> Result { + let mut ndjson = String::new(); + for action in actions { + match action { + | BulkAction::Index { doc_id, body } => { + ndjson.push_str( + &serde_json::to_string(&json!({ + "index": { "_index": self.index, "_id": doc_id } + })) + .unwrap(), + ); + ndjson.push('\n'); + ndjson.push_str(&serde_json::to_string(body).unwrap()); + ndjson.push('\n'); + }, + | BulkAction::Delete { doc_id } => { + ndjson.push_str( + &serde_json::to_string(&json!({ + "delete": { "_index": self.index, "_id": doc_id } + })) + .unwrap(), + ); + ndjson.push('\n'); + }, + } + } + + let url = format!("{}/_bulk", self.base_url); + let resp = self + .client + .post(&url) + .header("Content-Type", "application/x-ndjson") + .body(ndjson) + .send() + .await + .map_err(|e| err!("OpenSearch bulk request failed: {e}"))?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(err!("OpenSearch bulk request failed: {body}")); + } + + trace!("OpenSearch bulk: flushed {} actions", actions.len()); + Ok(()) + } +} + +#[async_trait] +impl SearchBackend for OpenSearchBackend { + async fn index_pdu( + &self, + _shortroomid: ShortRoomId, + pdu_id: &RawPduId, + room_id: &RoomId, + message_body: &str, + ) -> Result { + let doc_id = Self::doc_id(pdu_id); + let body = json!({ + "room_id": room_id.as_str(), + "pdu_id": doc_id, + "body": message_body, + }); + + self.tx + .send(BulkAction::Index { doc_id, body }) + .await + .map_err(|_| err!("OpenSearch bulk channel closed"))?; + + Ok(()) + } + + async fn deindex_pdu( + &self, + _shortroomid: ShortRoomId, + pdu_id: &RawPduId, + _room_id: &RoomId, + _message_body: &str, + ) -> Result { + let doc_id = Self::doc_id(pdu_id); + + self.tx + .send(BulkAction::Delete { doc_id }) + .await + .map_err(|_| err!("OpenSearch bulk channel closed"))?; + + Ok(()) + } + + async fn search_pdu_ids( + &self, + query: &BackendQuery<'_>, + _shortroomid: Option, + limit: usize, + skip: usize, + ) -> Result<(usize, Vec)> { + let body = self.build_search_query(query, limit, skip); + + let url = if self.hybrid { + let pipeline = format!("{}_hybrid_pipeline", self.index); + format!("{}/{}/_search?search_pipeline={}", self.base_url, self.index, pipeline) + } else { + format!("{}/{}/_search", self.base_url, self.index) + }; + let resp = self + .client + .post(&url) + .json(&body) + .send() + .await + .map_err(|e| err!("OpenSearch search request failed: {e}"))?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(err!("OpenSearch search failed: {body}")); + } + + let json: Value = resp + .json() + .await + .map_err(|e| err!("OpenSearch response parse error: {e}"))?; + + let total = usize::try_from( + json["hits"]["total"]["value"] + .as_u64() + .unwrap_or(0), + ) + .unwrap_or(usize::MAX); + + let hits = json["hits"]["hits"] + .as_array() + .map(|arr| { + arr.iter() + .filter_map(|hit| { + let pdu_id_b64 = hit["_source"]["pdu_id"].as_str()?; + let bytes = URL_SAFE_NO_PAD.decode(pdu_id_b64).ok()?; + let pdu_id: RawPduId = bytes.as_slice().into(); + let score = hit["_score"].as_f64(); + Some(SearchHit { pdu_id, score }) + }) + .collect() + }) + .unwrap_or_default(); + + Ok((total, hits)) + } + + async fn delete_room(&self, room_id: &RoomId, _shortroomid: ShortRoomId) -> Result { + let body = json!({ + "query": { + "term": { "room_id": room_id.as_str() } + } + }); + + let url = format!("{}/{}/_delete_by_query", self.base_url, self.index); + let resp = self + .client + .post(&url) + .json(&body) + .send() + .await + .map_err(|e| err!("OpenSearch delete_by_query failed: {e}"))?; + + if !resp.status().is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(err!("OpenSearch delete_by_query failed: {body}")); + } + + Ok(()) + } + + fn supports_cross_room_search(&self) -> bool { true } +} + +#[cfg(test)] +mod tests { + use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD}; + use ruma::RoomId; + use serde_json::json; + + use super::*; + + /// Test-only constructor that bypasses tuwunel Config. + /// Reads `OPENSEARCH_URL` env var, defaulting to `http://localhost:9200`. + const EMBEDDING_DIM: usize = 384; + const MODEL_NAME: &str = "huggingface/sentence-transformers/all-MiniLM-L6-v2"; + const MODEL_VERSION: &str = "1.0.2"; + const MODEL_FORMAT: &str = "TORCH_SCRIPT"; + + fn base_url() -> String { + std::env::var("OPENSEARCH_URL").unwrap_or_else(|_| "http://localhost:9200".to_owned()) + } + + fn test_backend(index: &str) -> (OpenSearchBackend, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(256); + let backend = OpenSearchBackend { + client: reqwest::Client::new(), + base_url: base_url(), + index: index.to_owned(), + tx, + hybrid: false, + model_id: None, + embedding_dim: EMBEDDING_DIM, + pipeline: "tuwunel_embedding_pipeline".to_owned(), + }; + (backend, rx) + } + + fn test_backend_hybrid( + index: &str, + model_id: &str, + ) -> (OpenSearchBackend, mpsc::Receiver) { + let (tx, rx) = mpsc::channel(256); + let backend = OpenSearchBackend { + client: reqwest::Client::new(), + base_url: base_url(), + index: index.to_owned(), + tx, + hybrid: true, + model_id: Some(model_id.to_owned()), + embedding_dim: EMBEDDING_DIM, + pipeline: format!("{index}_ingest_pipeline"), + }; + (backend, rx) + } + + /// Idempotently configure OpenSearch cluster for ML, register and deploy + /// the embedding model, and return the model_id. This is safe to call + /// repeatedly — it reuses existing resources. + async fn ensure_neural_model(client: &reqwest::Client, base_url: &str) -> String { + // Step 1: Configure cluster to allow ML on non-ML nodes + let resp = client + .put(&format!("{base_url}/_cluster/settings")) + .json(&json!({ + "persistent": { + "plugins.ml_commons.only_run_on_ml_node": "false", + "plugins.ml_commons.native_memory_threshold": "99", + "plugins.ml_commons.allow_registering_model_via_url": "true" + } + })) + .send() + .await + .unwrap(); + assert!( + resp.status().is_success(), + "Failed to configure cluster ML settings: {}", + resp.text().await.unwrap_or_default() + ); + + // Step 2: Check if we already have a deployed model + let resp = client + .post(&format!("{base_url}/_plugins/_ml/models/_search")) + .json(&json!({ + "query": { + "bool": { + "must": [ + { "term": { "name.keyword": MODEL_NAME } }, + { "term": { "model_state": "DEPLOYED" } } + ] + } + }, + "size": 1 + })) + .send() + .await + .unwrap(); + + if resp.status().is_success() { + let body: Value = resp.json().await.unwrap(); + if let Some(hit) = body["hits"]["hits"] + .as_array() + .and_then(|a| a.first()) + { + let model_id = hit["_id"].as_str().unwrap().to_owned(); + eprintln!("Reusing deployed model: {model_id}"); + return model_id; + } + } + + // Step 3: Get or create a model group + let model_group_id = get_or_create_model_group(client, base_url).await; + + // Step 4: Register + deploy the model + let resp = client + .post(&format!("{base_url}/_plugins/_ml/models/_register?deploy=true")) + .json(&json!({ + "name": MODEL_NAME, + "version": MODEL_VERSION, + "model_group_id": model_group_id, + "model_format": MODEL_FORMAT + })) + .send() + .await + .unwrap(); + let task_body: Value = resp.json().await.unwrap(); + let task_id = task_body["task_id"] + .as_str() + .expect("Failed to get task_id from model register") + .to_owned(); + + // Step 5: Poll until the task completes (model download + deploy) + let model_id = poll_ml_task(client, base_url, &task_id).await; + eprintln!("Model deployed: {model_id}"); + model_id + } + + /// Get or create the test model group, returning its ID. + async fn get_or_create_model_group(client: &reqwest::Client, base_url: &str) -> String { + // Try to find an existing group first + let resp = client + .post(&format!("{base_url}/_plugins/_ml/model_groups/_search")) + .json(&json!({ + "query": { "term": { "name.keyword": "tuwunel_test_group" } } + })) + .send() + .await + .unwrap(); + + if resp.status().is_success() { + let body: Value = resp.json().await.unwrap(); + if let Some(hit) = body["hits"]["hits"] + .as_array() + .and_then(|a| a.first()) + { + let id = hit["_id"].as_str().unwrap().to_owned(); + eprintln!("Reusing model group: {id}"); + return id; + } + } + + // Create new group (no access_mode — security plugin is disabled in tests) + let resp = client + .post(&format!("{base_url}/_plugins/_ml/model_groups/_register")) + .json(&json!({ + "name": "tuwunel_test_group", + "description": "Test model group for tuwunel CI" + })) + .send() + .await + .unwrap(); + + let body: Value = resp.json().await.unwrap(); + body["model_group_id"] + .as_str() + .expect("Failed to create model group") + .to_owned() + } + + /// Poll an ML task until completion, returning the model_id. + async fn poll_ml_task(client: &reqwest::Client, base_url: &str, task_id: &str) -> String { + let url = format!("{base_url}/_plugins/_ml/tasks/{task_id}"); + for i in 0..120 { + tokio::time::sleep(std::time::Duration::from_secs(2)).await; + let resp = client.get(&url).send().await.unwrap(); + let body: Value = resp.json().await.unwrap(); + let state = body["state"].as_str().unwrap_or("UNKNOWN"); + match state { + | "COMPLETED" => { + return body["model_id"] + .as_str() + .expect("No model_id in completed task") + .to_owned(); + }, + | "FAILED" => { + panic!( + "ML task failed after {i} polls: {}", + body["error"].as_str().unwrap_or("unknown error") + ); + }, + | _ => + if i % 5 == 0 { + eprintln!("Waiting for ML task {task_id}: {state} ({i}/120)"); + }, + } + } + panic!("ML task {task_id} did not complete within 240 seconds"); + } + + /// Idempotently create an ingest pipeline for the given model. + async fn ensure_ingest_pipeline( + client: &reqwest::Client, + base_url: &str, + pipeline_name: &str, + model_id: &str, + ) { + let url = format!("{base_url}/_ingest/pipeline/{pipeline_name}"); + + // Check if it already exists with the right model + if let Ok(resp) = client.get(&url).send().await { + if resp.status().is_success() { + let body: Value = resp.json().await.unwrap(); + // Check if the existing pipeline uses the same model + if let Some(processors) = body[pipeline_name]["processors"].as_array() { + for p in processors { + if p["text_embedding"]["model_id"].as_str() == Some(model_id) { + eprintln!("Ingest pipeline '{pipeline_name}' already exists"); + return; + } + } + } + } + } + + let body = json!({ + "description": "Embedding pipeline for tuwunel neural search tests", + "processors": [{ + "text_embedding": { + "model_id": model_id, + "field_map": { + "body": "embedding" + } + } + }] + }); + + let resp = client.put(&url).json(&body).send().await.unwrap(); + assert!( + resp.status().is_success(), + "Failed to create ingest pipeline: {}", + resp.text().await.unwrap_or_default() + ); + eprintln!("Created ingest pipeline '{pipeline_name}'"); + } + + /// Delete a search pipeline (cleanup helper). + async fn delete_search_pipeline(client: &reqwest::Client, base_url: &str, name: &str) { + let url = format!("{base_url}/_search/pipeline/{name}"); + let _ = client.delete(&url).send().await; + } + + /// Helper: make a fake RawPduId (NORMAL_LEN = 16 bytes: 8-byte + /// shortroomid + 8-byte count). + fn make_pdu_id(shortroomid: u64, count: u64) -> RawPduId { + let mut bytes = [0u8; 16]; + bytes[..8].copy_from_slice(&shortroomid.to_be_bytes()); + bytes[8..].copy_from_slice(&count.to_be_bytes()); + bytes.as_slice().into() + } + + /// Helper: directly index a document (bypass bulk channel) for test setup. + async fn direct_index( + client: &reqwest::Client, + base_url: &str, + index: &str, + pdu_id: &RawPduId, + room_id: &RoomId, + body_text: &str, + ) { + let doc_id = URL_SAFE_NO_PAD.encode(pdu_id.as_ref()); + let doc = json!({ + "room_id": room_id.as_str(), + "pdu_id": doc_id, + "body": body_text, + }); + let url = format!("{}/{}/_doc/{}", base_url, index, doc_id); + let resp = client.put(&url).json(&doc).send().await.unwrap(); + assert!( + resp.status().is_success(), + "Failed to index doc: {}", + resp.text().await.unwrap_or_default() + ); + } + + /// Helper: refresh the index so documents become searchable. + async fn refresh_index(client: &reqwest::Client, base_url: &str, index: &str) { + let url = format!("{}/{}/_refresh", base_url, index); + client.post(&url).send().await.unwrap(); + } + + /// Helper: delete the test index. + async fn delete_index(client: &reqwest::Client, base_url: &str, index: &str) { + let url = format!("{}/{}", base_url, index); + let _ = client.delete(&url).send().await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance"] + async fn test_ensure_index_creates_and_is_idempotent() { + let index = "test_ensure_index"; + let (backend, _rx) = test_backend(index); + delete_index(&backend.client, &backend.base_url, index).await; + + // First call creates the index + backend + .ensure_index() + .await + .expect("ensure_index should succeed"); + + // Second call should be idempotent + backend + .ensure_index() + .await + .expect("ensure_index should be idempotent"); + + // Verify it exists + let resp = backend + .client + .head(&format!("{}/{}", backend.base_url, index)) + .send() + .await + .unwrap(); + assert!(resp.status().is_success()); + + delete_index(&backend.client, &backend.base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance"] + async fn test_search_returns_indexed_documents() { + let index = "test_search_basic"; + let (backend, _rx) = test_backend(index); + delete_index(&backend.client, &backend.base_url, index).await; + backend.ensure_index().await.unwrap(); + + let room_id = <&RoomId>::try_from("!testroom:example.com").unwrap(); + let pdu1 = make_pdu_id(1, 1); + let pdu2 = make_pdu_id(1, 2); + + direct_index( + &backend.client, + &backend.base_url, + index, + &pdu1, + room_id, + "hello world this is a test message", + ) + .await; + direct_index( + &backend.client, + &backend.base_url, + index, + &pdu2, + room_id, + "goodbye cruel world", + ) + .await; + refresh_index(&backend.client, &backend.base_url, index).await; + + // Search for "hello" + let query = BackendQuery { search_term: "hello", room_ids: None }; + let (total, hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 1, "should find exactly one result for 'hello'"); + assert_eq!(hits.len(), 1); + assert_eq!(hits[0].pdu_id.as_ref(), pdu1.as_ref()); + + // Search for "world" — should match both + let query = BackendQuery { search_term: "world", room_ids: None }; + let (total, _hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 2, "should find two results for 'world'"); + + delete_index(&backend.client, &backend.base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance"] + async fn test_search_with_room_filter() { + let index = "test_search_room_filter"; + let (backend, _rx) = test_backend(index); + delete_index(&backend.client, &backend.base_url, index).await; + backend.ensure_index().await.unwrap(); + + let room_a = <&RoomId>::try_from("!roomA:example.com").unwrap(); + let room_b = <&RoomId>::try_from("!roomB:example.com").unwrap(); + let pdu_a = make_pdu_id(1, 10); + let pdu_b = make_pdu_id(2, 11); + + direct_index( + &backend.client, + &backend.base_url, + index, + &pdu_a, + room_a, + "important meeting notes", + ) + .await; + direct_index( + &backend.client, + &backend.base_url, + index, + &pdu_b, + room_b, + "important security update", + ) + .await; + refresh_index(&backend.client, &backend.base_url, index).await; + + // Search "important" in room_a only + let room_ids: Vec<&RoomId> = vec![room_a]; + let query = BackendQuery { + search_term: "important", + room_ids: Some(&room_ids), + }; + let (total, hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 1, "should find one result in room_a"); + assert_eq!(hits[0].pdu_id.as_ref(), pdu_a.as_ref()); + + // Search "important" across both rooms + let room_ids: Vec<&RoomId> = vec![room_a, room_b]; + let query = BackendQuery { + search_term: "important", + room_ids: Some(&room_ids), + }; + let (total, _hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 2, "should find two results across both rooms"); + + delete_index(&backend.client, &backend.base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance"] + async fn test_delete_room_removes_all_documents() { + let index = "test_delete_room"; + let (backend, _rx) = test_backend(index); + delete_index(&backend.client, &backend.base_url, index).await; + backend.ensure_index().await.unwrap(); + + let room_id = <&RoomId>::try_from("!delroom:example.com").unwrap(); + let pdu1 = make_pdu_id(1, 0x21); + let pdu2 = make_pdu_id(1, 0x22); + + direct_index(&backend.client, &backend.base_url, index, &pdu1, room_id, "message one") + .await; + direct_index(&backend.client, &backend.base_url, index, &pdu2, room_id, "message two") + .await; + refresh_index(&backend.client, &backend.base_url, index).await; + + // Delete the room + backend.delete_room(room_id, 1).await.unwrap(); + refresh_index(&backend.client, &backend.base_url, index).await; + + // Verify no documents remain + let query = BackendQuery { search_term: "message", room_ids: None }; + let (total, _hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 0, "should find no results after room deletion"); + + delete_index(&backend.client, &backend.base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance"] + async fn test_bulk_flush_indexes_documents() { + let index = "test_bulk_flush"; + let (backend, mut rx) = test_backend(index); + delete_index(&backend.client, &backend.base_url, index).await; + backend.ensure_index().await.unwrap(); + + let room_id = <&RoomId>::try_from("!bulkroom:example.com").unwrap(); + let pdu1 = make_pdu_id(1, 0x31); + let pdu2 = make_pdu_id(1, 0x32); + + // Send index actions through the channel + backend + .index_pdu(0, &pdu1, room_id, "bulk test message one") + .await + .unwrap(); + backend + .index_pdu(0, &pdu2, room_id, "bulk test message two") + .await + .unwrap(); + + // Manually drain and flush (simulating the worker) + let mut actions = Vec::new(); + while let Ok(action) = rx.try_recv() { + actions.push(action); + } + assert_eq!(actions.len(), 2, "should have 2 bulk actions"); + backend.flush_bulk(&actions).await.unwrap(); + + refresh_index(&backend.client, &backend.base_url, index).await; + + let query = BackendQuery { search_term: "bulk test", room_ids: None }; + let (total, _hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 2, "should find both bulk-indexed documents"); + + delete_index(&backend.client, &backend.base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance"] + async fn test_deindex_via_bulk_removes_document() { + let index = "test_deindex_bulk"; + let (backend, mut rx) = test_backend(index); + delete_index(&backend.client, &backend.base_url, index).await; + backend.ensure_index().await.unwrap(); + + let room_id = <&RoomId>::try_from("!deindexroom:example.com").unwrap(); + let pdu1 = make_pdu_id(1, 0x41); + + // Index directly first + direct_index(&backend.client, &backend.base_url, index, &pdu1, room_id, "to be removed") + .await; + refresh_index(&backend.client, &backend.base_url, index).await; + + // Deindex via bulk + backend + .deindex_pdu(0, &pdu1, room_id, "to be removed") + .await + .unwrap(); + + let mut actions = Vec::new(); + while let Ok(action) = rx.try_recv() { + actions.push(action); + } + backend.flush_bulk(&actions).await.unwrap(); + refresh_index(&backend.client, &backend.base_url, index).await; + + let query = BackendQuery { search_term: "removed", room_ids: None }; + let (total, _hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 0, "document should be removed after deindex"); + + delete_index(&backend.client, &backend.base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance"] + async fn test_search_pagination() { + let index = "test_search_pagination"; + let (backend, _rx) = test_backend(index); + delete_index(&backend.client, &backend.base_url, index).await; + backend.ensure_index().await.unwrap(); + + let room_id = <&RoomId>::try_from("!pagroom:example.com").unwrap(); + + // Index 5 documents + for i in 0u64..5 { + let pdu = make_pdu_id(1, 0x50 + i); + direct_index( + &backend.client, + &backend.base_url, + index, + &pdu, + room_id, + &format!("pagination test message number {i}"), + ) + .await; + } + refresh_index(&backend.client, &backend.base_url, index).await; + + // Get first 2 + let query = BackendQuery { + search_term: "pagination test", + room_ids: None, + }; + let (total, hits) = backend + .search_pdu_ids(&query, None, 2, 0) + .await + .unwrap(); + assert_eq!(total, 5, "total should be 5"); + assert_eq!(hits.len(), 2, "should return 2 hits with limit=2"); + + // Skip 2, get next 2 + let (total, hits) = backend + .search_pdu_ids(&query, None, 2, 2) + .await + .unwrap(); + assert_eq!(total, 5, "total should still be 5"); + assert_eq!(hits.len(), 2, "should return 2 hits with skip=2 limit=2"); + + // Skip 4, get remaining + let (total, hits) = backend + .search_pdu_ids(&query, None, 10, 4) + .await + .unwrap(); + assert_eq!(total, 5); + assert_eq!(hits.len(), 1, "should return 1 remaining hit"); + + delete_index(&backend.client, &backend.base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance"] + async fn test_cross_room_search() { + let index = "test_cross_room"; + let (backend, _rx) = test_backend(index); + delete_index(&backend.client, &backend.base_url, index).await; + backend.ensure_index().await.unwrap(); + + assert!(backend.supports_cross_room_search()); + + let room_a = <&RoomId>::try_from("!crossA:example.com").unwrap(); + let room_b = <&RoomId>::try_from("!crossB:example.com").unwrap(); + let room_c = <&RoomId>::try_from("!crossC:example.com").unwrap(); + + let pdu_a = make_pdu_id(1, 0x61); + let pdu_b = make_pdu_id(2, 0x62); + let pdu_c = make_pdu_id(3, 0x63); + + direct_index( + &backend.client, + &backend.base_url, + index, + &pdu_a, + room_a, + "secret plan alpha", + ) + .await; + direct_index( + &backend.client, + &backend.base_url, + index, + &pdu_b, + room_b, + "secret plan beta", + ) + .await; + direct_index( + &backend.client, + &backend.base_url, + index, + &pdu_c, + room_c, + "secret plan gamma", + ) + .await; + refresh_index(&backend.client, &backend.base_url, index).await; + + // Search across rooms A and B only (user has access to these) + let accessible: Vec<&RoomId> = vec![room_a, room_b]; + let query = BackendQuery { + search_term: "secret plan", + room_ids: Some(&accessible), + }; + let (total, hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 2, "should only find results from accessible rooms"); + assert_eq!(hits.len(), 2); + + // Verify room_c's message is NOT in results + let pdu_ids: Vec> = hits + .iter() + .map(|h| h.pdu_id.as_ref().to_vec()) + .collect(); + assert!(!pdu_ids.contains(&pdu_c.as_ref().to_vec()), "room_c result should not appear"); + + delete_index(&backend.client, &backend.base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance with ML plugin"] + async fn test_neural_model_setup_is_idempotent() { + let client = reqwest::Client::new(); + let base_url = base_url(); + + let model_id_1 = ensure_neural_model(&client, &base_url).await; + let model_id_2 = ensure_neural_model(&client, &base_url).await; + assert_eq!(model_id_1, model_id_2, "Repeated setup should return the same model ID"); + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance with ML plugin"] + async fn test_hybrid_search_returns_results() { + let client = reqwest::Client::new(); + let base_url = base_url(); + let index = "test_hybrid_search"; + + let model_id = ensure_neural_model(&client, &base_url).await; + let pipeline_name = format!("{index}_ingest_pipeline"); + ensure_ingest_pipeline(&client, &base_url, &pipeline_name, &model_id).await; + + let (backend, _rx) = test_backend_hybrid(index, &model_id); + delete_index(&client, &base_url, index).await; + delete_search_pipeline(&client, &base_url, &format!("{index}_hybrid_pipeline")).await; + backend.ensure_index().await.unwrap(); + + let room_id = <&RoomId>::try_from("!hybridroom:example.com").unwrap(); + let pdu1 = make_pdu_id(1, 0x71); + let pdu2 = make_pdu_id(1, 0x72); + + // Index via the ingest pipeline (embeddings generated automatically) + direct_index(&client, &base_url, index, &pdu1, room_id, "the quick brown fox jumps over") + .await; + direct_index( + &client, + &base_url, + index, + &pdu2, + room_id, + "meeting notes for the project deadline", + ) + .await; + refresh_index(&client, &base_url, index).await; + + // Hybrid search — combines BM25 and neural scoring + let query = BackendQuery { + search_term: "fast animal", + room_ids: None, + }; + let (total, hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + + assert!(total > 0, "hybrid search should return at least one result"); + assert!( + !hits.is_empty(), + "hybrid search should have hits for semantically related query" + ); + + delete_search_pipeline(&client, &base_url, &format!("{index}_hybrid_pipeline")).await; + delete_index(&client, &base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance with ML plugin"] + async fn test_hybrid_search_with_room_filter() { + let client = reqwest::Client::new(); + let base_url = base_url(); + let index = "test_hybrid_room_filter"; + + let model_id = ensure_neural_model(&client, &base_url).await; + let pipeline_name = format!("{index}_ingest_pipeline"); + ensure_ingest_pipeline(&client, &base_url, &pipeline_name, &model_id).await; + + let (backend, _rx) = test_backend_hybrid(index, &model_id); + delete_index(&client, &base_url, index).await; + delete_search_pipeline(&client, &base_url, &format!("{index}_hybrid_pipeline")).await; + backend.ensure_index().await.unwrap(); + + let room_a = <&RoomId>::try_from("!hybridA:example.com").unwrap(); + let room_b = <&RoomId>::try_from("!hybridB:example.com").unwrap(); + + direct_index( + &client, + &base_url, + index, + &make_pdu_id(1, 0x81), + room_a, + "rust programming language memory safety", + ) + .await; + direct_index( + &client, + &base_url, + index, + &make_pdu_id(2, 0x82), + room_b, + "python programming language scripting", + ) + .await; + refresh_index(&client, &base_url, index).await; + + // Search only in room_a + let room_ids: Vec<&RoomId> = vec![room_a]; + let query = BackendQuery { + search_term: "programming", + room_ids: Some(&room_ids), + }; + let (total, hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + assert_eq!(total, 1, "should only find result in room_a"); + assert_eq!(hits.len(), 1); + + delete_search_pipeline(&client, &base_url, &format!("{index}_hybrid_pipeline")).await; + delete_index(&client, &base_url, index).await; + } + + #[tokio::test] + #[ignore = "requires running OpenSearch instance with ML plugin"] + async fn test_hybrid_semantic_relevance() { + let client = reqwest::Client::new(); + let base_url = base_url(); + let index = "test_hybrid_semantic"; + + let model_id = ensure_neural_model(&client, &base_url).await; + let pipeline_name = format!("{index}_ingest_pipeline"); + ensure_ingest_pipeline(&client, &base_url, &pipeline_name, &model_id).await; + + let (backend, _rx) = test_backend_hybrid(index, &model_id); + delete_index(&client, &base_url, index).await; + delete_search_pipeline(&client, &base_url, &format!("{index}_hybrid_pipeline")).await; + backend.ensure_index().await.unwrap(); + + let room_id = <&RoomId>::try_from("!semantic:example.com").unwrap(); + + // Index messages with different semantic content + direct_index( + &client, + &base_url, + index, + &make_pdu_id(1, 0x91), + room_id, + "I am feeling very happy and excited about the good news", + ) + .await; + direct_index( + &client, + &base_url, + index, + &make_pdu_id(1, 0x92), + room_id, + "The database server configuration needs to be updated", + ) + .await; + direct_index( + &client, + &base_url, + index, + &make_pdu_id(1, 0x93), + room_id, + "She was overjoyed and thrilled with the wonderful surprise", + ) + .await; + refresh_index(&client, &base_url, index).await; + + // Semantic search: "joyful emotions" shouldn't match BM25 well + // but should match semantically to the happy/overjoyed messages + let query = BackendQuery { + search_term: "joyful emotions", + room_ids: None, + }; + let (_total, hits) = backend + .search_pdu_ids(&query, None, 10, 0) + .await + .unwrap(); + + // The neural component should surface semantically related results + // even though "joyful" and "emotions" don't appear literally + assert!(!hits.is_empty(), "semantic search should find results for related concepts"); + + delete_search_pipeline(&client, &base_url, &format!("{index}_hybrid_pipeline")).await; + delete_index(&client, &base_url, index).await; + } +} diff --git a/src/service/rooms/search/rocksdb.rs b/src/service/rooms/search/rocksdb.rs new file mode 100644 index 00000000..62b98b94 --- /dev/null +++ b/src/service/rooms/search/rocksdb.rs @@ -0,0 +1,193 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use futures::StreamExt; +use ruma::RoomId; +use tuwunel_core::{ + PduCount, Result, + arrayvec::ArrayVec, + utils::{ArrayVecExt, ReadyExt, set, stream::TryIgnore}, +}; +use tuwunel_database::{Interfix, Map, keyval::Val}; + +use super::backend::{BackendQuery, SearchBackend, SearchHit}; +use crate::rooms::{ + short::ShortRoomId, + timeline::{PduId, RawPduId}, +}; + +pub(super) struct RocksDbBackend { + tokenids: Arc, +} + +type TokenId = ArrayVec; + +const TOKEN_ID_MAX_LEN: usize = + size_of::() + WORD_MAX_LEN + 1 + size_of::(); +const WORD_MAX_LEN: usize = 50; + +impl RocksDbBackend { + pub(super) fn new(tokenids: Arc) -> Self { Self { tokenids } } +} + +#[async_trait] +impl SearchBackend for RocksDbBackend { + async fn index_pdu( + &self, + shortroomid: ShortRoomId, + pdu_id: &RawPduId, + _room_id: &RoomId, + message_body: &str, + ) -> Result { + let batch = tokenize(message_body) + .map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xFF); + key.extend_from_slice(pdu_id.as_ref()); + key + }) + .collect::>(); + + self.tokenids + .insert_batch(batch.iter().map(|k| (k.as_slice(), &[]))); + + Ok(()) + } + + async fn deindex_pdu( + &self, + shortroomid: ShortRoomId, + pdu_id: &RawPduId, + _room_id: &RoomId, + message_body: &str, + ) -> Result { + let batch = tokenize(message_body).map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xFF); + key.extend_from_slice(pdu_id.as_ref()); + key + }); + + for token in batch { + self.tokenids.remove(&token); + } + + Ok(()) + } + + async fn search_pdu_ids( + &self, + query: &BackendQuery<'_>, + shortroomid: Option, + _limit: usize, + _skip: usize, + ) -> Result<(usize, Vec)> { + let shortroomid = shortroomid.ok_or_else(|| { + tuwunel_core::err!(Request(InvalidParam( + "RocksDB backend requires a shortroomid for search" + ))) + })?; + + let pdu_ids = + search_pdu_ids_query_room(&self.tokenids, shortroomid, query.search_term).await; + + let iters = pdu_ids.into_iter().map(IntoIterator::into_iter); + let results: Vec<_> = set::intersection(iters) + .map(|pdu_id| SearchHit { pdu_id, score: None }) + .collect(); + let count = results.len(); + + Ok((count, results)) + } + + async fn delete_room(&self, room_id: &RoomId, _shortroomid: ShortRoomId) -> Result { + let prefix = (room_id, Interfix); + + self.tokenids + .keys_prefix_raw(&prefix) + .ignore_err() + .ready_for_each(|key| { + self.tokenids.remove(key); + }) + .await; + + Ok(()) + } +} + +async fn search_pdu_ids_query_room( + tokenids: &Arc, + shortroomid: ShortRoomId, + search_term: &str, +) -> Vec> { + tokenize(search_term) + .collect::>() + .into_iter() + .map(|word| { + let tokenids = tokenids.clone(); + async move { + search_pdu_ids_query_words(&tokenids, shortroomid, &word) + .collect::>() + .await + } + }) + .collect::>() + .collect::>() + .await +} + +fn search_pdu_ids_query_words<'a>( + tokenids: &'a Arc, + shortroomid: ShortRoomId, + word: &'a str, +) -> impl futures::Stream + Send + 'a { + search_pdu_ids_query_word(tokenids, shortroomid, word).map(move |key| -> RawPduId { + let key = &key[prefix_len(word)..]; + key.into() + }) +} + +fn search_pdu_ids_query_word<'a>( + tokenids: &'a Arc, + shortroomid: ShortRoomId, + word: &'a str, +) -> impl futures::Stream> + Send + 'a { + let end_id: RawPduId = PduId { shortroomid, count: PduCount::max() }.into(); + + let end = make_tokenid(shortroomid, word, &end_id); + let prefix = make_prefix(shortroomid, word); + tokenids + .rev_raw_keys_from(&end) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) +} + +/// Splits a string into tokens used as keys in the search inverted index +pub(super) fn tokenize(body: &str) -> impl Iterator + Send + '_ { + body.split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .filter(|word| word.len() <= WORD_MAX_LEN) + .map(str::to_lowercase) +} + +fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &RawPduId) -> TokenId { + let mut key = make_prefix(shortroomid, word); + key.extend_from_slice(pdu_id.as_ref()); + key +} + +fn make_prefix(shortroomid: ShortRoomId, word: &str) -> TokenId { + let mut key = TokenId::new(); + key.extend_from_slice(&shortroomid.to_be_bytes()); + key.extend_from_slice(word.as_bytes()); + key.push(tuwunel_database::SEP); + key +} + +fn prefix_len(word: &str) -> usize { + size_of::() + .saturating_add(word.len()) + .saturating_add(1) +} diff --git a/src/service/rooms/timeline/append.rs b/src/service/rooms/timeline/append.rs index 07dce70e..964327b6 100644 --- a/src/service/rooms/timeline/append.rs +++ b/src/service/rooms/timeline/append.rs @@ -281,7 +281,8 @@ async fn append_pdu_effects( if let Some(body) = content.body { self.services .search - .index_pdu(shortroomid, &pdu_id, &body); + .index_pdu(shortroomid, &pdu_id, pdu.room_id(), &body) + .await; if self .services diff --git a/src/service/rooms/timeline/backfill.rs b/src/service/rooms/timeline/backfill.rs index d756120e..82a4d8d5 100644 --- a/src/service/rooms/timeline/backfill.rs +++ b/src/service/rooms/timeline/backfill.rs @@ -240,7 +240,8 @@ pub async fn backfill_pdu( if let Some(body) = content.body { self.services .search - .index_pdu(shortroomid, &pdu_id, &body); + .index_pdu(shortroomid, &pdu_id, pdu.room_id(), &body) + .await; } } drop(mutex_lock); diff --git a/src/service/rooms/timeline/redact.rs b/src/service/rooms/timeline/redact.rs index 16c6fde4..31a628c8 100644 --- a/src/service/rooms/timeline/redact.rs +++ b/src/service/rooms/timeline/redact.rs @@ -40,14 +40,15 @@ pub async fn redact_pdu( .get("body") .and_then(|body| body.as_str()); + let room_id = RoomId::parse(pdu["room_id"].as_str().unwrap()).unwrap(); + if let Some(body) = body { self.services .search - .deindex_pdu(shortroomid, &pdu_id, body); + .deindex_pdu(shortroomid, &pdu_id, room_id, body) + .await; } - let room_id = RoomId::parse(pdu["room_id"].as_str().unwrap()).unwrap(); - let room_version_id = self .services .state diff --git a/tuwunel-example.toml b/tuwunel-example.toml index 7a2ae8c9..9bc7bac9 100644 --- a/tuwunel-example.toml +++ b/tuwunel-example.toml @@ -907,6 +907,72 @@ # #auto_deactivate_banned_room_attempts = false +# Search backend to use for full-text message search. +# +# Available options: "rocksdb" (default) or "opensearch". +# +#search_backend = "rocksdb" + +# URL of the OpenSearch instance. Required when search_backend is +# "opensearch". +# +# example: "http://localhost:9200" +# +#search_opensearch_url = + +# Name of the OpenSearch index for message search. +# +#search_opensearch_index = "tuwunel_messages" + +# Authentication for OpenSearch in "user:pass" format. +# +#search_opensearch_auth = + +# Maximum number of documents to batch before flushing to OpenSearch. +# +#search_opensearch_batch_size = 100 + +# Maximum time in milliseconds to wait before flushing a partial batch +# to OpenSearch. +# +#search_opensearch_flush_interval_ms = 1000 + +# Enable hybrid neural+BM25 search in OpenSearch. Requires an ML model +# deployed in OpenSearch and an ingest pipeline that populates an +# "embedding" field. +# +# When enabled, tuwunel will: +# - Create the index with a knn_vector "embedding" field +# - Attach the ingest pipeline (search_opensearch_pipeline) to the index +# - Use hybrid queries combining BM25 + neural kNN scoring +# +# For a complete reference on configuring OpenSearch's ML plugin, model +# registration, and ingest pipeline setup, see the test helpers in +# `src/service/rooms/search/opensearch.rs` (the `ensure_neural_model`, +# `ensure_ingest_pipeline`, etc. functions in the `tests` module). +# +# See also: https://opensearch.org/docs/latest/search-plugins/neural-search/ +# +#search_opensearch_hybrid = false + +# The model ID registered in OpenSearch for neural search. Required when +# search_opensearch_hybrid is enabled. +# +# example: "aKV84osBBHNT0StI3MBr" +# +#search_opensearch_model_id = + +# Embedding dimension for the neural search model. Must match the output +# dimension of the deployed model. Common values: 384 +# (all-MiniLM-L6-v2), 768 (msmarco-distilbert-base-tas-b). +# +#search_opensearch_embedding_dim = 384 + +# Name of the ingest pipeline that generates embeddings for the +# "embedding" field. This pipeline must already exist in OpenSearch. +# +#search_opensearch_pipeline = "tuwunel_embedding_pipeline" + # RocksDB log level. This is not the same as tuwunel's log level. This # is the log level for the RocksDB engine/library which show up in your # database folder/path as `LOG` files. tuwunel will log RocksDB errors