Compare commits
42 Commits
feat/searc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
ba031865a4
|
|||
|
890b7c6c57
|
|||
|
|
489ff6f2a2
|
||
|
|
ce8abf6bf1
|
||
|
|
6a3588ed0b | ||
|
|
7e21b9d730 | ||
|
|
1a5b552cd6 | ||
|
|
529a2b91a4 | ||
|
|
a656aba615 | ||
|
|
e9864bc4e7 | ||
|
|
a554280559 | ||
|
|
02ee1a55a0 | ||
|
|
3ceeb8655f | ||
|
|
cd66cd843b | ||
|
|
b5b6e3f1fd | ||
|
|
e31778bdb2 | ||
|
|
aa847e4844 | ||
|
|
2a1d34bee1 | ||
|
|
64dd481140 | ||
|
|
715d0a11c6 | ||
|
|
beb9fa0ecd | ||
|
|
e70bc5d665 | ||
|
|
d15b30de64 | ||
|
|
f3db71b32e | ||
|
|
13c038e254 | ||
|
|
b07c61fab8 | ||
|
|
e5d01a2045 | ||
|
|
0de031c765 | ||
|
|
0d43411447 | ||
|
|
cf7a4dc88d | ||
|
|
3fcfcafdd2 | ||
|
|
dfedef4d19 | ||
|
|
c960a9dbc3 | ||
|
|
57d4ae243a | ||
|
|
75301ff596 | ||
|
|
1d537d4a37 | ||
|
|
14b9c5df45 | ||
|
|
ad896bb091 | ||
|
|
2b81e189cb | ||
|
|
31e7dc2735 | ||
|
|
d2836e9f50 | ||
|
|
55ee0d8ab6 |
2
.github/workflows/autocopr.yml
vendored
2
.github/workflows/autocopr.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
||||
contents: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v5
|
||||
- uses: actions/checkout@v6
|
||||
|
||||
- uses: aidandenlinger/autocopr@v1 # Or a specific release tag, or commit
|
||||
with:
|
||||
|
||||
4
.github/workflows/bake.yml
vendored
4
.github/workflows/bake.yml
vendored
@@ -116,7 +116,7 @@ jobs:
|
||||
machine: ${{fromJSON(inputs.machines)}}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
@@ -228,7 +228,7 @@ jobs:
|
||||
!failure() && !cancelled()
|
||||
&& fromJSON(inputs.artifact)[matrix.bake_target].dst
|
||||
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
path: _artifact/*
|
||||
name: ${{matrix.cargo_profile}}-${{matrix.feat_set}}-${{matrix.sys_target}}-${{fromJSON(inputs.artifact)[matrix.bake_target].dst}}
|
||||
|
||||
13
.github/workflows/main.yml
vendored
13
.github/workflows/main.yml
vendored
@@ -332,19 +332,6 @@ jobs:
|
||||
{"sys_target": "x86_64-v4-linux-gnu", "feat_set": "default"},
|
||||
]
|
||||
|
||||
opensearch-tests:
|
||||
if: >
|
||||
!failure() && !cancelled()
|
||||
&& fromJSON(needs.init.outputs.enable_test)
|
||||
&& !fromJSON(needs.init.outputs.is_release)
|
||||
&& !contains(needs.init.outputs.pipeline, '[ci no test]')
|
||||
|
||||
name: OpenSearch
|
||||
needs: [init, lint]
|
||||
uses: ./.github/workflows/opensearch-tests.yml
|
||||
with:
|
||||
checkout: ${{needs.init.outputs.checkout}}
|
||||
|
||||
package:
|
||||
if: >
|
||||
!failure() && !cancelled()
|
||||
|
||||
114
.github/workflows/opensearch-tests.yml
vendored
114
.github/workflows/opensearch-tests.yml
vendored
@@ -1,114 +0,0 @@
|
||||
name: OpenSearch Integration Tests
|
||||
|
||||
on:
|
||||
workflow_call:
|
||||
inputs:
|
||||
checkout:
|
||||
type: string
|
||||
default: 'HEAD'
|
||||
pull_request:
|
||||
paths:
|
||||
- 'src/service/rooms/search/**'
|
||||
- 'src/core/config/mod.rs'
|
||||
- '.github/workflows/opensearch-tests.yml'
|
||||
push:
|
||||
branches: [main]
|
||||
paths:
|
||||
- 'src/service/rooms/search/**'
|
||||
- 'src/core/config/mod.rs'
|
||||
- '.github/workflows/opensearch-tests.yml'
|
||||
|
||||
jobs:
|
||||
opensearch-tests:
|
||||
name: OpenSearch Integration
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
services:
|
||||
opensearch:
|
||||
image: opensearchproject/opensearch:3.5.0
|
||||
ports:
|
||||
- 9200:9200
|
||||
env:
|
||||
discovery.type: single-node
|
||||
DISABLE_SECURITY_PLUGIN: "true"
|
||||
OPENSEARCH_INITIAL_ADMIN_PASSWORD: "SuperSecret123!"
|
||||
OPENSEARCH_JAVA_OPTS: "-Xms1g -Xmx2g"
|
||||
options: >-
|
||||
--health-cmd "curl -f http://localhost:9200/_cluster/health || exit 1"
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 20
|
||||
--health-start-period 60s
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ inputs.checkout || github.sha }}
|
||||
|
||||
- name: Install Rust toolchain
|
||||
uses: dtolnay/rust-toolchain@nightly
|
||||
|
||||
- name: Cache cargo registry and build
|
||||
uses: actions/cache@v4
|
||||
with:
|
||||
path: |
|
||||
~/.cargo/registry
|
||||
~/.cargo/git
|
||||
target
|
||||
key: opensearch-tests-${{ runner.os }}-${{ hashFiles('**/Cargo.lock') }}
|
||||
restore-keys: |
|
||||
opensearch-tests-${{ runner.os }}-
|
||||
|
||||
- name: Wait for OpenSearch
|
||||
run: |
|
||||
for i in $(seq 1 60); do
|
||||
if curl -sf http://localhost:9200/_cluster/health; then
|
||||
echo ""
|
||||
echo "OpenSearch is ready"
|
||||
exit 0
|
||||
fi
|
||||
echo "Waiting for OpenSearch... ($i/60)"
|
||||
sleep 2
|
||||
done
|
||||
echo "OpenSearch did not start in time"
|
||||
exit 1
|
||||
|
||||
- name: Configure OpenSearch ML plugin
|
||||
run: |
|
||||
curl -sf -X PUT http://localhost:9200/_cluster/settings \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{
|
||||
"persistent": {
|
||||
"plugins.ml_commons.only_run_on_ml_node": "false",
|
||||
"plugins.ml_commons.native_memory_threshold": "99",
|
||||
"plugins.ml_commons.allow_registering_model_via_url": "true"
|
||||
}
|
||||
}'
|
||||
|
||||
- name: Run OpenSearch integration tests (BM25)
|
||||
env:
|
||||
OPENSEARCH_URL: http://localhost:9200
|
||||
run: >
|
||||
cargo test
|
||||
-p tuwunel_service
|
||||
--lib
|
||||
rooms::search::opensearch::tests
|
||||
--
|
||||
--ignored
|
||||
--test-threads=1
|
||||
--skip test_neural
|
||||
--skip test_hybrid
|
||||
|
||||
- name: Run OpenSearch neural/hybrid tests
|
||||
env:
|
||||
OPENSEARCH_URL: http://localhost:9200
|
||||
timeout-minutes: 15
|
||||
run: >
|
||||
cargo test
|
||||
-p tuwunel_service
|
||||
--lib
|
||||
rooms::search::opensearch::tests
|
||||
--
|
||||
--ignored
|
||||
--test-threads=1
|
||||
test_neural test_hybrid
|
||||
14
.github/workflows/publish.yml
vendored
14
.github/workflows/publish.yml
vendored
@@ -103,21 +103,21 @@ jobs:
|
||||
include: ${{fromJSON(inputs.includes)}}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
fetch-depth: 0
|
||||
fetch-tags: true
|
||||
|
||||
- name: GitHub Login
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.ghcr_token }}
|
||||
|
||||
- name: DockerHub Login
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ inputs.docker_acct }}
|
||||
@@ -173,14 +173,14 @@ jobs:
|
||||
|
||||
steps:
|
||||
- name: GitHub Login
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.ghcr_token }}
|
||||
|
||||
- name: DockerHub Login
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ inputs.docker_acct }}
|
||||
@@ -280,14 +280,14 @@ jobs:
|
||||
permissions: write-all
|
||||
steps:
|
||||
- name: GitHub Login
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.ghcr_token }}
|
||||
|
||||
- name: DockerHub Login
|
||||
uses: docker/login-action@v3
|
||||
uses: docker/login-action@v4
|
||||
with:
|
||||
registry: docker.io
|
||||
username: ${{ inputs.docker_acct }}
|
||||
|
||||
6
.github/workflows/test.yml
vendored
6
.github/workflows/test.yml
vendored
@@ -389,7 +389,7 @@ jobs:
|
||||
include: ${{fromJSON(inputs.includes)}}
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- uses: actions/checkout@v6
|
||||
- name: Execute
|
||||
id: execute
|
||||
env:
|
||||
@@ -420,7 +420,7 @@ jobs:
|
||||
- if: success() || failure() && steps.execute.outcome == 'failure'
|
||||
name: Upload New Results
|
||||
id: upload-result
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: complement_results-${{matrix.feat_set}}-${{matrix.sys_name}}-${{matrix.sys_target}}.jsonl
|
||||
path: ./tests/complement/results.jsonl
|
||||
@@ -428,7 +428,7 @@ jobs:
|
||||
- if: success() || (failure() && steps.execute.outcome == 'failure')
|
||||
name: Upload Log Output
|
||||
id: upload-output
|
||||
uses: actions/upload-artifact@v4
|
||||
uses: actions/upload-artifact@v6
|
||||
with:
|
||||
name: complement_output-${{matrix.feat_set}}-${{matrix.sys_name}}-${{matrix.sys_target}}.jsonl
|
||||
path: ./tests/complement/logs.jsonl
|
||||
|
||||
592
Cargo.lock
generated
592
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
26
Cargo.toml
26
Cargo.toml
@@ -29,7 +29,7 @@ keywords = [
|
||||
license = "Apache-2.0"
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/matrix-construct/tuwunel"
|
||||
rust-version = "1.91.1"
|
||||
rust-version = "1.94.0"
|
||||
version = "1.5.1"
|
||||
|
||||
[workspace.metadata.crane]
|
||||
@@ -74,7 +74,7 @@ features = [
|
||||
version = "0.7"
|
||||
|
||||
[workspace.dependencies.axum-extra]
|
||||
version = "0.10"
|
||||
version = "0.12"
|
||||
default-features = false
|
||||
features = [
|
||||
"cookie",
|
||||
@@ -107,7 +107,7 @@ features = [
|
||||
version = "1.11"
|
||||
|
||||
[workspace.dependencies.bytesize]
|
||||
version = "2.1"
|
||||
version = "2.3"
|
||||
|
||||
[workspace.dependencies.cargo_toml]
|
||||
version = "0.22"
|
||||
@@ -143,7 +143,7 @@ features = [
|
||||
version = "0.8.3"
|
||||
|
||||
[workspace.dependencies.const-str]
|
||||
version = "0.7"
|
||||
version = "1.1"
|
||||
|
||||
[workspace.dependencies.criterion]
|
||||
version = "0.7"
|
||||
@@ -189,7 +189,7 @@ version = "0.12"
|
||||
default-features = false
|
||||
|
||||
[workspace.dependencies.http]
|
||||
version = "1.3"
|
||||
version = "1.4"
|
||||
|
||||
[workspace.dependencies.http-body-util]
|
||||
version = "0.1"
|
||||
@@ -266,7 +266,7 @@ version = "0.1"
|
||||
version = "1.0"
|
||||
|
||||
[workspace.dependencies.minicbor]
|
||||
version = "2.1"
|
||||
version = "2.2"
|
||||
features = ["std"]
|
||||
|
||||
[workspace.dependencies.minicbor-serde]
|
||||
@@ -396,7 +396,7 @@ version = "0.4"
|
||||
default-features = false
|
||||
|
||||
[workspace.dependencies.sentry]
|
||||
version = "0.45"
|
||||
version = "0.46"
|
||||
default-features = false
|
||||
features = [
|
||||
"backtrace",
|
||||
@@ -412,10 +412,10 @@ features = [
|
||||
]
|
||||
|
||||
[workspace.dependencies.sentry-tower]
|
||||
version = "0.45"
|
||||
version = "0.46"
|
||||
|
||||
[workspace.dependencies.sentry-tracing]
|
||||
version = "0.45"
|
||||
version = "0.46"
|
||||
|
||||
[workspace.dependencies.serde]
|
||||
version = "1.0"
|
||||
@@ -490,7 +490,7 @@ version = "2.0"
|
||||
default-features = false
|
||||
|
||||
[workspace.dependencies.tokio]
|
||||
version = "1.49"
|
||||
version = "1.50"
|
||||
default-features = false
|
||||
features = [
|
||||
"fs",
|
||||
@@ -508,9 +508,9 @@ features = [
|
||||
version = "0.4"
|
||||
|
||||
[workspace.dependencies.toml]
|
||||
version = "0.9"
|
||||
version = "1.0"
|
||||
default-features = false
|
||||
features = ["parse"]
|
||||
features = ["parse", "serde"]
|
||||
|
||||
[workspace.dependencies.tower]
|
||||
version = "0.5"
|
||||
@@ -984,7 +984,7 @@ unnecessary_safety_doc = "warn"
|
||||
unnecessary_self_imports = "warn"
|
||||
unneeded_field_pattern = "warn"
|
||||
unseparated_literal_suffix = "warn"
|
||||
#unwrap_used = "warn" # TODO
|
||||
unwrap_used = "warn"
|
||||
verbose_file_reads = "warn"
|
||||
|
||||
###################
|
||||
|
||||
@@ -7,6 +7,8 @@ excessive-nesting-threshold = 8
|
||||
type-complexity-threshold = 250 # reduce me to ~200
|
||||
cognitive-complexity-threshold = 100 # TODO reduce me ALARA
|
||||
|
||||
allow-unwrap-in-tests = true
|
||||
|
||||
#disallowed-macros = [
|
||||
# { path = "log::error", reason = "use tuwunel_core::error" },
|
||||
# { path = "log::warn", reason = "use tuwunel_core::warn" },
|
||||
|
||||
@@ -45,7 +45,7 @@
|
||||
file = ./rust-toolchain.toml;
|
||||
|
||||
# See also `rust-toolchain.toml`
|
||||
sha256 = "sha256-SDu4snEWjuZU475PERvu+iO50Mi39KVjqCeJeNvpguU=";
|
||||
sha256 = "sha256-qqF33vNuAdU5vua96VKVIwuc43j4EFeEXbjQ6+l4mO4=";
|
||||
};
|
||||
|
||||
mkScope =
|
||||
|
||||
@@ -9,7 +9,7 @@
|
||||
# If you're having trouble making the relevant changes, bug a maintainer.
|
||||
|
||||
[toolchain]
|
||||
channel = "1.91.1"
|
||||
channel = "1.94.0"
|
||||
profile = "minimal"
|
||||
components = [
|
||||
# For rust-analyzer
|
||||
|
||||
@@ -25,6 +25,7 @@ use tuwunel_core::{
|
||||
tokio_metrics::TaskMonitor,
|
||||
trace, utils,
|
||||
utils::{
|
||||
math::Expected,
|
||||
stream::{IterStream, ReadyExt},
|
||||
string::EMPTY,
|
||||
time::now_secs,
|
||||
@@ -377,7 +378,7 @@ pub(super) async fn sign_json(&self) -> Result {
|
||||
return Err!("Expected code block in command body. Add --help for details.");
|
||||
}
|
||||
|
||||
let string = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n");
|
||||
let string = self.body[1..self.body.len().expected_sub(1)].join("\n");
|
||||
let mut value = serde_json::from_str(&string).map_err(|e| err!("Invalid json: {e}"))?;
|
||||
|
||||
self.services.server_keys.sign_json(&mut value)?;
|
||||
@@ -395,7 +396,7 @@ pub(super) async fn verify_json(&self) -> Result {
|
||||
return Err!("Expected code block in command body. Add --help for details.");
|
||||
}
|
||||
|
||||
let string = self.body[1..self.body.len().checked_sub(1).unwrap()].join("\n");
|
||||
let string = self.body[1..self.body.len().expected_sub(1)].join("\n");
|
||||
|
||||
let value = serde_json::from_str::<CanonicalJsonObject>(&string)
|
||||
.map_err(|e| err!("Invalid json: {e}"))?;
|
||||
@@ -552,9 +553,11 @@ pub(super) async fn force_set_room_state_from_server(
|
||||
};
|
||||
|
||||
let pdu = if value["type"] == "m.room.create" {
|
||||
PduEvent::from_rid_val(&room_id, &event_id, value.clone()).map_err(invalid_pdu_err)?
|
||||
PduEvent::from_object_and_roomid_and_eventid(&room_id, &event_id, value.clone())
|
||||
.map_err(invalid_pdu_err)?
|
||||
} else {
|
||||
PduEvent::from_id_val(&event_id, value.clone()).map_err(invalid_pdu_err)?
|
||||
PduEvent::from_object_and_eventid(&event_id, value.clone())
|
||||
.map_err(invalid_pdu_err)?
|
||||
};
|
||||
|
||||
if !value.contains_key("room_id") {
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#![expect(clippy::too_many_arguments)]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
pub(crate) mod admin;
|
||||
pub(crate) mod context;
|
||||
|
||||
@@ -464,6 +464,10 @@ pub(super) async fn force_join_all_local_users(
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
{
|
||||
if user_id == &self.services.globals.server_user {
|
||||
continue;
|
||||
}
|
||||
|
||||
match self
|
||||
.services
|
||||
.membership
|
||||
|
||||
@@ -4,18 +4,21 @@ use axum::extract::State;
|
||||
use axum_client_ip::InsecureClientIp;
|
||||
use reqwest::Url;
|
||||
use ruma::{
|
||||
Mxc, UserId,
|
||||
MilliSecondsSinceUnixEpoch, Mxc, UserId,
|
||||
api::client::{
|
||||
authenticated_media::{
|
||||
get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
|
||||
get_media_preview,
|
||||
},
|
||||
media::create_content,
|
||||
media::{create_content, create_content_async, create_mxc_uri},
|
||||
},
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Err, Result, err,
|
||||
utils::{self, content_disposition::make_content_disposition, math::ruma_from_usize},
|
||||
utils::{
|
||||
self, content_disposition::make_content_disposition, math::ruma_from_usize,
|
||||
time::now_millis,
|
||||
},
|
||||
};
|
||||
use tuwunel_service::{
|
||||
Services,
|
||||
@@ -80,6 +83,78 @@ pub(crate) async fn create_content_route(
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/media/v1/create`
|
||||
///
|
||||
/// Create a new MXC URI without content.
|
||||
#[tracing::instrument(
|
||||
name = "media_create_mxc",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(%client),
|
||||
)]
|
||||
pub(crate) async fn create_mxc_uri_route(
|
||||
State(services): State<crate::State>,
|
||||
InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<create_mxc_uri::v1::Request>,
|
||||
) -> Result<create_mxc_uri::v1::Response> {
|
||||
let user = body.sender_user();
|
||||
let mxc = Mxc {
|
||||
server_name: services.globals.server_name(),
|
||||
media_id: &utils::random_string(MXC_LENGTH),
|
||||
};
|
||||
|
||||
// safe because even if it overflows, it will be greater than the current time
|
||||
// and the unused media will be deleted anyway
|
||||
let unused_expires_at = now_millis().saturating_add(
|
||||
services
|
||||
.server
|
||||
.config
|
||||
.media_create_unused_expiration_time
|
||||
.saturating_mul(1000),
|
||||
);
|
||||
services
|
||||
.media
|
||||
.create_pending(&mxc, user, unused_expires_at)
|
||||
.await?;
|
||||
|
||||
Ok(create_mxc_uri::v1::Response {
|
||||
content_uri: mxc.to_string().into(),
|
||||
unused_expires_at: ruma::UInt::new(unused_expires_at).map(MilliSecondsSinceUnixEpoch),
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/media/v3/upload/{serverName}/{mediaId}`
|
||||
///
|
||||
/// Upload content to a MXC URI that was created earlier.
|
||||
#[tracing::instrument(
|
||||
name = "media_upload_async",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(%client),
|
||||
)]
|
||||
pub(crate) async fn create_content_async_route(
|
||||
State(services): State<crate::State>,
|
||||
InsecureClientIp(client): InsecureClientIp,
|
||||
body: Ruma<create_content_async::v3::Request>,
|
||||
) -> Result<create_content_async::v3::Response> {
|
||||
let user = body.sender_user();
|
||||
let mxc = Mxc {
|
||||
server_name: &body.server_name,
|
||||
media_id: &body.media_id,
|
||||
};
|
||||
|
||||
let filename = body.filename.as_deref();
|
||||
let content_type = body.content_type.as_deref();
|
||||
let content_disposition = make_content_disposition(None, content_type, filename);
|
||||
|
||||
services
|
||||
.media
|
||||
.upload_pending(&mxc, user, Some(&content_disposition), content_type, &body.file)
|
||||
.await?;
|
||||
|
||||
Ok(create_content_async::v3::Response {})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/v1/media/thumbnail/{serverName}/{mediaId}`
|
||||
///
|
||||
/// Load media thumbnail from our server or over federation.
|
||||
@@ -296,7 +371,11 @@ async fn fetch_thumbnail_meta(
|
||||
timeout_ms: Duration,
|
||||
dim: &Dim,
|
||||
) -> Result<FileMeta> {
|
||||
if let Some(filemeta) = services.media.get_thumbnail(mxc, dim).await? {
|
||||
if let Some(filemeta) = services
|
||||
.media
|
||||
.get_thumbnail_with_timeout(mxc, dim, timeout_ms)
|
||||
.await?
|
||||
{
|
||||
return Ok(filemeta);
|
||||
}
|
||||
|
||||
@@ -316,7 +395,11 @@ async fn fetch_file_meta(
|
||||
user: &UserId,
|
||||
timeout_ms: Duration,
|
||||
) -> Result<FileMeta> {
|
||||
if let Some(filemeta) = services.media.get(mxc).await? {
|
||||
if let Some(filemeta) = services
|
||||
.media
|
||||
.get_with_timeout(mxc, timeout_ms)
|
||||
.await?
|
||||
{
|
||||
return Ok(filemeta);
|
||||
}
|
||||
|
||||
|
||||
@@ -145,7 +145,11 @@ pub(crate) async fn get_content_legacy_route(
|
||||
media_id: &body.media_id,
|
||||
};
|
||||
|
||||
match services.media.get(&mxc).await? {
|
||||
match services
|
||||
.media
|
||||
.get_with_timeout(&mxc, body.timeout_ms)
|
||||
.await?
|
||||
{
|
||||
| Some(FileMeta {
|
||||
content,
|
||||
content_type,
|
||||
@@ -236,7 +240,11 @@ pub(crate) async fn get_content_as_filename_legacy_route(
|
||||
media_id: &body.media_id,
|
||||
};
|
||||
|
||||
match services.media.get(&mxc).await? {
|
||||
match services
|
||||
.media
|
||||
.get_with_timeout(&mxc, body.timeout_ms)
|
||||
.await?
|
||||
{
|
||||
| Some(FileMeta {
|
||||
content,
|
||||
content_type,
|
||||
@@ -327,7 +335,11 @@ pub(crate) async fn get_content_thumbnail_legacy_route(
|
||||
};
|
||||
|
||||
let dim = Dim::from_ruma(body.width, body.height, body.method.clone())?;
|
||||
match services.media.get_thumbnail(&mxc, &dim).await? {
|
||||
match services
|
||||
.media
|
||||
.get_thumbnail_with_timeout(&mxc, &dim, body.timeout_ms)
|
||||
.await?
|
||||
{
|
||||
| Some(FileMeta {
|
||||
content,
|
||||
content_type,
|
||||
|
||||
@@ -46,6 +46,11 @@ pub(crate) async fn get_member_events_route(
|
||||
|
||||
let membership = body.membership.as_ref();
|
||||
let not_membership = body.not_membership.as_ref();
|
||||
let membership_filter = |content: &RoomMemberEventContent| {
|
||||
membership.is_none_or(is_equal_to!(&content.membership))
|
||||
&& not_membership.is_none_or(is_not_equal_to!(&content.membership))
|
||||
};
|
||||
|
||||
Ok(get_member_events::v3::Response {
|
||||
chunk: services
|
||||
.state_accessor
|
||||
@@ -55,12 +60,8 @@ pub(crate) async fn get_member_events_route(
|
||||
.map(at!(1))
|
||||
.ready_filter(|pdu| {
|
||||
pdu.get_content::<RoomMemberEventContent>()
|
||||
.is_ok_and(|content| {
|
||||
let event_membership = content.membership;
|
||||
|
||||
membership.is_none_or(is_equal_to!(&event_membership))
|
||||
&& not_membership.is_none_or(is_not_equal_to!(&event_membership))
|
||||
})
|
||||
.as_ref()
|
||||
.is_ok_and(membership_filter)
|
||||
})
|
||||
.map(Event::into_format)
|
||||
.collect()
|
||||
|
||||
@@ -16,6 +16,7 @@ pub(super) mod media_legacy;
|
||||
pub(super) mod membership;
|
||||
pub(super) mod message;
|
||||
pub(super) mod openid;
|
||||
pub(super) mod oidc;
|
||||
pub(super) mod presence;
|
||||
pub(super) mod profile;
|
||||
pub(super) mod push;
|
||||
@@ -63,6 +64,7 @@ pub(super) use media_legacy::*;
|
||||
pub(super) use membership::*;
|
||||
pub(super) use message::*;
|
||||
pub(super) use openid::*;
|
||||
pub(super) use oidc::*;
|
||||
pub(super) use presence::*;
|
||||
pub(super) use profile::*;
|
||||
pub(super) use push::*;
|
||||
|
||||
222
src/api/client/oidc.rs
Normal file
222
src/api/client/oidc.rs
Normal file
@@ -0,0 +1,222 @@
|
||||
use std::time::SystemTime;
|
||||
|
||||
use axum::{Json, extract::State, response::{IntoResponse, Redirect}};
|
||||
use axum_extra::{TypedHeader, headers::{Authorization, authorization::Bearer}};
|
||||
use http::StatusCode;
|
||||
use ruma::OwnedDeviceId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{Err, Result, err, info, utils};
|
||||
use tuwunel_service::{oauth::oidc_server::{DcrRequest, IdTokenClaims, OidcAuthRequest, OidcServer, ProviderMetadata}, users::device::generate_refresh_token};
|
||||
|
||||
const OIDC_REQ_ID_LENGTH: usize = 32;
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct AuthIssuerResponse { issuer: String }
|
||||
|
||||
pub(crate) async fn auth_issuer_route(State(services): State<crate::State>) -> Result<impl IntoResponse> {
|
||||
let issuer = oidc_issuer_url(&services)?;
|
||||
Ok(Json(AuthIssuerResponse { issuer }))
|
||||
}
|
||||
|
||||
pub(crate) async fn openid_configuration_route(State(services): State<crate::State>) -> Result<impl IntoResponse> {
|
||||
Ok(Json(oidc_metadata(&services)?))
|
||||
}
|
||||
|
||||
fn oidc_metadata(services: &tuwunel_service::Services) -> Result<ProviderMetadata> {
|
||||
let issuer = oidc_issuer_url(services)?;
|
||||
let base = issuer.trim_end_matches('/').to_owned();
|
||||
|
||||
Ok(ProviderMetadata {
|
||||
issuer,
|
||||
authorization_endpoint: format!("{base}/_tuwunel/oidc/authorize"),
|
||||
token_endpoint: format!("{base}/_tuwunel/oidc/token"),
|
||||
registration_endpoint: Some(format!("{base}/_tuwunel/oidc/registration")),
|
||||
revocation_endpoint: Some(format!("{base}/_tuwunel/oidc/revoke")),
|
||||
jwks_uri: format!("{base}/_tuwunel/oidc/jwks"),
|
||||
userinfo_endpoint: Some(format!("{base}/_tuwunel/oidc/userinfo")),
|
||||
account_management_uri: Some(format!("{base}/_tuwunel/oidc/account")),
|
||||
account_management_actions_supported: Some(vec!["org.matrix.profile".to_owned(), "org.matrix.sessions_list".to_owned(), "org.matrix.session_view".to_owned(), "org.matrix.session_end".to_owned(), "org.matrix.cross_signing_reset".to_owned()]),
|
||||
response_types_supported: vec!["code".to_owned()],
|
||||
response_modes_supported: Some(vec!["query".to_owned(), "fragment".to_owned()]),
|
||||
grant_types_supported: Some(vec!["authorization_code".to_owned(), "refresh_token".to_owned()]),
|
||||
code_challenge_methods_supported: Some(vec!["S256".to_owned()]),
|
||||
token_endpoint_auth_methods_supported: Some(vec!["none".to_owned(), "client_secret_basic".to_owned(), "client_secret_post".to_owned()]),
|
||||
scopes_supported: Some(vec!["openid".to_owned(), "urn:matrix:org.matrix.msc2967.client:api:*".to_owned(), "urn:matrix:org.matrix.msc2967.client:device:*".to_owned()]),
|
||||
subject_types_supported: Some(vec!["public".to_owned()]),
|
||||
id_token_signing_alg_values_supported: Some(vec!["ES256".to_owned()]),
|
||||
prompt_values_supported: Some(vec!["create".to_owned()]),
|
||||
claim_types_supported: Some(vec!["normal".to_owned()]),
|
||||
claims_supported: Some(vec!["iss".to_owned(), "sub".to_owned(), "aud".to_owned(), "exp".to_owned(), "iat".to_owned(), "nonce".to_owned()]),
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn registration_route(State(services): State<crate::State>, Json(body): Json<DcrRequest>) -> Result<impl IntoResponse> {
|
||||
let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); };
|
||||
|
||||
if body.redirect_uris.is_empty() { return Err!(Request(InvalidParam("redirect_uris must not be empty"))); }
|
||||
|
||||
let reg = oidc.register_client(body)?;
|
||||
info!("OIDC client registered: {} ({})", reg.client_id, reg.client_name.as_deref().unwrap_or("unnamed"));
|
||||
|
||||
Ok((StatusCode::CREATED, Json(serde_json::json!({"client_id": reg.client_id, "client_id_issued_at": reg.registered_at, "redirect_uris": reg.redirect_uris, "client_name": reg.client_name, "client_uri": reg.client_uri, "logo_uri": reg.logo_uri, "contacts": reg.contacts, "token_endpoint_auth_method": reg.token_endpoint_auth_method, "grant_types": reg.grant_types, "response_types": reg.response_types, "application_type": reg.application_type, "policy_uri": reg.policy_uri, "tos_uri": reg.tos_uri, "software_id": reg.software_id, "software_version": reg.software_version}))))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct AuthorizeParams {
|
||||
client_id: String, redirect_uri: String, response_type: String, scope: String,
|
||||
state: Option<String>, nonce: Option<String>, code_challenge: Option<String>,
|
||||
code_challenge_method: Option<String>, #[serde(default, rename = "prompt")] _prompt: Option<String>,
|
||||
}
|
||||
|
||||
pub(crate) async fn authorize_route(State(services): State<crate::State>, request: axum::extract::Request) -> Result<impl IntoResponse> {
|
||||
let params: AuthorizeParams = serde_html_form::from_str(request.uri().query().unwrap_or_default())?;
|
||||
let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); };
|
||||
|
||||
if params.response_type != "code" { return Err!(Request(InvalidParam("Only response_type=code is supported"))); }
|
||||
|
||||
oidc.validate_redirect_uri(¶ms.client_id, ¶ms.redirect_uri).await?;
|
||||
|
||||
let req_id = utils::random_string(OIDC_REQ_ID_LENGTH);
|
||||
let now = SystemTime::now();
|
||||
|
||||
oidc.store_auth_request(&req_id, &OidcAuthRequest {
|
||||
client_id: params.client_id, redirect_uri: params.redirect_uri, scope: params.scope,
|
||||
state: params.state, nonce: params.nonce, code_challenge: params.code_challenge,
|
||||
code_challenge_method: params.code_challenge_method, created_at: now,
|
||||
expires_at: now.checked_add(OidcServer::auth_request_lifetime()).unwrap_or(now),
|
||||
});
|
||||
|
||||
let default_idp = services.config.identity_provider.values().find(|idp| idp.default).or_else(|| services.config.identity_provider.values().next()).ok_or_else(|| err!(Config("identity_provider", "No identity provider configured")))?;
|
||||
let idp_id = default_idp.id();
|
||||
|
||||
let base = oidc_issuer_url(&services)?;
|
||||
let base = base.trim_end_matches('/');
|
||||
|
||||
let mut complete_url = url::Url::parse(&format!("{base}/_tuwunel/oidc/_complete")).map_err(|_| err!(error!("Failed to build complete URL")))?;
|
||||
complete_url.query_pairs_mut().append_pair("oidc_req_id", &req_id);
|
||||
|
||||
let mut sso_url = url::Url::parse(&format!("{base}/_matrix/client/v3/login/sso/redirect/{idp_id}")).map_err(|_| err!(error!("Failed to build SSO URL")))?;
|
||||
sso_url.query_pairs_mut().append_pair("redirectUrl", complete_url.as_str());
|
||||
|
||||
Ok(Redirect::temporary(sso_url.as_str()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct CompleteParams { oidc_req_id: String, #[serde(rename = "loginToken")] login_token: String }
|
||||
|
||||
pub(crate) async fn complete_route(State(services): State<crate::State>, request: axum::extract::Request) -> Result<impl IntoResponse> {
|
||||
let params: CompleteParams = serde_html_form::from_str(request.uri().query().unwrap_or_default())?;
|
||||
let Ok(oidc) = get_oidc_server(&services) else { return Err!(Request(NotFound("OIDC server not configured"))); };
|
||||
|
||||
let user_id = services.users.find_from_login_token(¶ms.login_token).await.map_err(|_| err!(Request(Forbidden("Invalid or expired login token"))))?;
|
||||
let auth_req = oidc.take_auth_request(¶ms.oidc_req_id).await?;
|
||||
let code = oidc.create_auth_code(&auth_req, user_id);
|
||||
|
||||
let mut redirect_url = url::Url::parse(&auth_req.redirect_uri).map_err(|_| err!(Request(InvalidParam("Invalid redirect_uri"))))?;
|
||||
redirect_url.query_pairs_mut().append_pair("code", &code);
|
||||
if let Some(state) = &auth_req.state { redirect_url.query_pairs_mut().append_pair("state", state); }
|
||||
|
||||
Ok(Redirect::temporary(redirect_url.as_str()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct TokenRequest {
|
||||
grant_type: String, code: Option<String>, redirect_uri: Option<String>, client_id: Option<String>,
|
||||
code_verifier: Option<String>, refresh_token: Option<String>, #[serde(rename = "scope")] _scope: Option<String>,
|
||||
}
|
||||
|
||||
pub(crate) async fn token_route(State(services): State<crate::State>, axum::extract::Form(body): axum::extract::Form<TokenRequest>) -> impl IntoResponse {
|
||||
match body.grant_type.as_str() {
|
||||
| "authorization_code" => token_authorization_code(&services, &body).await.unwrap_or_else(|e| oauth_error(StatusCode::INTERNAL_SERVER_ERROR, "server_error", &e.to_string())),
|
||||
| "refresh_token" => token_refresh(&services, &body).await.unwrap_or_else(|e| oauth_error(StatusCode::INTERNAL_SERVER_ERROR, "server_error", &e.to_string())),
|
||||
| _ => oauth_error(StatusCode::BAD_REQUEST, "unsupported_grant_type", "Unsupported grant_type"),
|
||||
}
|
||||
}
|
||||
|
||||
async fn token_authorization_code(services: &tuwunel_service::Services, body: &TokenRequest) -> Result<http::Response<axum::body::Body>> {
|
||||
let code = body.code.as_deref().ok_or_else(|| err!(Request(InvalidParam("code is required"))))?;
|
||||
let redirect_uri = body.redirect_uri.as_deref().ok_or_else(|| err!(Request(InvalidParam("redirect_uri is required"))))?;
|
||||
let client_id = body.client_id.as_deref().ok_or_else(|| err!(Request(InvalidParam("client_id is required"))))?;
|
||||
|
||||
let oidc = get_oidc_server(services)?;
|
||||
let session = oidc.exchange_auth_code(code, client_id, redirect_uri, body.code_verifier.as_deref()).await?;
|
||||
|
||||
let user_id = &session.user_id;
|
||||
let (access_token, expires_in) = services.users.generate_access_token(true);
|
||||
let refresh_token = generate_refresh_token();
|
||||
|
||||
let device_id: Option<OwnedDeviceId> = extract_device_id(&session.scope).map(OwnedDeviceId::from);
|
||||
let client_name = oidc.get_client(client_id).await.ok().and_then(|c| c.client_name);
|
||||
let device_display_name = client_name.as_deref().unwrap_or("OIDC Client");
|
||||
let device_id = services.users.create_device(user_id, device_id.as_deref(), (Some(&access_token), expires_in), Some(&refresh_token), Some(device_display_name), None).await?;
|
||||
|
||||
info!("{user_id} logged in via OIDC (device {device_id})");
|
||||
|
||||
let id_token = if session.scope.contains("openid") {
|
||||
let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap_or_default().as_secs();
|
||||
let issuer = oidc_issuer_url(services)?;
|
||||
let claims = IdTokenClaims { iss: issuer, sub: user_id.to_string(), aud: client_id.to_owned(), exp: now.saturating_add(3600), iat: now, nonce: session.nonce, at_hash: Some(OidcServer::at_hash(&access_token)) };
|
||||
Some(oidc.sign_id_token(&claims)?)
|
||||
} else { None };
|
||||
|
||||
let mut response = serde_json::json!({"access_token": access_token, "token_type": "Bearer", "scope": session.scope, "refresh_token": refresh_token});
|
||||
if let Some(expires_in) = expires_in { response["expires_in"] = serde_json::json!(expires_in.as_secs()); }
|
||||
if let Some(id_token) = id_token { response["id_token"] = serde_json::json!(id_token); }
|
||||
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
|
||||
async fn token_refresh(services: &tuwunel_service::Services, body: &TokenRequest) -> Result<http::Response<axum::body::Body>> {
|
||||
let refresh_token = body.refresh_token.as_deref().ok_or_else(|| err!(Request(InvalidParam("refresh_token is required"))))?;
|
||||
let (user_id, device_id, _) = services.users.find_from_token(refresh_token).await.map_err(|_| err!(Request(Forbidden("Invalid refresh token"))))?;
|
||||
|
||||
let (new_access_token, expires_in) = services.users.generate_access_token(true);
|
||||
let new_refresh_token = generate_refresh_token();
|
||||
services.users.set_access_token(&user_id, &device_id, &new_access_token, expires_in, Some(&new_refresh_token)).await?;
|
||||
|
||||
let mut response = serde_json::json!({"access_token": new_access_token, "token_type": "Bearer", "refresh_token": new_refresh_token});
|
||||
if let Some(expires_in) = expires_in { response["expires_in"] = serde_json::json!(expires_in.as_secs()); }
|
||||
|
||||
Ok(Json(response).into_response())
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub(crate) struct RevokeRequest { token: String, #[serde(default, rename = "token_type_hint")] _token_type_hint: Option<String> }
|
||||
|
||||
pub(crate) async fn revoke_route(State(services): State<crate::State>, axum::extract::Form(body): axum::extract::Form<RevokeRequest>) -> Result<impl IntoResponse> {
|
||||
if let Ok((user_id, device_id, _)) = services.users.find_from_token(&body.token).await { services.users.remove_device(&user_id, &device_id).await; }
|
||||
Ok(Json(serde_json::json!({})))
|
||||
}
|
||||
|
||||
pub(crate) async fn jwks_route(State(services): State<crate::State>) -> Result<impl IntoResponse> {
|
||||
let oidc = get_oidc_server(&services)?;
|
||||
Ok(Json(oidc.jwks()))
|
||||
}
|
||||
|
||||
pub(crate) async fn userinfo_route(State(services): State<crate::State>, TypedHeader(Authorization(bearer)): TypedHeader<Authorization<Bearer>>) -> Result<impl IntoResponse> {
|
||||
let token = bearer.token();
|
||||
let Ok((user_id, _device_id, _expires)) = services.users.find_from_token(token).await else { return Err!(Request(Unauthorized("Invalid access token"))); };
|
||||
let displayname = services.users.displayname(&user_id).await.ok();
|
||||
let avatar_url = services.users.avatar_url(&user_id).await.ok();
|
||||
Ok(Json(serde_json::json!({"sub": user_id.to_string(), "name": displayname, "picture": avatar_url})))
|
||||
}
|
||||
|
||||
pub(crate) async fn account_route() -> impl IntoResponse {
|
||||
axum::response::Html("<html><body><h1>Account Management</h1><p>Account management is not yet implemented. Please use your identity provider to manage your account.</p></body></html>")
|
||||
}
|
||||
|
||||
fn oauth_error(status: StatusCode, error: &str, description: &str) -> http::Response<axum::body::Body> {
|
||||
(status, Json(serde_json::json!({"error": error, "error_description": description}))).into_response()
|
||||
}
|
||||
|
||||
fn scope_contains_token(scope: &str, token: &str) -> bool { scope.split_whitespace().any(|t| t == token) }
|
||||
|
||||
fn get_oidc_server(services: &tuwunel_service::Services) -> Result<&OidcServer> {
|
||||
services.oauth.oidc_server.as_deref().ok_or_else(|| err!(Request(NotFound("OIDC server not configured"))))
|
||||
}
|
||||
|
||||
fn oidc_issuer_url(services: &tuwunel_service::Services) -> Result<String> {
|
||||
services.config.well_known.client.as_ref().map(|url| { let s = url.to_string(); if s.ends_with('/') { s } else { s + "/" } }).ok_or_else(|| err!(Config("well_known.client", "well_known.client must be set for OIDC server")))
|
||||
}
|
||||
|
||||
fn extract_device_id(scope: &str) -> Option<String> { scope.split_whitespace().find_map(|s| s.strip_prefix("urn:matrix:org.matrix.msc2967.client:device:")).map(ToOwned::to_owned) }
|
||||
@@ -233,8 +233,8 @@ pub(crate) async fn register_route(
|
||||
let proposed_user_id = UserId::parse_with_server_name(
|
||||
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
||||
services.globals.server_name(),
|
||||
)
|
||||
.unwrap();
|
||||
)?;
|
||||
|
||||
if !services.users.exists(&proposed_user_id).await {
|
||||
break proposed_user_id;
|
||||
}
|
||||
@@ -295,8 +295,7 @@ pub(crate) async fn register_route(
|
||||
let (worked, uiaainfo) = services
|
||||
.uiaa
|
||||
.try_auth(
|
||||
&UserId::parse_with_server_name("", services.globals.server_name())
|
||||
.unwrap(),
|
||||
&UserId::parse_with_server_name("", services.globals.server_name())?,
|
||||
"".into(),
|
||||
auth,
|
||||
&uiaainfo,
|
||||
@@ -311,8 +310,7 @@ pub(crate) async fn register_route(
|
||||
| Some(ref json) => {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services.uiaa.create(
|
||||
&UserId::parse_with_server_name("", services.globals.server_name())
|
||||
.unwrap(),
|
||||
&UserId::parse_with_server_name("", services.globals.server_name())?,
|
||||
"".into(),
|
||||
&uiaainfo,
|
||||
json,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
use std::collections::{BTreeMap, BTreeSet};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use axum::extract::State;
|
||||
use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt};
|
||||
@@ -92,120 +92,37 @@ async fn category_room_events(
|
||||
.boxed()
|
||||
});
|
||||
|
||||
let accessible_rooms: Vec<OwnedRoomId> = rooms
|
||||
let results: Vec<_> = rooms
|
||||
.filter_map(async |room_id| {
|
||||
check_room_visible(services, sender_user, &room_id, criteria)
|
||||
.await
|
||||
.is_ok()
|
||||
.then_some(room_id)
|
||||
})
|
||||
.filter_map(async |room_id| {
|
||||
let query = RoomQuery {
|
||||
room_id: &room_id,
|
||||
user_id: Some(sender_user),
|
||||
criteria,
|
||||
skip: next_batch,
|
||||
limit,
|
||||
};
|
||||
|
||||
let (count, results) = services.search.search_pdus(&query).await.ok()?;
|
||||
|
||||
results
|
||||
.collect::<Vec<_>>()
|
||||
.map(|results| (room_id.clone(), count, results))
|
||||
.map(Some)
|
||||
.await
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let (results, total) = if services.search.supports_cross_room_search() {
|
||||
let room_id_refs: Vec<&RoomId> = accessible_rooms
|
||||
.iter()
|
||||
.map(AsRef::as_ref)
|
||||
.collect();
|
||||
|
||||
let (count, pdus) = services
|
||||
.search
|
||||
.search_pdus_multi_room(&room_id_refs, sender_user, criteria, limit, next_batch)
|
||||
.await?;
|
||||
|
||||
let pdu_list: Vec<_> = pdus.collect().await;
|
||||
let total: UInt = count.try_into()?;
|
||||
|
||||
// Group results by room for state collection
|
||||
let room_ids_in_results: Vec<OwnedRoomId> = pdu_list
|
||||
.iter()
|
||||
.map(|pdu| pdu.room_id().to_owned())
|
||||
.collect();
|
||||
|
||||
let state: RoomStates = if criteria.include_state.is_some_and(is_true!()) {
|
||||
let unique_rooms: BTreeSet<OwnedRoomId> =
|
||||
room_ids_in_results.iter().cloned().collect();
|
||||
futures::stream::iter(unique_rooms.iter())
|
||||
.filter_map(async |room_id| {
|
||||
procure_room_state(services, room_id)
|
||||
.map_ok(|state| (room_id.clone(), state))
|
||||
.await
|
||||
.ok()
|
||||
})
|
||||
.collect()
|
||||
.await
|
||||
} else {
|
||||
BTreeMap::new()
|
||||
};
|
||||
|
||||
let results: Vec<SearchResult> = pdu_list
|
||||
.into_iter()
|
||||
.stream()
|
||||
.map(Event::into_format)
|
||||
.map(|result| SearchResult {
|
||||
rank: None,
|
||||
result: Some(result),
|
||||
context: EventContextResult {
|
||||
profile_info: BTreeMap::new(), //TODO
|
||||
events_after: Vec::new(), //TODO
|
||||
events_before: Vec::new(), //TODO
|
||||
start: None, //TODO
|
||||
end: None, //TODO
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let highlights = criteria
|
||||
.search_term
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.map(str::to_lowercase)
|
||||
.collect();
|
||||
|
||||
let next_batch = (results.len() >= limit)
|
||||
.then_some(next_batch.saturating_add(results.len()))
|
||||
.as_ref()
|
||||
.map(ToString::to_string);
|
||||
|
||||
return Ok(ResultRoomEvents {
|
||||
count: Some(total),
|
||||
next_batch,
|
||||
results,
|
||||
state,
|
||||
highlights,
|
||||
groups: BTreeMap::new(), // TODO
|
||||
});
|
||||
} else {
|
||||
let results: Vec<_> = accessible_rooms
|
||||
.iter()
|
||||
.stream()
|
||||
.filter_map(async |room_id| {
|
||||
let query = RoomQuery {
|
||||
room_id,
|
||||
user_id: Some(sender_user),
|
||||
criteria,
|
||||
skip: next_batch,
|
||||
limit,
|
||||
};
|
||||
|
||||
let (count, results) = services.search.search_pdus(&query).await.ok()?;
|
||||
|
||||
results
|
||||
.collect::<Vec<_>>()
|
||||
.map(|results| (room_id.clone(), count, results))
|
||||
.map(Some)
|
||||
.await
|
||||
})
|
||||
.collect()
|
||||
.await;
|
||||
|
||||
let total: UInt = results
|
||||
.iter()
|
||||
.fold(0, |a: usize, (_, count, _)| a.saturating_add(*count))
|
||||
.try_into()?;
|
||||
|
||||
(results, total)
|
||||
};
|
||||
let total: UInt = results
|
||||
.iter()
|
||||
.fold(0, |a: usize, (_, count, _)| a.saturating_add(*count))
|
||||
.try_into()?;
|
||||
|
||||
let state: RoomStates = results
|
||||
.iter()
|
||||
|
||||
@@ -79,6 +79,7 @@ pub(crate) async fn get_login_types_route(
|
||||
if list_idps && identity_providers.is_empty() =>
|
||||
false,
|
||||
| LoginType::Password(_) => services.config.login_with_password,
|
||||
| LoginType::Jwt(_) => services.config.jwt.enable,
|
||||
| _ => true,
|
||||
})
|
||||
.collect(),
|
||||
|
||||
@@ -51,7 +51,7 @@ static VERSIONS: [&str; 17] = [
|
||||
"v1.15", /* custom profile fields */
|
||||
];
|
||||
|
||||
static UNSTABLE_FEATURES: [&str; 18] = [
|
||||
static UNSTABLE_FEATURES: [&str; 22] = [
|
||||
"org.matrix.e2e_cross_signing",
|
||||
// private read receipts (https://github.com/matrix-org/matrix-spec-proposals/pull/2285)
|
||||
"org.matrix.msc2285.stable",
|
||||
@@ -86,4 +86,12 @@ static UNSTABLE_FEATURES: [&str; 18] = [
|
||||
"org.matrix.simplified_msc3575",
|
||||
// Allow room moderators to view redacted event content (https://github.com/matrix-org/matrix-spec-proposals/pull/2815)
|
||||
"fi.mau.msc2815",
|
||||
// OIDC-native auth: authorization code grant (https://github.com/matrix-org/matrix-spec-proposals/pull/2964)
|
||||
"org.matrix.msc2964",
|
||||
// OIDC-native auth: auth issuer discovery (https://github.com/matrix-org/matrix-spec-proposals/pull/2965)
|
||||
"org.matrix.msc2965",
|
||||
// OIDC-native auth: dynamic client registration (https://github.com/matrix-org/matrix-spec-proposals/pull/2966)
|
||||
"org.matrix.msc2966",
|
||||
// OIDC-native auth: API scopes (https://github.com/matrix-org/matrix-spec-proposals/pull/2967)
|
||||
"org.matrix.msc2967",
|
||||
];
|
||||
|
||||
@@ -35,20 +35,19 @@ pub(crate) async fn turn_server_route(
|
||||
)
|
||||
.expect("time is valid");
|
||||
|
||||
let user = body.sender_user.unwrap_or_else(|| {
|
||||
let random_user_id = || {
|
||||
UserId::parse_with_server_name(
|
||||
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
||||
&services.server.name,
|
||||
)
|
||||
.unwrap()
|
||||
});
|
||||
};
|
||||
|
||||
let user = body.sender_user.map_or_else(random_user_id, Ok)?;
|
||||
let username: String = format!("{}:{}", expiry.get(), user);
|
||||
|
||||
let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes())
|
||||
.expect("HMAC can take key of any size");
|
||||
mac.update(username.as_bytes());
|
||||
|
||||
mac.update(username.as_bytes());
|
||||
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
|
||||
|
||||
(username, password)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#![expect(clippy::toplevel_ref_arg)]
|
||||
#![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
pub mod client;
|
||||
pub mod router;
|
||||
|
||||
@@ -158,6 +158,8 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
|
||||
.ruma_route(&client::turn_server_route)
|
||||
.ruma_route(&client::send_event_to_device_route)
|
||||
.ruma_route(&client::create_content_route)
|
||||
.ruma_route(&client::create_mxc_uri_route)
|
||||
.ruma_route(&client::create_content_async_route)
|
||||
.ruma_route(&client::get_content_thumbnail_route)
|
||||
.ruma_route(&client::get_content_route)
|
||||
.ruma_route(&client::get_content_as_filename_route)
|
||||
@@ -195,6 +197,20 @@ pub fn build(router: Router<State>, server: &Server) -> Router<State> {
|
||||
.ruma_route(&client::well_known_support)
|
||||
.ruma_route(&client::well_known_client)
|
||||
.route("/_tuwunel/server_version", get(client::tuwunel_server_version))
|
||||
// OIDC server endpoints (next-gen auth, MSC2965/2964/2966/2967)
|
||||
.route("/_matrix/client/unstable/org.matrix.msc2965/auth_issuer", get(client::auth_issuer_route))
|
||||
.route("/_matrix/client/v1/auth_issuer", get(client::auth_issuer_route))
|
||||
.route("/_matrix/client/unstable/org.matrix.msc2965/auth_metadata", get(client::openid_configuration_route))
|
||||
.route("/_matrix/client/v1/auth_metadata", get(client::openid_configuration_route))
|
||||
.route("/.well-known/openid-configuration", get(client::openid_configuration_route))
|
||||
.route("/_tuwunel/oidc/registration", post(client::registration_route))
|
||||
.route("/_tuwunel/oidc/authorize", get(client::authorize_route))
|
||||
.route("/_tuwunel/oidc/_complete", get(client::complete_route))
|
||||
.route("/_tuwunel/oidc/token", post(client::token_route))
|
||||
.route("/_tuwunel/oidc/revoke", post(client::revoke_route))
|
||||
.route("/_tuwunel/oidc/jwks", get(client::jwks_route))
|
||||
.route("/_tuwunel/oidc/userinfo", get(client::userinfo_route))
|
||||
.route("/_tuwunel/oidc/account", get(client::account_route))
|
||||
.ruma_route(&client::room_initial_sync_route)
|
||||
.route("/client/server.json", get(client::syncv3_client_server_json));
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ pub struct State {
|
||||
services: *const Services,
|
||||
}
|
||||
|
||||
#[clippy::has_significant_drop]
|
||||
pub struct Guard {
|
||||
services: Arc<Services>,
|
||||
}
|
||||
|
||||
@@ -61,7 +61,7 @@ pub(crate) async fn get_server_keys_route(
|
||||
}
|
||||
|
||||
fn valid_until_ts() -> MilliSecondsSinceUnixEpoch {
|
||||
let dur = Duration::from_secs(86400 * 7);
|
||||
let dur = Duration::from_hours(168);
|
||||
let timepoint = timepoint_from_now(dur).expect("SystemTime should not overflow");
|
||||
MilliSecondsSinceUnixEpoch::from_system_time(timepoint).expect("UInt should not overflow")
|
||||
}
|
||||
|
||||
@@ -304,6 +304,12 @@ pub fn stats_reset() -> Result { notify(&mallctl!("stats.mutexes.reset")) }
|
||||
|
||||
pub fn prof_reset() -> Result { notify(&mallctl!("prof.reset")) }
|
||||
|
||||
pub fn prof_dump() -> Result { notify(&mallctl!("prof.dump")) }
|
||||
|
||||
pub fn prof_gdump(enable: bool) -> Result<bool> {
|
||||
set::<u8>(&mallctl!("prof.gdump"), enable.into()).map(is_nonzero!())
|
||||
}
|
||||
|
||||
pub fn prof_enable(enable: bool) -> Result<bool> {
|
||||
set::<u8>(&mallctl!("prof.active"), enable.into()).map(is_nonzero!())
|
||||
}
|
||||
@@ -312,6 +318,10 @@ pub fn is_prof_enabled() -> Result<bool> {
|
||||
get::<u8>(&mallctl!("prof.active")).map(is_nonzero!())
|
||||
}
|
||||
|
||||
pub fn prof_interval() -> Result<u64> {
|
||||
get::<u64>(&mallctl!("prof.interval")).and_then(math::try_into)
|
||||
}
|
||||
|
||||
pub fn trim<I: Into<Option<usize>> + Copy>(arena: I) -> Result {
|
||||
decay(arena).and_then(|()| purge(arena))
|
||||
}
|
||||
|
||||
@@ -116,22 +116,6 @@ pub fn check(config: &Config) -> Result {
|
||||
});
|
||||
}
|
||||
|
||||
if config.search_backend == super::SearchBackendConfig::OpenSearch
|
||||
&& config.search_opensearch_url.is_none()
|
||||
{
|
||||
return Err!(Config(
|
||||
"search_opensearch_url",
|
||||
"OpenSearch URL must be set when search_backend is \"opensearch\""
|
||||
));
|
||||
}
|
||||
|
||||
if config.search_opensearch_hybrid && config.search_opensearch_model_id.is_none() {
|
||||
return Err!(Config(
|
||||
"search_opensearch_model_id",
|
||||
"Model ID must be set when search_opensearch_hybrid is enabled"
|
||||
));
|
||||
}
|
||||
|
||||
// rocksdb does not allow max_log_files to be 0
|
||||
if config.rocksdb_max_log_files == 0 {
|
||||
return Err!(Config(
|
||||
|
||||
@@ -408,6 +408,33 @@ pub struct Config {
|
||||
#[serde(default = "default_max_request_size")]
|
||||
pub max_request_size: usize,
|
||||
|
||||
/// Maximum number of concurrently pending (asynchronous) media uploads a
|
||||
/// user can have.
|
||||
///
|
||||
/// default: 5
|
||||
#[serde(default = "default_max_pending_media_uploads")]
|
||||
pub max_pending_media_uploads: usize,
|
||||
|
||||
/// The time in seconds before an unused pending MXC URI expires and is
|
||||
/// removed.
|
||||
///
|
||||
/// default: 86400 (24 hours)
|
||||
#[serde(default = "default_media_create_unused_expiration_time")]
|
||||
pub media_create_unused_expiration_time: u64,
|
||||
|
||||
/// The maximum number of media create requests per second allowed from a
|
||||
/// single user.
|
||||
///
|
||||
/// default: 10
|
||||
#[serde(default = "default_media_rc_create_per_second")]
|
||||
pub media_rc_create_per_second: u32,
|
||||
|
||||
/// The maximum burst count for media create requests from a single user.
|
||||
///
|
||||
/// default: 50
|
||||
#[serde(default = "default_media_rc_create_burst_count")]
|
||||
pub media_rc_create_burst_count: u32,
|
||||
|
||||
/// default: 192
|
||||
#[serde(default = "default_max_fetch_prev_events")]
|
||||
pub max_fetch_prev_events: u16,
|
||||
@@ -1100,85 +1127,6 @@ pub struct Config {
|
||||
#[serde(default)]
|
||||
pub auto_deactivate_banned_room_attempts: bool,
|
||||
|
||||
/// Search backend to use for full-text message search.
|
||||
///
|
||||
/// Available options: "rocksdb" (default) or "opensearch".
|
||||
///
|
||||
/// default: "rocksdb"
|
||||
#[serde(default)]
|
||||
pub search_backend: SearchBackendConfig,
|
||||
|
||||
/// URL of the OpenSearch instance. Required when search_backend is
|
||||
/// "opensearch".
|
||||
///
|
||||
/// example: "http://localhost:9200"
|
||||
pub search_opensearch_url: Option<Url>,
|
||||
|
||||
/// Name of the OpenSearch index for message search.
|
||||
///
|
||||
/// default: "tuwunel_messages"
|
||||
#[serde(default = "default_search_opensearch_index")]
|
||||
pub search_opensearch_index: String,
|
||||
|
||||
/// Authentication for OpenSearch in "user:pass" format.
|
||||
///
|
||||
/// display: sensitive
|
||||
pub search_opensearch_auth: Option<String>,
|
||||
|
||||
/// Maximum number of documents to batch before flushing to OpenSearch.
|
||||
///
|
||||
/// default: 100
|
||||
#[serde(default = "default_search_opensearch_batch_size")]
|
||||
pub search_opensearch_batch_size: usize,
|
||||
|
||||
/// Maximum time in milliseconds to wait before flushing a partial batch
|
||||
/// to OpenSearch.
|
||||
///
|
||||
/// default: 1000
|
||||
#[serde(default = "default_search_opensearch_flush_interval_ms")]
|
||||
pub search_opensearch_flush_interval_ms: u64,
|
||||
|
||||
/// Enable hybrid neural+BM25 search in OpenSearch. Requires an ML model
|
||||
/// deployed in OpenSearch and an ingest pipeline that populates an
|
||||
/// "embedding" field.
|
||||
///
|
||||
/// When enabled, tuwunel will:
|
||||
/// - Create the index with a knn_vector "embedding" field
|
||||
/// - Attach the ingest pipeline (search_opensearch_pipeline) to the index
|
||||
/// - Use hybrid queries combining BM25 + neural kNN scoring
|
||||
///
|
||||
/// For a complete reference on configuring OpenSearch's ML plugin, model
|
||||
/// registration, and ingest pipeline setup, see the test helpers in
|
||||
/// `src/service/rooms/search/opensearch.rs` (the `ensure_neural_model`,
|
||||
/// `ensure_ingest_pipeline`, etc. functions in the `tests` module).
|
||||
///
|
||||
/// See also: https://opensearch.org/docs/latest/search-plugins/neural-search/
|
||||
///
|
||||
/// default: false
|
||||
#[serde(default)]
|
||||
pub search_opensearch_hybrid: bool,
|
||||
|
||||
/// The model ID registered in OpenSearch for neural search. Required when
|
||||
/// search_opensearch_hybrid is enabled.
|
||||
///
|
||||
/// example: "aKV84osBBHNT0StI3MBr"
|
||||
pub search_opensearch_model_id: Option<String>,
|
||||
|
||||
/// Embedding dimension for the neural search model. Must match the output
|
||||
/// dimension of the deployed model. Common values: 384
|
||||
/// (all-MiniLM-L6-v2), 768 (msmarco-distilbert-base-tas-b).
|
||||
///
|
||||
/// default: 384
|
||||
#[serde(default = "default_search_opensearch_embedding_dim")]
|
||||
pub search_opensearch_embedding_dim: usize,
|
||||
|
||||
/// Name of the ingest pipeline that generates embeddings for the
|
||||
/// "embedding" field. This pipeline must already exist in OpenSearch.
|
||||
///
|
||||
/// default: "tuwunel_embedding_pipeline"
|
||||
#[serde(default = "default_search_opensearch_pipeline")]
|
||||
pub search_opensearch_pipeline: String,
|
||||
|
||||
/// RocksDB log level. This is not the same as tuwunel's log level. This
|
||||
/// is the log level for the RocksDB engine/library which show up in your
|
||||
/// database folder/path as `LOG` files. tuwunel will log RocksDB errors
|
||||
@@ -2383,14 +2331,6 @@ pub struct Config {
|
||||
catchall: BTreeMap<String, IgnoredAny>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default, Deserialize, PartialEq, Eq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum SearchBackendConfig {
|
||||
#[default]
|
||||
RocksDb,
|
||||
OpenSearch,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Default)]
|
||||
#[config_example_generator(filename = "tuwunel-example.toml", section = "global.tls")]
|
||||
pub struct TlsConfig {
|
||||
@@ -3268,16 +3208,6 @@ impl Config {
|
||||
|
||||
fn true_fn() -> bool { true }
|
||||
|
||||
fn default_search_opensearch_index() -> String { "tuwunel_messages".to_owned() }
|
||||
|
||||
fn default_search_opensearch_batch_size() -> usize { 100 }
|
||||
|
||||
fn default_search_opensearch_flush_interval_ms() -> u64 { 1000 }
|
||||
|
||||
fn default_search_opensearch_embedding_dim() -> usize { 384 }
|
||||
|
||||
fn default_search_opensearch_pipeline() -> String { "tuwunel_embedding_pipeline".to_owned() }
|
||||
|
||||
#[cfg(test)]
|
||||
fn default_server_name() -> OwnedServerName { ruma::owned_server_name!("localhost") }
|
||||
|
||||
@@ -3342,6 +3272,10 @@ fn default_dns_timeout() -> u64 { 10 }
|
||||
fn default_ip_lookup_strategy() -> u8 { 5 }
|
||||
|
||||
fn default_max_request_size() -> usize { 24 * 1024 * 1024 }
|
||||
fn default_max_pending_media_uploads() -> usize { 5 }
|
||||
fn default_media_create_unused_expiration_time() -> u64 { 86400 }
|
||||
fn default_media_rc_create_per_second() -> u32 { 10 }
|
||||
fn default_media_rc_create_burst_count() -> u32 { 50 }
|
||||
|
||||
fn default_request_conn_timeout() -> u64 { 10 }
|
||||
|
||||
@@ -3394,7 +3328,7 @@ fn default_jaeger_filter() -> String {
|
||||
fn default_tracing_flame_output_path() -> String { "./tracing.folded".to_owned() }
|
||||
|
||||
fn default_trusted_servers() -> Vec<OwnedServerName> {
|
||||
vec![OwnedServerName::try_from("matrix.org").unwrap()]
|
||||
vec![OwnedServerName::try_from("matrix.org").expect("valid ServerName")]
|
||||
}
|
||||
|
||||
/// do debug logging by default for debug builds
|
||||
@@ -3596,7 +3530,7 @@ fn default_client_sync_timeout_max() -> u64 { 90000 }
|
||||
fn default_access_token_ttl() -> u64 { 604_800 }
|
||||
|
||||
fn default_deprioritize_joins_through_servers() -> RegexSet {
|
||||
RegexSet::new([r"matrix\.org"]).unwrap()
|
||||
RegexSet::new([r"matrix\.org"]).expect("valid set of regular expressions")
|
||||
}
|
||||
|
||||
fn default_one_time_key_limit() -> usize { 256 }
|
||||
|
||||
@@ -57,12 +57,18 @@ pub(super) fn bad_request_code(kind: &ErrorKind) -> StatusCode {
|
||||
use ErrorKind::*;
|
||||
|
||||
match kind {
|
||||
// 504
|
||||
| NotYetUploaded => StatusCode::GATEWAY_TIMEOUT,
|
||||
|
||||
// 429
|
||||
| LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS,
|
||||
|
||||
// 413
|
||||
| TooLarge => StatusCode::PAYLOAD_TOO_LARGE,
|
||||
|
||||
// 409
|
||||
| CannotOverwriteMedia => StatusCode::CONFLICT,
|
||||
|
||||
// 405
|
||||
| Unrecognized => StatusCode::METHOD_NOT_ALLOWED,
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ use std::sync::Arc;
|
||||
use super::Capture;
|
||||
|
||||
/// Capture instance scope guard.
|
||||
#[clippy::has_significant_drop]
|
||||
pub struct Guard {
|
||||
pub(super) capture: Arc<Capture>,
|
||||
}
|
||||
|
||||
@@ -7,6 +7,14 @@ use crate::{Result, debug_error, err, matrix::room_version};
|
||||
///
|
||||
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String,
|
||||
/// CanonicalJsonValue>`.
|
||||
#[tracing::instrument(
|
||||
name = "gen_event_id",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
len = pdu.get().len(),
|
||||
)
|
||||
)]
|
||||
pub fn gen_event_id_canonical_json(
|
||||
pdu: &RawJsonValue,
|
||||
room_version_id: &RoomVersionId,
|
||||
@@ -23,6 +31,14 @@ pub fn gen_event_id_canonical_json(
|
||||
/// Generates a correct eventId for the PDU. For v1/v2 incoming PDU's the
|
||||
/// value's event_id is passed through. For all outgoing PDU's and for v3+
|
||||
/// incoming PDU's it is generated.
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(
|
||||
members = value.len(),
|
||||
room_version = ?room_version_id,
|
||||
)
|
||||
)]
|
||||
pub fn gen_event_id(
|
||||
value: &CanonicalJsonObject,
|
||||
room_version_id: &RoomVersionId,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
mod builder;
|
||||
mod count;
|
||||
pub mod format;
|
||||
mod format;
|
||||
mod hashes;
|
||||
mod id;
|
||||
mod raw_id;
|
||||
@@ -13,6 +13,7 @@ use std::cmp::Ordering;
|
||||
use ruma::{
|
||||
CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId,
|
||||
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, UInt, UserId, events::TimelineEventType,
|
||||
room_version_rules::RoomVersionRules,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::value::RawValue as RawJsonValue;
|
||||
@@ -22,7 +23,10 @@ pub use self::{
|
||||
Count as PduCount, Id as PduId, Pdu as PduEvent, RawId as RawPduId,
|
||||
builder::{Builder, Builder as PduBuilder},
|
||||
count::Count,
|
||||
format::check::check_pdu_format,
|
||||
format::{
|
||||
check::{check_room_id, check_rules},
|
||||
from_incoming_federation, into_outgoing_federation,
|
||||
},
|
||||
hashes::EventHashes as EventHash,
|
||||
id::Id,
|
||||
raw_id::*,
|
||||
@@ -99,28 +103,61 @@ pub const MAX_PREV_EVENTS: usize = 20;
|
||||
pub const MAX_AUTH_EVENTS: usize = 10;
|
||||
|
||||
impl Pdu {
|
||||
pub fn from_rid_val(
|
||||
pub fn from_object_and_roomid_and_eventid(
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
mut json: CanonicalJsonObject,
|
||||
) -> Result<Self> {
|
||||
let room_id = CanonicalJsonValue::String(room_id.into());
|
||||
json.insert("room_id".into(), room_id);
|
||||
|
||||
Self::from_id_val(event_id, json)
|
||||
Self::from_object_and_eventid(event_id, json)
|
||||
}
|
||||
|
||||
pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self> {
|
||||
pub fn from_object_and_eventid(
|
||||
event_id: &EventId,
|
||||
mut json: CanonicalJsonObject,
|
||||
) -> Result<Self> {
|
||||
let event_id = CanonicalJsonValue::String(event_id.into());
|
||||
json.insert("event_id".into(), event_id);
|
||||
|
||||
Self::from_val(&json)
|
||||
Self::from_object(json)
|
||||
}
|
||||
|
||||
pub fn from_val(json: &CanonicalJsonObject) -> Result<Self> {
|
||||
serde_json::to_value(json)
|
||||
.and_then(serde_json::from_value)
|
||||
.map_err(Into::into)
|
||||
pub fn from_object_federation(
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
json: CanonicalJsonObject,
|
||||
rules: &RoomVersionRules,
|
||||
) -> Result<(Self, CanonicalJsonObject)> {
|
||||
let json = from_incoming_federation(room_id, event_id, json, rules);
|
||||
let pdu = Self::from_object_checked(json.clone(), rules)?;
|
||||
check_room_id(&pdu, room_id)?;
|
||||
Ok((pdu, json))
|
||||
}
|
||||
|
||||
pub fn from_object_checked(
|
||||
json: CanonicalJsonObject,
|
||||
rules: &RoomVersionRules,
|
||||
) -> Result<Self> {
|
||||
check_rules(&json, &rules.event_format)?;
|
||||
Self::from_object(json)
|
||||
}
|
||||
|
||||
pub fn from_object(json: CanonicalJsonObject) -> Result<Self> {
|
||||
let json = CanonicalJsonValue::Object(json);
|
||||
Self::from_value(json)
|
||||
}
|
||||
|
||||
pub fn from_raw_value(json: &RawJsonValue) -> Result<Self> {
|
||||
let json: CanonicalJsonValue = json.into();
|
||||
Self::from_value(json)
|
||||
}
|
||||
|
||||
pub fn from_value(json: CanonicalJsonValue) -> Result<Self> {
|
||||
serde_json::from_value(json.into()).map_err(Into::into)
|
||||
}
|
||||
|
||||
pub fn from_raw_json(json: &RawJsonValue) -> Result<Self> {
|
||||
Self::deserialize(json).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -5,11 +5,9 @@ use ruma::{
|
||||
room_version_rules::{EventsReferenceFormatVersion, RoomVersionRules},
|
||||
};
|
||||
|
||||
use crate::{
|
||||
Result, extract_variant, is_equal_to,
|
||||
matrix::{PduEvent, room_version},
|
||||
};
|
||||
use crate::{extract_variant, is_equal_to, matrix::room_version};
|
||||
|
||||
#[must_use]
|
||||
pub fn into_outgoing_federation(
|
||||
mut pdu_json: CanonicalJsonObject,
|
||||
room_version: &RoomVersionId,
|
||||
@@ -68,12 +66,13 @@ fn mutate_outgoing_reference_format(value: &mut CanonicalJsonValue) {
|
||||
});
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn from_incoming_federation(
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
pdu_json: &mut CanonicalJsonObject,
|
||||
mut pdu_json: CanonicalJsonObject,
|
||||
room_rules: &RoomVersionRules,
|
||||
) -> Result<PduEvent> {
|
||||
) -> CanonicalJsonObject {
|
||||
if matches!(room_rules.events_reference_format, EventsReferenceFormatVersion::V1) {
|
||||
if let Some(value) = pdu_json.get_mut("auth_events") {
|
||||
mutate_incoming_reference_format(value);
|
||||
@@ -95,9 +94,7 @@ pub fn from_incoming_federation(
|
||||
pdu_json.insert("event_id".into(), CanonicalJsonValue::String(event_id.into()));
|
||||
}
|
||||
|
||||
check::check_pdu_format(pdu_json, &room_rules.event_format)?;
|
||||
|
||||
PduEvent::from_val(pdu_json)
|
||||
pdu_json
|
||||
}
|
||||
|
||||
fn mutate_incoming_reference_format(value: &mut CanonicalJsonValue) {
|
||||
|
||||
@@ -4,10 +4,21 @@ use ruma::{
|
||||
};
|
||||
use serde_json::to_string as to_json_string;
|
||||
|
||||
use crate::{
|
||||
Err, Result, err,
|
||||
matrix::pdu::{MAX_AUTH_EVENTS, MAX_PDU_BYTES, MAX_PREV_EVENTS},
|
||||
};
|
||||
use super::super::{MAX_AUTH_EVENTS, MAX_PDU_BYTES, MAX_PREV_EVENTS, Pdu};
|
||||
use crate::{Err, Result, err};
|
||||
|
||||
pub fn check_room_id(pdu: &Pdu, room_id: &RoomId) -> Result {
|
||||
if pdu.room_id != room_id {
|
||||
return Err!(Request(InvalidParam(error!(
|
||||
pdu_event_id = ?pdu.event_id,
|
||||
pdu_room_id = ?pdu.room_id,
|
||||
?room_id,
|
||||
"Event in wrong room",
|
||||
))));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check that the given canonicalized PDU respects the event format of the room
|
||||
/// version and the [size limits] from the Matrix specification.
|
||||
@@ -31,7 +42,7 @@ use crate::{
|
||||
///
|
||||
/// [size limits]: https://spec.matrix.org/latest/client-server-api/#size-limits
|
||||
/// [checks performed on receipt of a PDU]: https://spec.matrix.org/latest/server-server-api/#checks-performed-on-receipt-of-a-pdu
|
||||
pub fn check_pdu_format(pdu: &CanonicalJsonObject, rules: &EventFormatRules) -> Result {
|
||||
pub fn check_rules(pdu: &CanonicalJsonObject, rules: &EventFormatRules) -> Result {
|
||||
// Check the PDU size, it must occur on the full PDU with signatures.
|
||||
let json = to_json_string(&pdu)
|
||||
.map_err(|e| err!(Request(BadJson("Failed to serialize canonical JSON: {e}"))))?;
|
||||
@@ -200,7 +211,7 @@ mod tests {
|
||||
};
|
||||
use serde_json::{from_value as from_json_value, json};
|
||||
|
||||
use super::check_pdu_format;
|
||||
use super::check_rules as check_pdu_format;
|
||||
|
||||
/// Construct a PDU valid for the event format of room v1.
|
||||
fn pdu_v1() -> CanonicalJsonObject {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#![type_length_limit = "12288"]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
pub mod alloc;
|
||||
pub mod config;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
mod expect_into;
|
||||
mod expected;
|
||||
mod tried;
|
||||
|
||||
@@ -5,7 +6,7 @@ use std::convert::TryFrom;
|
||||
|
||||
pub use checked_ops::checked_ops;
|
||||
|
||||
pub use self::{expected::Expected, tried::Tried};
|
||||
pub use self::{expect_into::ExpectInto, expected::Expected, tried::Tried};
|
||||
use crate::{Err, Error, Result, debug::type_name, err};
|
||||
|
||||
/// Checked arithmetic expression. Returns a Result<R, Error::Arithmetic>
|
||||
@@ -100,6 +101,11 @@ pub fn ruma_from_usize(val: usize) -> ruma::UInt {
|
||||
#[expect(clippy::as_conversions, clippy::cast_possible_truncation)]
|
||||
pub fn usize_from_u64_truncated(val: u64) -> usize { val as usize }
|
||||
|
||||
#[inline]
|
||||
pub fn expect_into<Dst: TryFrom<Src>, Src>(src: Src) -> Dst {
|
||||
try_into(src).expect("failed conversion from Src to Dst")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn try_into<Dst: TryFrom<Src>, Src>(src: Src) -> Result<Dst> {
|
||||
Dst::try_from(src).map_err(|_| {
|
||||
|
||||
12
src/core/utils/math/expect_into.rs
Normal file
12
src/core/utils/math/expect_into.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
pub trait ExpectInto {
|
||||
#[inline]
|
||||
#[must_use]
|
||||
fn expect_into<Dst: TryFrom<Self>>(self) -> Dst
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
super::expect_into::<Dst, Self>(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ExpectInto for T {}
|
||||
@@ -16,6 +16,7 @@ pub struct MutexMap<Key, Val> {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[clippy::has_significant_drop]
|
||||
pub struct Guard<Key, Val> {
|
||||
map: Map<Key, Val>,
|
||||
val: Omg<Val>,
|
||||
|
||||
@@ -5,7 +5,11 @@ pub mod usage;
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
pub use self::{compute::available_parallelism, limits::*, usage::*};
|
||||
pub use self::{
|
||||
compute::available_parallelism,
|
||||
limits::*,
|
||||
usage::{statm, thread_usage, usage},
|
||||
};
|
||||
use crate::{Result, at};
|
||||
|
||||
/// Return a possibly corrected std::env::current_exe() even if the path is
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#[cfg(unix)]
|
||||
use nix::sys::resource::{Resource, getrlimit};
|
||||
use nix::unistd::{SysconfVar, sysconf};
|
||||
|
||||
use crate::{Result, apply, debug, utils::math::usize_from_u64_truncated};
|
||||
use crate::{Result, apply, debug, utils::math::ExpectInto};
|
||||
|
||||
#[cfg(unix)]
|
||||
/// This is needed for opening lots of file descriptors, which tends to
|
||||
@@ -28,12 +29,7 @@ pub fn maximize_fd_limit() -> Result {
|
||||
#[cfg(not(unix))]
|
||||
pub fn maximize_fd_limit() -> Result { Ok(()) }
|
||||
|
||||
#[cfg(any(
|
||||
linux_android,
|
||||
netbsdlike,
|
||||
target_os = "aix",
|
||||
target_os = "freebsd",
|
||||
))]
|
||||
#[cfg(all(unix, not(target_os = "macos")))]
|
||||
/// Some distributions ship with very low defaults for thread counts; similar to
|
||||
/// low default file descriptor limits. But unlike fd's, thread limit is rarely
|
||||
/// reached, though on large systems (32+ cores) shipping with defaults of
|
||||
@@ -52,60 +48,62 @@ pub fn maximize_thread_limit() -> Result {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
linux_android,
|
||||
netbsdlike,
|
||||
target_os = "aix",
|
||||
target_os = "freebsd",
|
||||
)))]
|
||||
#[cfg(any(not(unix), target_os = "macos"))]
|
||||
pub fn maximize_thread_limit() -> Result { Ok(()) }
|
||||
|
||||
#[cfg(unix)]
|
||||
#[inline]
|
||||
pub fn max_file_descriptors() -> Result<(usize, usize)> {
|
||||
getrlimit(Resource::RLIMIT_NOFILE)
|
||||
.map(apply!(2, usize_from_u64_truncated))
|
||||
.map(apply!(2, ExpectInto::expect_into))
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
#[inline]
|
||||
pub fn max_file_descriptors() -> Result<(usize, usize)> { Ok((usize::MAX, usize::MAX)) }
|
||||
|
||||
#[cfg(unix)]
|
||||
#[inline]
|
||||
pub fn max_stack_size() -> Result<(usize, usize)> {
|
||||
getrlimit(Resource::RLIMIT_STACK)
|
||||
.map(apply!(2, usize_from_u64_truncated))
|
||||
.map(apply!(2, ExpectInto::expect_into))
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
#[cfg(not(unix))]
|
||||
#[inline]
|
||||
pub fn max_stack_size() -> Result<(usize, usize)> { Ok((usize::MAX, usize::MAX)) }
|
||||
|
||||
#[cfg(any(linux_android, netbsdlike, target_os = "freebsd",))]
|
||||
#[cfg(all(unix, not(target_os = "macos")))]
|
||||
#[inline]
|
||||
pub fn max_memory_locked() -> Result<(usize, usize)> {
|
||||
getrlimit(Resource::RLIMIT_MEMLOCK)
|
||||
.map(apply!(2, usize_from_u64_truncated))
|
||||
.map(apply!(2, ExpectInto::expect_into))
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
#[cfg(not(any(linux_android, netbsdlike, target_os = "freebsd",)))]
|
||||
#[cfg(any(not(unix), target_os = "macos"))]
|
||||
#[inline]
|
||||
pub fn max_memory_locked() -> Result<(usize, usize)> { Ok((usize::MIN, usize::MIN)) }
|
||||
|
||||
#[cfg(any(
|
||||
linux_android,
|
||||
netbsdlike,
|
||||
target_os = "aix",
|
||||
target_os = "freebsd",
|
||||
))]
|
||||
#[cfg(all(unix, not(target_os = "macos")))]
|
||||
#[inline]
|
||||
pub fn max_threads() -> Result<(usize, usize)> {
|
||||
getrlimit(Resource::RLIMIT_NPROC)
|
||||
.map(apply!(2, usize_from_u64_truncated))
|
||||
.map(apply!(2, ExpectInto::expect_into))
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
#[cfg(not(any(
|
||||
linux_android,
|
||||
netbsdlike,
|
||||
target_os = "aix",
|
||||
target_os = "freebsd",
|
||||
)))]
|
||||
#[cfg(any(not(unix), target_os = "macos"))]
|
||||
#[inline]
|
||||
pub fn max_threads() -> Result<(usize, usize)> { Ok((usize::MAX, usize::MAX)) }
|
||||
|
||||
/// Get the system's page size in bytes.
|
||||
#[inline]
|
||||
pub fn page_size() -> Result<usize> {
|
||||
sysconf(SysconfVar::PAGE_SIZE)?
|
||||
.unwrap_or(-1)
|
||||
.try_into()
|
||||
.map_err(Into::into)
|
||||
}
|
||||
|
||||
@@ -1,13 +1,80 @@
|
||||
use nix::sys::resource::Usage;
|
||||
#[cfg(unix)]
|
||||
use nix::sys::resource::{Usage, UsageWho, getrusage};
|
||||
use nix::sys::resource::{UsageWho, getrusage};
|
||||
|
||||
use crate::Result;
|
||||
use crate::{Result, expected};
|
||||
|
||||
pub fn virt() -> Result<usize> {
|
||||
Ok(statm_bytes()?
|
||||
.next()
|
||||
.expect("incomplete statm contents"))
|
||||
}
|
||||
|
||||
pub fn res() -> Result<usize> {
|
||||
Ok(statm_bytes()?
|
||||
.nth(1)
|
||||
.expect("incomplete statm contents"))
|
||||
}
|
||||
|
||||
pub fn shm() -> Result<usize> {
|
||||
Ok(statm_bytes()?
|
||||
.nth(2)
|
||||
.expect("incomplete statm contents"))
|
||||
}
|
||||
|
||||
pub fn code() -> Result<usize> {
|
||||
Ok(statm_bytes()?
|
||||
.nth(3)
|
||||
.expect("incomplete statm contents"))
|
||||
}
|
||||
|
||||
pub fn data() -> Result<usize> {
|
||||
Ok(statm_bytes()?
|
||||
.nth(5)
|
||||
.expect("incomplete statm contents"))
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn statm_bytes() -> Result<impl Iterator<Item = usize>> {
|
||||
let page_size = super::page_size()?;
|
||||
|
||||
Ok(statm()?.map(move |pages| expected!(pages * page_size)))
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
#[inline]
|
||||
pub fn statm() -> Result<impl Iterator<Item = usize>> {
|
||||
use std::{fs::File, io::Read, str};
|
||||
|
||||
use crate::{Error, arrayvec::ArrayVec};
|
||||
|
||||
File::open("/proc/self/statm")
|
||||
.map_err(Error::from)
|
||||
.and_then(|mut fp| {
|
||||
let mut buf = [0; 96];
|
||||
let len = fp.read(&mut buf)?;
|
||||
let vals = str::from_utf8(&buf[0..len])
|
||||
.expect("non-utf8 content in statm")
|
||||
.split_ascii_whitespace()
|
||||
.map(|val| {
|
||||
val.parse()
|
||||
.expect("non-integer value in statm contents")
|
||||
})
|
||||
.collect::<ArrayVec<usize, 12>>();
|
||||
|
||||
Ok(vals.into_iter())
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "linux"))]
|
||||
#[inline]
|
||||
pub fn statm() -> Result<impl Iterator<Item = usize>> { Ok([0, 0, 0, 0, 0, 0].into_iter()) }
|
||||
|
||||
#[cfg(unix)]
|
||||
pub fn usage() -> Result<Usage> { getrusage(UsageWho::RUSAGE_SELF).map_err(Into::into) }
|
||||
|
||||
#[cfg(not(unix))]
|
||||
pub fn usage() -> Result<()> { Ok(()) }
|
||||
pub fn usage() -> Result<Usage> { Ok(Usage::default()) }
|
||||
|
||||
#[cfg(any(
|
||||
target_os = "linux",
|
||||
@@ -16,15 +83,11 @@ pub fn usage() -> Result<()> { Ok(()) }
|
||||
))]
|
||||
pub fn thread_usage() -> Result<Usage> { getrusage(UsageWho::RUSAGE_THREAD).map_err(Into::into) }
|
||||
|
||||
#[cfg(all(
|
||||
unix,
|
||||
not(any(
|
||||
target_os = "linux",
|
||||
target_os = "freebsd",
|
||||
target_os = "openbsd"
|
||||
))
|
||||
))]
|
||||
pub fn thread_usage() -> Result<Usage> { getrusage(UsageWho::RUSAGE_SELF).map_err(Into::into) }
|
||||
|
||||
#[cfg(not(unix))]
|
||||
pub fn thread_usage() -> Result<()> { Ok(()) }
|
||||
#[cfg(not(any(
|
||||
target_os = "linux",
|
||||
target_os = "freebsd",
|
||||
target_os = "openbsd"
|
||||
)))]
|
||||
pub fn thread_usage() -> Result<Usage> {
|
||||
unimplemented!("RUSAGE_THREAD available on this platform")
|
||||
}
|
||||
|
||||
@@ -312,3 +312,13 @@ async fn set_difference_sorted_stream2() {
|
||||
.await;
|
||||
assert_eq!(r, &["aaa", "eee", "hhh"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn page_size() {
|
||||
use crate::utils::sys::page_size;
|
||||
|
||||
let val = page_size().expect("Failed to get system page size");
|
||||
println!("{val:?}");
|
||||
|
||||
assert!(val != 0, "page size was zero");
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ pub struct State<F: Fn(u64) -> Result + Sync> {
|
||||
release: F,
|
||||
}
|
||||
|
||||
#[clippy::has_significant_drop]
|
||||
pub struct Permit<F: Fn(u64) -> Result + Sync> {
|
||||
/// Link back to the shared-state.
|
||||
state: Arc<Counter<F>>,
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
#![cfg(test)]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
use criterion::{Criterion, criterion_group, criterion_main};
|
||||
|
||||
criterion_group!(benches, ser_str);
|
||||
|
||||
@@ -2,6 +2,7 @@ use std::sync::Arc;
|
||||
|
||||
use crate::{Database, Engine};
|
||||
|
||||
#[clippy::has_significant_drop]
|
||||
pub struct Cork {
|
||||
engine: Arc<Engine>,
|
||||
flush: bool,
|
||||
|
||||
@@ -3,6 +3,7 @@ mod cf_opts;
|
||||
pub(crate) mod context;
|
||||
mod db_opts;
|
||||
pub(crate) mod descriptor;
|
||||
mod events;
|
||||
mod files;
|
||||
mod logger;
|
||||
mod memory_usage;
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::{cmp, convert::TryFrom};
|
||||
use rocksdb::{Cache, DBRecoveryMode, Env, LogLevel, Options, statistics::StatsLevel};
|
||||
use tuwunel_core::{Config, Result, utils};
|
||||
|
||||
use super::{cf_opts::cache_size_f64, logger::handle as handle_log};
|
||||
use super::{cf_opts::cache_size_f64, events::Events, logger::handle as handle_log};
|
||||
use crate::util::map_err;
|
||||
|
||||
/// Create database-wide options suitable for opening the database. This also
|
||||
@@ -22,6 +22,7 @@ pub(crate) fn db_options(config: &Config, env: &Env, row_cache: &Cache) -> Resul
|
||||
|
||||
// Logging
|
||||
set_logging_defaults(&mut opts, config);
|
||||
opts.add_event_listener(Events::new(config, env));
|
||||
|
||||
// Processing
|
||||
opts.set_max_background_jobs(num_threads::<i32>(config)?);
|
||||
|
||||
243
src/database/engine/events.rs
Normal file
243
src/database/engine/events.rs
Normal file
@@ -0,0 +1,243 @@
|
||||
use rocksdb::{
|
||||
Env,
|
||||
event_listener::{
|
||||
CompactionJobInfo, DBBackgroundErrorReason, DBWriteStallCondition, EventListener,
|
||||
FlushJobInfo, IngestionInfo, MemTableInfo, MutableStatus, SubcompactionJobInfo,
|
||||
WriteStallInfo,
|
||||
},
|
||||
};
|
||||
use tuwunel_core::{Config, debug, debug::INFO_SPAN_LEVEL, debug_info, error, info, warn};
|
||||
|
||||
pub(super) struct Events;
|
||||
|
||||
impl Events {
|
||||
pub(super) fn new(_config: &Config, _env: &Env) -> Self { Self {} }
|
||||
}
|
||||
|
||||
impl EventListener for Events {
|
||||
#[tracing::instrument(name = "error", level = "error", skip_all)]
|
||||
fn on_background_error(&self, reason: DBBackgroundErrorReason, _status: MutableStatus) {
|
||||
error!(error = ?reason, "Critical RocksDB Error");
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "stall", level = "warn", skip_all)]
|
||||
fn on_stall_conditions_changed(&self, info: &WriteStallInfo) {
|
||||
let col = info.cf_name();
|
||||
let col = col
|
||||
.as_deref()
|
||||
.map(str::from_utf8)
|
||||
.expect("column has a name")
|
||||
.expect("column name is valid utf8");
|
||||
|
||||
let prev = info.prev();
|
||||
match info.cur() {
|
||||
| DBWriteStallCondition::KStopped => {
|
||||
error!(?col, ?prev, "Database Stalled");
|
||||
},
|
||||
| DBWriteStallCondition::KDelayed if prev == DBWriteStallCondition::KStopped => {
|
||||
warn!(?col, ?prev, "Database Stall Recovering");
|
||||
},
|
||||
| DBWriteStallCondition::KDelayed => {
|
||||
warn!(?col, ?prev, "Database Stalling");
|
||||
},
|
||||
| DBWriteStallCondition::KNormal
|
||||
if prev == DBWriteStallCondition::KStopped
|
||||
|| prev == DBWriteStallCondition::KDelayed =>
|
||||
{
|
||||
info!(?col, ?prev, "Database Stall Recovered");
|
||||
},
|
||||
| DBWriteStallCondition::KNormal => {
|
||||
debug!(?col, ?prev, "Database Normal");
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "compaction",
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
)]
|
||||
fn on_compaction_begin(&self, info: &CompactionJobInfo) {
|
||||
let col = info.cf_name();
|
||||
let col = col
|
||||
.as_deref()
|
||||
.map(str::from_utf8)
|
||||
.expect("column has a name")
|
||||
.expect("column name is valid utf8");
|
||||
|
||||
let level = (info.base_input_level(), info.output_level());
|
||||
let records = (info.input_records(), info.output_records());
|
||||
let bytes = (info.total_input_bytes(), info.total_output_bytes());
|
||||
let files = (
|
||||
info.input_file_count(),
|
||||
info.output_file_count(),
|
||||
info.num_input_files_at_output_level(),
|
||||
);
|
||||
|
||||
debug!(
|
||||
status = ?info.status(),
|
||||
?level,
|
||||
?files,
|
||||
?records,
|
||||
?bytes,
|
||||
micros = info.elapsed_micros(),
|
||||
errs = info.num_corrupt_keys(),
|
||||
reason = ?info.compaction_reason(),
|
||||
?col,
|
||||
"Compaction Starting",
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "compaction",
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
)]
|
||||
fn on_compaction_completed(&self, info: &CompactionJobInfo) {
|
||||
let col = info.cf_name();
|
||||
let col = col
|
||||
.as_deref()
|
||||
.map(str::from_utf8)
|
||||
.expect("column has a name")
|
||||
.expect("column name is valid utf8");
|
||||
|
||||
let level = (info.base_input_level(), info.output_level());
|
||||
let records = (info.input_records(), info.output_records());
|
||||
let bytes = (info.total_input_bytes(), info.total_output_bytes());
|
||||
let files = (
|
||||
info.input_file_count(),
|
||||
info.output_file_count(),
|
||||
info.num_input_files_at_output_level(),
|
||||
);
|
||||
|
||||
debug_info!(
|
||||
status = ?info.status(),
|
||||
?level,
|
||||
?files,
|
||||
?records,
|
||||
?bytes,
|
||||
micros = info.elapsed_micros(),
|
||||
errs = info.num_corrupt_keys(),
|
||||
reason = ?info.compaction_reason(),
|
||||
?col,
|
||||
"Compaction Complete",
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "compaction", level = "debug", skip_all)]
|
||||
fn on_subcompaction_begin(&self, info: &SubcompactionJobInfo) {
|
||||
let col = info.cf_name();
|
||||
let col = col
|
||||
.as_deref()
|
||||
.map(str::from_utf8)
|
||||
.expect("column has a name")
|
||||
.expect("column name is valid utf8");
|
||||
|
||||
let level = (info.base_input_level(), info.output_level());
|
||||
|
||||
debug!(
|
||||
status = ?info.status(),
|
||||
?level,
|
||||
tid = info.thread_id(),
|
||||
reason = ?info.compaction_reason(),
|
||||
?col,
|
||||
"Compaction Starting",
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "compaction", level = "debug", skip_all)]
|
||||
fn on_subcompaction_completed(&self, info: &SubcompactionJobInfo) {
|
||||
let col = info.cf_name();
|
||||
let col = col
|
||||
.as_deref()
|
||||
.map(str::from_utf8)
|
||||
.expect("column has a name")
|
||||
.expect("column name is valid utf8");
|
||||
|
||||
let level = (info.base_input_level(), info.output_level());
|
||||
|
||||
debug!(
|
||||
status = ?info.status(),
|
||||
?level,
|
||||
tid = info.thread_id(),
|
||||
reason = ?info.compaction_reason(),
|
||||
?col,
|
||||
"Compaction Complete",
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "flush",
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
)]
|
||||
fn on_flush_begin(&self, info: &FlushJobInfo) {
|
||||
let col = info.cf_name();
|
||||
let col = col
|
||||
.as_deref()
|
||||
.map(str::from_utf8)
|
||||
.expect("column has a name")
|
||||
.expect("column name is valid utf8");
|
||||
|
||||
debug!(
|
||||
seq_start = info.smallest_seqno(),
|
||||
seq_end = info.largest_seqno(),
|
||||
slow = info.triggered_writes_slowdown(),
|
||||
stop = info.triggered_writes_stop(),
|
||||
reason = ?info.flush_reason(),
|
||||
?col,
|
||||
"Flush Starting",
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "flush",
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
)]
|
||||
fn on_flush_completed(&self, info: &FlushJobInfo) {
|
||||
let col = info.cf_name();
|
||||
let col = col
|
||||
.as_deref()
|
||||
.map(str::from_utf8)
|
||||
.expect("column has a name")
|
||||
.expect("column name is valid utf8");
|
||||
|
||||
debug_info!(
|
||||
seq_start = info.smallest_seqno(),
|
||||
seq_end = info.largest_seqno(),
|
||||
slow = info.triggered_writes_slowdown(),
|
||||
stop = info.triggered_writes_stop(),
|
||||
reason = ?info.flush_reason(),
|
||||
?col,
|
||||
"Flush Complete",
|
||||
);
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "memtable",
|
||||
level = INFO_SPAN_LEVEL,
|
||||
skip_all,
|
||||
)]
|
||||
fn on_memtable_sealed(&self, info: &MemTableInfo) {
|
||||
let col = info.cf_name();
|
||||
let col = col
|
||||
.as_deref()
|
||||
.map(str::from_utf8)
|
||||
.expect("column has a name")
|
||||
.expect("column name is valid utf8");
|
||||
|
||||
debug_info!(
|
||||
seq_first = info.first_seqno(),
|
||||
seq_early = info.earliest_seqno(),
|
||||
ents = info.num_entries(),
|
||||
dels = info.num_deletes(),
|
||||
?col,
|
||||
"Buffer Filled",
|
||||
);
|
||||
}
|
||||
|
||||
fn on_external_file_ingested(&self, _info: &IngestionInfo) {
|
||||
unimplemented!();
|
||||
}
|
||||
}
|
||||
@@ -128,6 +128,11 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "mediaid_file",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "mediaid_pending",
|
||||
ttl: 60 * 60 * 24 * 7,
|
||||
..descriptor::RANDOM_SMALL_CACHE
|
||||
},
|
||||
Descriptor {
|
||||
name: "mediaid_user",
|
||||
..descriptor::RANDOM_SMALL
|
||||
@@ -140,6 +145,22 @@ pub(super) static MAPS: &[Descriptor] = &[
|
||||
name: "oauthuniqid_oauthid",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oidc_signingkey",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oidcclientid_registration",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oidccode_authsession",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "oidcreqid_authrequest",
|
||||
..descriptor::RANDOM_SMALL
|
||||
},
|
||||
Descriptor {
|
||||
name: "onetimekeyid_onetimekeys",
|
||||
..descriptor::RANDOM_SMALL
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#![type_length_limit = "65536"]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
extern crate rust_rocksdb as rocksdb;
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
#![expect(clippy::needless_borrows_for_generic_args)]
|
||||
|
||||
use std::fmt::Debug;
|
||||
@@ -813,3 +814,47 @@ fn serde_tuple_option_none_none_none() {
|
||||
assert_eq!(None, cc.0);
|
||||
assert_eq!(bb, cc);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_tuple_integer_string() {
|
||||
let integer: u64 = 123_456;
|
||||
let user_id: &UserId = "@user:example.com".try_into().unwrap();
|
||||
|
||||
let mut a = integer.to_be_bytes().to_vec();
|
||||
a.push(0xFF);
|
||||
a.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
let b = (integer, user_id);
|
||||
let s = serialize_to_vec(&b).expect("failed to serialize (integer,string) tuple");
|
||||
|
||||
assert_eq!(a, s);
|
||||
|
||||
let c: (u64, &UserId) =
|
||||
de::from_slice(&s).expect("failed to deserialize (integer,string) tuple");
|
||||
|
||||
assert_eq!(c, b, "deserialized (integer,string) tuple did not match");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_tuple_string_integer_string() {
|
||||
let room_id: &RoomId = "!room:example.com".try_into().unwrap();
|
||||
let integer: u64 = 123_456;
|
||||
let user_id: &UserId = "@user:example.com".try_into().unwrap();
|
||||
|
||||
let mut a = Vec::new();
|
||||
a.extend_from_slice(room_id.as_bytes());
|
||||
a.push(0xFF);
|
||||
a.extend_from_slice(&integer.to_be_bytes());
|
||||
a.push(0xFF);
|
||||
a.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
let b = (room_id, integer, user_id);
|
||||
let s = serialize_to_vec(&b).expect("failed to serialize (string,integer,string) tuple");
|
||||
|
||||
assert_eq!(a, s);
|
||||
|
||||
let c: (&RoomId, u64, &UserId) =
|
||||
de::from_slice(&s).expect("failed to deserialize (integer,string) tuple");
|
||||
|
||||
assert_eq!(c, b, "deserialized (string,integer,string) tuple did not match");
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
mod admin;
|
||||
mod cargo;
|
||||
mod config;
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
#![cfg(test)]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
use criterion::{Criterion, criterion_group, criterion_main};
|
||||
use tracing::Level;
|
||||
use tuwunel::{Args, Server, runtime};
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#![type_length_limit = "4096"] //TODO: reduce me
|
||||
#![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
pub mod args;
|
||||
pub mod logging;
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use tuwunel::{Server, args, restart, runtime};
|
||||
|
||||
@@ -27,7 +27,7 @@ const WORKER_THREAD_MIN: usize = 2;
|
||||
const BLOCKING_THREAD_KEEPALIVE: u64 = 36;
|
||||
const BLOCKING_THREAD_NAME: &str = "tuwunel:spawned";
|
||||
const BLOCKING_THREAD_MAX: usize = 1024;
|
||||
const RUNTIME_SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10000);
|
||||
const RUNTIME_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
|
||||
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
||||
const DISABLE_MUZZY_THRESHOLD: usize = 8;
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#![cfg(test)]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
use insta::{assert_debug_snapshot, with_settings};
|
||||
use tuwunel::{Args, Server, runtime};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#![cfg(test)]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
use insta::{assert_debug_snapshot, with_settings};
|
||||
use tuwunel::{Args, Server, runtime};
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
#![cfg(test)]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
use insta::{assert_debug_snapshot, with_settings};
|
||||
use tracing::Level;
|
||||
|
||||
@@ -159,7 +159,8 @@ fn cors_layer(server: &Server) -> CorsLayer {
|
||||
header::AUTHORIZATION,
|
||||
header::CONTENT_TYPE,
|
||||
header::ORIGIN,
|
||||
HeaderName::from_lowercase(b"x-requested-with").unwrap(),
|
||||
HeaderName::from_lowercase(b"x-requested-with")
|
||||
.expect("valid HTTP HeaderName from lowercase."),
|
||||
];
|
||||
|
||||
let allow_origin_list = server
|
||||
@@ -181,7 +182,7 @@ fn cors_layer(server: &Server) -> CorsLayer {
|
||||
};
|
||||
|
||||
CorsLayer::new()
|
||||
.max_age(Duration::from_secs(86400))
|
||||
.max_age(Duration::from_hours(24))
|
||||
.allow_methods(METHODS)
|
||||
.allow_headers(headers)
|
||||
.allow_origin(allow_origin)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
#![type_length_limit = "32768"] //TODO: reduce me
|
||||
#![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
mod handle;
|
||||
mod layers;
|
||||
|
||||
@@ -21,6 +21,9 @@ pub(crate) async fn run(services: Arc<Services>) -> Result {
|
||||
// Install the admin room callback here for now
|
||||
tuwunel_admin::init(&services.admin).await;
|
||||
|
||||
// Execute configured startup commands.
|
||||
services.admin.startup_execute().await?;
|
||||
|
||||
// Setup shutdown/signal handling
|
||||
let handle = ServerHandle::new();
|
||||
let sigs = server
|
||||
@@ -65,7 +68,7 @@ pub(crate) async fn start(server: Arc<Server>) -> Result<Arc<Services>> {
|
||||
let services = Services::build(server).await?.start().await?;
|
||||
|
||||
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
||||
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
|
||||
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])
|
||||
.expect("failed to notify systemd of ready state");
|
||||
|
||||
debug!("Started");
|
||||
|
||||
@@ -3,7 +3,10 @@ mod plain;
|
||||
mod tls;
|
||||
mod unix;
|
||||
|
||||
use std::sync::{Arc, atomic::Ordering};
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
sync::{Arc, atomic::Ordering},
|
||||
};
|
||||
|
||||
use tokio::task::JoinSet;
|
||||
use tuwunel_core::{Result, debug_info};
|
||||
@@ -29,24 +32,42 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
|
||||
.await?;
|
||||
}
|
||||
|
||||
let addrs = config.get_bind_addrs();
|
||||
if !addrs.is_empty() {
|
||||
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
|
||||
if config.tls.certs.is_some() {
|
||||
#[cfg(feature = "direct_tls")]
|
||||
{
|
||||
services.globals.init_rustls_provider()?;
|
||||
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs).await?;
|
||||
}
|
||||
let systemd_listeners: Vec<_> = systemd_listeners().collect();
|
||||
let systemd_listeners_is_empty = systemd_listeners.is_empty();
|
||||
let (listeners, addrs): (Vec<_>, Vec<_>) = config
|
||||
.get_bind_addrs()
|
||||
.into_iter()
|
||||
.filter(|_| systemd_listeners_is_empty)
|
||||
.map(|addr| {
|
||||
let listener = TcpListener::bind(addr)
|
||||
.expect("Failed to bind configured TcpListener to {addr:?}");
|
||||
|
||||
#[cfg(not(feature = "direct_tls"))]
|
||||
return tuwunel_core::Err!(Config(
|
||||
"tls",
|
||||
"tuwunel was not built with direct TLS support (\"direct_tls\")"
|
||||
));
|
||||
} else {
|
||||
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &addrs);
|
||||
(listener, addr)
|
||||
})
|
||||
.chain(systemd_listeners)
|
||||
.inspect(|(listener, _)| {
|
||||
listener
|
||||
.set_nonblocking(true)
|
||||
.expect("Failed to set non-blocking");
|
||||
})
|
||||
.unzip();
|
||||
|
||||
#[cfg_attr(not(feature = "direct_tls"), expect(clippy::redundant_else))]
|
||||
if config.tls.certs.is_some() {
|
||||
#[cfg(feature = "direct_tls")]
|
||||
{
|
||||
services.globals.init_rustls_provider()?;
|
||||
tls::serve(server, &app, &handle.handle_ip, &mut join_set, &listeners, &addrs)
|
||||
.await?;
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "direct_tls"))]
|
||||
return tuwunel_core::Err!(Config(
|
||||
"tls",
|
||||
"tuwunel was not built with direct TLS support (\"direct_tls\")"
|
||||
));
|
||||
} else {
|
||||
plain::serve(server, &app, &handle.handle_ip, &mut join_set, &listeners, &addrs)?;
|
||||
}
|
||||
|
||||
assert!(!join_set.is_empty(), "at least one listener should be installed");
|
||||
@@ -75,3 +96,27 @@ pub(super) async fn serve(services: Arc<Services>, handle: ServerHandle) -> Resu
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
||||
fn systemd_listeners() -> impl Iterator<Item = (TcpListener, SocketAddr)> {
|
||||
sd_notify::listen_fds()
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.filter_map(|fd| {
|
||||
use std::os::fd::FromRawFd;
|
||||
|
||||
debug_assert!(fd >= 3, "fdno probably not a listener socket");
|
||||
// SAFETY: systemd should already take care of providing
|
||||
// the correct TCP socket, so we just use it via raw fd
|
||||
let listener = unsafe { TcpListener::from_raw_fd(fd) };
|
||||
|
||||
let Ok(addr) = listener.local_addr() else {
|
||||
return None;
|
||||
};
|
||||
|
||||
Some((listener, addr))
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(any(not(feature = "systemd"), not(target_os = "linux")))]
|
||||
fn systemd_listeners() -> impl Iterator<Item = (TcpListener, SocketAddr)> { std::iter::empty() }
|
||||
|
||||
@@ -1,26 +1,34 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::Handle;
|
||||
use tokio::task::JoinSet;
|
||||
use tuwunel_core::{Server, info};
|
||||
use tuwunel_core::{Result, Server, info};
|
||||
|
||||
pub(super) fn serve(
|
||||
server: &Arc<Server>,
|
||||
router: &Router,
|
||||
handle: &Handle<SocketAddr>,
|
||||
join_set: &mut JoinSet<Result<(), std::io::Error>>,
|
||||
listeners: &[TcpListener],
|
||||
addrs: &[SocketAddr],
|
||||
) {
|
||||
) -> Result {
|
||||
let router = router
|
||||
.clone()
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
for addr in addrs {
|
||||
let acceptor = axum_server::bind(*addr)
|
||||
|
||||
for listener in listeners {
|
||||
let acceptor = axum_server::from_tcp(listener.try_clone()?)?
|
||||
.handle(handle.clone())
|
||||
.serve(router.clone());
|
||||
|
||||
join_set.spawn_on(acceptor, server.runtime());
|
||||
}
|
||||
|
||||
info!("Listening on {addrs:?}");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
use std::{
|
||||
net::{SocketAddr, TcpListener},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use axum::Router;
|
||||
use axum_server::Handle;
|
||||
@@ -11,17 +14,30 @@ pub(super) async fn serve(
|
||||
app: &Router,
|
||||
handle: &Handle<SocketAddr>,
|
||||
join_set: &mut JoinSet<core::result::Result<(), std::io::Error>>,
|
||||
listeners: &[TcpListener],
|
||||
addrs: &[SocketAddr],
|
||||
) -> Result {
|
||||
let tls = &server.config.tls;
|
||||
let certs = tls.certs.as_ref().unwrap();
|
||||
let key = tls.key.as_ref().unwrap();
|
||||
|
||||
let certs = tls
|
||||
.certs
|
||||
.as_ref()
|
||||
.ok_or_else(|| err!(Config("tls.certs", "Invalid or missing TLS certificates")))?;
|
||||
|
||||
let key = tls
|
||||
.key
|
||||
.as_ref()
|
||||
.ok_or_else(|| err!(Config("tls.key", "Invalid or missingTLS key")))?;
|
||||
|
||||
info!(
|
||||
"Note: It is strongly recommended that you use a reverse proxy instead of running \
|
||||
tuwunel directly with TLS."
|
||||
);
|
||||
debug!("Using direct TLS. Certificate path {certs} and certificate private key path {key}",);
|
||||
|
||||
debug!(
|
||||
"Using direct TLS. Certificate path {certs:?} and certificate private key path {key:?}"
|
||||
);
|
||||
|
||||
let conf = RustlsConfig::from_pem_file(certs, key)
|
||||
.await
|
||||
.map_err(|e| err!(Config("tls", "Failed to load certificates or key: {e}")))?;
|
||||
@@ -29,15 +45,18 @@ pub(super) async fn serve(
|
||||
let app = app
|
||||
.clone()
|
||||
.into_make_service_with_connect_info::<SocketAddr>();
|
||||
|
||||
if tls.dual_protocol {
|
||||
for addr in addrs {
|
||||
join_set.spawn_on(
|
||||
axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone())
|
||||
.set_upgrade(false)
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
server.runtime(),
|
||||
);
|
||||
for listener in listeners {
|
||||
let acceptor = axum_server_dual_protocol::from_tcp_dual_protocol(
|
||||
listener.try_clone()?,
|
||||
conf.clone(),
|
||||
)?
|
||||
.set_upgrade(false)
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone());
|
||||
|
||||
join_set.spawn_on(acceptor, server.runtime());
|
||||
}
|
||||
|
||||
warn!(
|
||||
@@ -45,13 +64,12 @@ pub(super) async fn serve(
|
||||
(HTTP) connections too (insecure!)",
|
||||
);
|
||||
} else {
|
||||
for addr in addrs {
|
||||
join_set.spawn_on(
|
||||
axum_server::bind_rustls(*addr, conf.clone())
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
server.runtime(),
|
||||
);
|
||||
for listener in listeners {
|
||||
let acceptor = axum_server::from_tcp_rustls(listener.try_clone()?, conf.clone())?
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone());
|
||||
|
||||
join_set.spawn_on(acceptor, server.runtime());
|
||||
}
|
||||
|
||||
info!("Listening on {addrs:?} with TLS certificate {certs}");
|
||||
|
||||
@@ -104,6 +104,7 @@ lru-cache.workspace = true
|
||||
rand.workspace = true
|
||||
regex.workspace = true
|
||||
reqwest.workspace = true
|
||||
ring.workspace = true
|
||||
ruma.workspace = true
|
||||
rustls.workspace = true
|
||||
rustyline-async.workspace = true
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use ruma::events::room::message::RoomMessageEventContent;
|
||||
use tokio::time::{Duration, sleep};
|
||||
use tokio::task::yield_now;
|
||||
use tuwunel_core::{Err, Result, debug, debug_info, error, implement, info};
|
||||
|
||||
pub(super) const SIGNAL: &str = "SIGUSR2";
|
||||
@@ -16,7 +16,7 @@ pub(super) async fn console_auto_start(&self) {
|
||||
.admin_console_automatic
|
||||
{
|
||||
// Allow more of the startup sequence to execute before spawning
|
||||
tokio::task::yield_now().await;
|
||||
yield_now().await;
|
||||
self.console.start();
|
||||
}
|
||||
}
|
||||
@@ -31,7 +31,7 @@ pub(super) async fn console_auto_stop(&self) {
|
||||
|
||||
/// Execute admin commands after startup
|
||||
#[implement(super::Service)]
|
||||
pub(super) async fn startup_execute(&self) -> Result {
|
||||
pub async fn startup_execute(&self) -> Result {
|
||||
// List of commands to execute
|
||||
let commands = &self.services.server.config.admin_execute;
|
||||
|
||||
@@ -46,9 +46,6 @@ pub(super) async fn startup_execute(&self) -> Result {
|
||||
.config
|
||||
.admin_execute_errors_ignore;
|
||||
|
||||
//TODO: remove this after run-states are broadcast
|
||||
sleep(Duration::from_millis(500)).await;
|
||||
|
||||
for (i, command) in commands.iter().enumerate() {
|
||||
if let Err(e) = self.execute_command(i, command.clone()).await
|
||||
&& !errors
|
||||
@@ -56,7 +53,7 @@ pub(super) async fn startup_execute(&self) -> Result {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
yield_now().await;
|
||||
}
|
||||
|
||||
// The smoketest functionality is placed here for now and simply initiates
|
||||
@@ -98,7 +95,7 @@ pub(super) async fn signal_execute(&self) -> Result {
|
||||
return Err(e);
|
||||
}
|
||||
|
||||
tokio::task::yield_now().await;
|
||||
yield_now().await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
||||
@@ -86,7 +86,6 @@ impl crate::Service for Service {
|
||||
.expect("locked for writing")
|
||||
.insert(sender);
|
||||
|
||||
self.startup_execute().await?;
|
||||
self.console_auto_start().await;
|
||||
|
||||
loop {
|
||||
|
||||
@@ -52,13 +52,18 @@ where
|
||||
.map(BytesMut::freeze);
|
||||
|
||||
let mut parts = http_request.uri().clone().into_parts();
|
||||
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
|
||||
let old_path_and_query = parts
|
||||
.path_and_query
|
||||
.expect("valid request uri path and query")
|
||||
.as_str()
|
||||
.to_owned();
|
||||
|
||||
let symbol = if old_path_and_query.contains('?') { "&" } else { "?" };
|
||||
|
||||
parts.path_and_query = Some(
|
||||
(old_path_and_query + symbol + "access_token=" + hs_token)
|
||||
.parse()
|
||||
.unwrap(),
|
||||
.expect("valid path and query"),
|
||||
);
|
||||
*http_request.uri_mut() = parts
|
||||
.try_into()
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
#![cfg(test)]
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
collections::HashMap,
|
||||
|
||||
@@ -49,13 +49,16 @@ impl Deref for Service {
|
||||
fn handle_reload(&self) -> Result {
|
||||
if self.server.config.config_reload_signal {
|
||||
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
||||
sd_notify::notify(true, &[sd_notify::NotifyState::Reloading])
|
||||
.expect("failed to notify systemd of reloading state");
|
||||
sd_notify::notify(false, &[
|
||||
sd_notify::NotifyState::Reloading,
|
||||
sd_notify::NotifyState::monotonic_usec_now().expect("failed to get monotonic time"),
|
||||
])
|
||||
.expect("failed to notify systemd of reloading state");
|
||||
|
||||
self.reload(iter::empty())?;
|
||||
|
||||
#[cfg(all(feature = "systemd", target_os = "linux"))]
|
||||
sd_notify::notify(true, &[sd_notify::NotifyState::Ready])
|
||||
sd_notify::notify(false, &[sd_notify::NotifyState::Ready])
|
||||
.expect("failed to notify systemd of ready state");
|
||||
}
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ pub async fn format_pdu_into(
|
||||
.as_ref()
|
||||
.or(room_version)
|
||||
{
|
||||
pdu_json = pdu::format::into_outgoing_federation(pdu_json, room_version);
|
||||
pdu_json = pdu::into_outgoing_federation(pdu_json, room_version);
|
||||
} else {
|
||||
pdu_json.remove("event_id");
|
||||
}
|
||||
|
||||
@@ -3,7 +3,7 @@ use std::{panic::AssertUnwindSafe, sync::Arc, time::Duration};
|
||||
use futures::{FutureExt, TryFutureExt};
|
||||
use tokio::{
|
||||
sync::{Mutex, MutexGuard},
|
||||
task::{JoinHandle, JoinSet},
|
||||
task::{JoinHandle, JoinSet, yield_now},
|
||||
time::sleep,
|
||||
};
|
||||
use tuwunel_core::{
|
||||
@@ -60,6 +60,8 @@ impl Manager {
|
||||
self.start_worker(&mut workers, &service)?;
|
||||
}
|
||||
|
||||
yield_now().await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -72,7 +74,7 @@ impl Manager {
|
||||
}
|
||||
}
|
||||
|
||||
async fn worker(&self) -> Result {
|
||||
async fn worker(self: &Arc<Self>) -> Result {
|
||||
loop {
|
||||
let mut workers = self.workers.lock().await;
|
||||
tokio::select! {
|
||||
@@ -95,7 +97,7 @@ impl Manager {
|
||||
}
|
||||
|
||||
async fn handle_result(
|
||||
&self,
|
||||
self: &Arc<Self>,
|
||||
workers: &mut WorkersLocked<'_>,
|
||||
result: WorkerResult,
|
||||
) -> Result {
|
||||
@@ -108,7 +110,7 @@ impl Manager {
|
||||
|
||||
#[expect(clippy::unused_self)]
|
||||
fn handle_finished(
|
||||
&self,
|
||||
self: &Arc<Self>,
|
||||
_workers: &mut WorkersLocked<'_>,
|
||||
service: &Arc<dyn Service>,
|
||||
) -> Result {
|
||||
@@ -117,7 +119,7 @@ impl Manager {
|
||||
}
|
||||
|
||||
async fn handle_error(
|
||||
&self,
|
||||
self: &Arc<Self>,
|
||||
workers: &mut WorkersLocked<'_>,
|
||||
service: &Arc<dyn Service>,
|
||||
error: Error,
|
||||
@@ -143,7 +145,7 @@ impl Manager {
|
||||
|
||||
/// Start the worker in a task for the service.
|
||||
fn start_worker(
|
||||
&self,
|
||||
self: &Arc<Self>,
|
||||
workers: &mut WorkersLocked<'_>,
|
||||
service: &Arc<dyn Service>,
|
||||
) -> Result {
|
||||
@@ -155,7 +157,7 @@ impl Manager {
|
||||
}
|
||||
|
||||
debug!("Service {:?} worker starting...", service.name());
|
||||
workers.spawn_on(worker(service.clone()), self.server.runtime());
|
||||
workers.spawn_on(worker(service.clone(), self.clone()), self.server.runtime());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -172,7 +174,7 @@ impl Manager {
|
||||
skip_all,
|
||||
fields(service = %service.name()),
|
||||
)]
|
||||
async fn worker(service: Arc<dyn Service>) -> WorkerResult {
|
||||
async fn worker(service: Arc<dyn Service>, _mgr: Arc<Manager>) -> WorkerResult {
|
||||
let service_ = Arc::clone(&service);
|
||||
let result = AssertUnwindSafe(service_.worker())
|
||||
.catch_unwind()
|
||||
|
||||
@@ -1,17 +1,22 @@
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use futures::{StreamExt, pin_mut};
|
||||
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
|
||||
use ruma::{Mxc, OwnedMxcUri, OwnedUserId, UserId, http_headers::ContentDisposition};
|
||||
use tuwunel_core::{
|
||||
Err, Result, debug, debug_info, err,
|
||||
utils::{ReadyExt, str_from_bytes, stream::TryIgnore, string_from_bytes},
|
||||
utils::{
|
||||
ReadyExt, str_from_bytes,
|
||||
stream::{TryExpect, TryIgnore},
|
||||
string_from_bytes,
|
||||
},
|
||||
};
|
||||
use tuwunel_database::{Database, Interfix, Map, serialize_key};
|
||||
use tuwunel_database::{Database, Deserialized, Ignore, Interfix, Map, serialize_key};
|
||||
|
||||
use super::{preview::UrlPreviewData, thumbnail::Dim};
|
||||
|
||||
pub(crate) struct Data {
|
||||
mediaid_file: Arc<Map>,
|
||||
mediaid_pending: Arc<Map>,
|
||||
mediaid_user: Arc<Map>,
|
||||
url_previews: Arc<Map>,
|
||||
}
|
||||
@@ -27,6 +32,7 @@ impl Data {
|
||||
pub(super) fn new(db: &Arc<Database>) -> Self {
|
||||
Self {
|
||||
mediaid_file: db["mediaid_file"].clone(),
|
||||
mediaid_pending: db["mediaid_pending"].clone(),
|
||||
mediaid_user: db["mediaid_user"].clone(),
|
||||
url_previews: db["url_previews"].clone(),
|
||||
}
|
||||
@@ -52,6 +58,55 @@ impl Data {
|
||||
Ok(key.to_vec())
|
||||
}
|
||||
|
||||
/// Insert a pending MXC URI into the database
|
||||
pub(super) fn insert_pending_mxc(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: &UserId,
|
||||
unused_expires_at: u64,
|
||||
) {
|
||||
let value = (unused_expires_at, user);
|
||||
debug!(?mxc, ?user, ?unused_expires_at, "Inserting pending");
|
||||
|
||||
self.mediaid_pending
|
||||
.raw_put(mxc.to_string(), value);
|
||||
}
|
||||
|
||||
/// Remove a pending MXC URI from the database
|
||||
pub(super) fn remove_pending_mxc(&self, mxc: &Mxc<'_>) {
|
||||
self.mediaid_pending.remove(&mxc.to_string());
|
||||
}
|
||||
|
||||
/// Count the number of pending MXC URIs for a specific user
|
||||
pub(super) async fn count_pending_mxc_for_user(&self, user_id: &UserId) -> (usize, u64) {
|
||||
type KeyVal<'a> = (Ignore, (u64, &'a UserId));
|
||||
|
||||
self.mediaid_pending
|
||||
.stream()
|
||||
.expect_ok()
|
||||
.ready_filter(|(_, (_, pending_user_id)): &KeyVal<'_>| user_id == *pending_user_id)
|
||||
.ready_fold(
|
||||
(0_usize, u64::MAX),
|
||||
|(count, earliest_expiration), (_, (expires_at, _))| {
|
||||
(count.saturating_add(1), earliest_expiration.min(expires_at))
|
||||
},
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Search for a pending MXC URI in the database
|
||||
pub(super) async fn search_pending_mxc(&self, mxc: &Mxc<'_>) -> Result<(OwnedUserId, u64)> {
|
||||
type Value<'a> = (u64, OwnedUserId);
|
||||
|
||||
self.mediaid_pending
|
||||
.get(&mxc.to_string())
|
||||
.await
|
||||
.deserialized()
|
||||
.map(|(expires_at, user_id): Value<'_>| (user_id, expires_at))
|
||||
.inspect(|(user_id, expires_at)| debug!(?mxc, ?user_id, ?expires_at, "Found pending"))
|
||||
.map_err(|e| err!(Request(NotFound("Pending not found or error: {e}"))))
|
||||
}
|
||||
|
||||
pub(super) async fn delete_file_mxc(&self, mxc: &Mxc<'_>) {
|
||||
debug!("MXC URI: {mxc}");
|
||||
|
||||
|
||||
@@ -5,18 +5,29 @@ mod preview;
|
||||
mod remote;
|
||||
mod tests;
|
||||
mod thumbnail;
|
||||
use std::{path::PathBuf, sync::Arc, time::SystemTime};
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
path::PathBuf,
|
||||
sync::{Arc, Mutex},
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use base64::{Engine as _, engine::general_purpose};
|
||||
use ruma::{Mxc, OwnedMxcUri, UserId, http_headers::ContentDisposition};
|
||||
use http::StatusCode;
|
||||
use ruma::{
|
||||
Mxc, OwnedMxcUri, OwnedUserId, UserId,
|
||||
api::client::error::{ErrorKind, RetryAfter},
|
||||
http_headers::ContentDisposition,
|
||||
};
|
||||
use tokio::{
|
||||
fs,
|
||||
io::{AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
sync::Notify,
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace,
|
||||
utils::{self, MutexMap},
|
||||
Err, Error, Result, debug, debug_error, debug_info, debug_warn, err, error, trace,
|
||||
utils::{self, MutexMap, time::now_millis},
|
||||
warn,
|
||||
};
|
||||
|
||||
@@ -30,10 +41,19 @@ pub struct FileMeta {
|
||||
pub content_disposition: Option<ContentDisposition>,
|
||||
}
|
||||
|
||||
/// For MSC2246
|
||||
struct MXCState {
|
||||
/// Save the notifier for each pending media upload
|
||||
notifiers: Mutex<HashMap<OwnedMxcUri, Arc<Notify>>>,
|
||||
/// Save the ratelimiter for each user
|
||||
ratelimiter: Mutex<HashMap<OwnedUserId, (Instant, f64)>>,
|
||||
}
|
||||
|
||||
pub struct Service {
|
||||
url_preview_mutex: MutexMap<String, ()>,
|
||||
pub(super) db: Data,
|
||||
services: Arc<crate::services::OnceServices>,
|
||||
url_preview_mutex: MutexMap<String, ()>,
|
||||
mxc_state: MXCState,
|
||||
}
|
||||
|
||||
/// generated MXC ID (`media-id`) length
|
||||
@@ -49,9 +69,13 @@ pub const CORP_CROSS_ORIGIN: &str = "cross-origin";
|
||||
impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(Self {
|
||||
url_preview_mutex: MutexMap::new(),
|
||||
db: Data::new(args.db),
|
||||
services: args.services.clone(),
|
||||
url_preview_mutex: MutexMap::new(),
|
||||
mxc_state: MXCState {
|
||||
notifiers: Mutex::new(HashMap::new()),
|
||||
ratelimiter: Mutex::new(HashMap::new()),
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
@@ -65,6 +89,113 @@ impl crate::Service for Service {
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Create a pending media upload ID.
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn create_pending(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: &UserId,
|
||||
unused_expires_at: u64,
|
||||
) -> Result {
|
||||
let config = &self.services.server.config;
|
||||
|
||||
// Rate limiting (rc_media_create)
|
||||
let rate = f64::from(config.media_rc_create_per_second);
|
||||
let burst = f64::from(config.media_rc_create_burst_count);
|
||||
|
||||
// Check rate limiting
|
||||
if rate > 0.0 && burst > 0.0 {
|
||||
let now = Instant::now();
|
||||
let mut ratelimiter = self.mxc_state.ratelimiter.lock()?;
|
||||
|
||||
let (last_time, tokens) = ratelimiter
|
||||
.entry(user.to_owned())
|
||||
.or_insert_with(|| (now, burst));
|
||||
|
||||
let elapsed = now.duration_since(*last_time).as_secs_f64();
|
||||
let new_tokens = elapsed.mul_add(rate, *tokens).min(burst);
|
||||
|
||||
if new_tokens >= 1.0 {
|
||||
*last_time = now;
|
||||
*tokens = new_tokens - 1.0;
|
||||
} else {
|
||||
return Err(Error::Request(
|
||||
ErrorKind::LimitExceeded { retry_after: None },
|
||||
"Too many pending media creation requests.".into(),
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
let max_uploads = config.max_pending_media_uploads;
|
||||
let (current_uploads, earliest_expiration) =
|
||||
self.db.count_pending_mxc_for_user(user).await;
|
||||
|
||||
// Check if the user has reached the maximum number of pending media uploads
|
||||
if current_uploads >= max_uploads {
|
||||
let retry_after = earliest_expiration.saturating_sub(now_millis());
|
||||
return Err(Error::Request(
|
||||
ErrorKind::LimitExceeded {
|
||||
retry_after: Some(RetryAfter::Delay(Duration::from_millis(retry_after))),
|
||||
},
|
||||
"Maximum number of pending media uploads reached.".into(),
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
));
|
||||
}
|
||||
|
||||
self.db
|
||||
.insert_pending_mxc(mxc, user, unused_expires_at);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Uploads content to a pending media ID.
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn upload_pending(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
user: &UserId,
|
||||
content_disposition: Option<&ContentDisposition>,
|
||||
content_type: Option<&str>,
|
||||
file: &[u8],
|
||||
) -> Result {
|
||||
let Ok((owner_id, expires_at)) = self.db.search_pending_mxc(mxc).await else {
|
||||
if self.get_metadata(mxc).await.is_some() {
|
||||
return Err!(Request(CannotOverwriteMedia("Media ID already has content")));
|
||||
}
|
||||
|
||||
return Err!(Request(NotFound("Media not found")));
|
||||
};
|
||||
|
||||
if owner_id != user {
|
||||
return Err!(Request(Forbidden("You did not create this media ID")));
|
||||
}
|
||||
|
||||
let current_time = now_millis();
|
||||
if expires_at < current_time {
|
||||
return Err!(Request(NotFound("Pending media ID expired")));
|
||||
}
|
||||
|
||||
self.create(mxc, Some(user), content_disposition, content_type, file)
|
||||
.await?;
|
||||
|
||||
self.db.remove_pending_mxc(mxc);
|
||||
|
||||
let mxc_uri: OwnedMxcUri = mxc.to_string().into();
|
||||
if let Some(notifier) = self
|
||||
.mxc_state
|
||||
.notifiers
|
||||
.lock()?
|
||||
.get(&mxc_uri)
|
||||
.cloned()
|
||||
{
|
||||
notifier.notify_waiters();
|
||||
self.mxc_state.notifiers.lock()?.remove(&mxc_uri);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self,
|
||||
@@ -168,6 +299,71 @@ impl Service {
|
||||
}
|
||||
}
|
||||
|
||||
/// Download a file and wait up to a timeout_ms if it is pending.
|
||||
pub async fn get_with_timeout(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
if let Some(meta) = self.get(mxc).await? {
|
||||
return Ok(Some(meta));
|
||||
}
|
||||
|
||||
let Ok(_pending) = self.db.search_pending_mxc(mxc).await else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let notifier = self
|
||||
.mxc_state
|
||||
.notifiers
|
||||
.lock()?
|
||||
.entry(mxc.to_string().into())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone();
|
||||
|
||||
if tokio::time::timeout(timeout_duration, notifier.notified())
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
|
||||
}
|
||||
|
||||
self.get(mxc).await
|
||||
}
|
||||
|
||||
/// Download a thumbnail and wait up to a timeout_ms if it is pending.
|
||||
pub async fn get_thumbnail_with_timeout(
|
||||
&self,
|
||||
mxc: &Mxc<'_>,
|
||||
dim: &Dim,
|
||||
timeout_duration: Duration,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
if let Some(meta) = self.get_thumbnail(mxc, dim).await? {
|
||||
return Ok(Some(meta));
|
||||
}
|
||||
|
||||
let Ok(_pending) = self.db.search_pending_mxc(mxc).await else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let notifier = self
|
||||
.mxc_state
|
||||
.notifiers
|
||||
.lock()?
|
||||
.entry(mxc.to_string().into())
|
||||
.or_insert_with(|| Arc::new(Notify::new()))
|
||||
.clone();
|
||||
|
||||
if tokio::time::timeout(timeout_duration, notifier.notified())
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err!(Request(NotYetUploaded("Media has not been uploaded yet")));
|
||||
}
|
||||
|
||||
self.get_thumbnail(mxc, dim).await
|
||||
}
|
||||
|
||||
/// Gets all the MXC URIs in our media database
|
||||
pub async fn get_all_mxcs(&self) -> Result<Vec<OwnedMxcUri>> {
|
||||
let all_keys = self.db.get_all_media_keys().await;
|
||||
|
||||
@@ -29,7 +29,7 @@ use serde_json::value::RawValue as RawJsonValue;
|
||||
use tuwunel_core::{
|
||||
Err, Result, at, debug, debug_error, debug_info, debug_warn, err, error, implement, info,
|
||||
matrix::{event::gen_event_id_canonical_json, room_version},
|
||||
pdu::{PduBuilder, check_pdu_format, format::from_incoming_federation},
|
||||
pdu::{Pdu, PduBuilder, check_rules},
|
||||
trace,
|
||||
utils::{self, BoolExt, IterStream, ReadyExt, future::TryExtExt, math::Expected, shuffle},
|
||||
warn,
|
||||
@@ -310,8 +310,8 @@ pub async fn join_remote(
|
||||
%shortroomid,
|
||||
"Initialized room. Parsing join event..."
|
||||
);
|
||||
let parsed_join_pdu =
|
||||
from_incoming_federation(room_id, &event_id, &mut join_event, &room_version_rules)?;
|
||||
let (parsed_join_pdu, join_event) =
|
||||
Pdu::from_object_federation(room_id, &event_id, join_event, &room_version_rules)?;
|
||||
|
||||
let resp_state = &response.state;
|
||||
let resp_auth = &response.auth_chain;
|
||||
@@ -337,12 +337,12 @@ pub async fn join_remote(
|
||||
})
|
||||
.inspect_err(|e| debug_error!("Invalid send_join state event: {e:?}"))
|
||||
.ready_filter_map(Result::ok)
|
||||
.ready_filter_map(|(event_id, mut value)| {
|
||||
from_incoming_federation(room_id, &event_id, &mut value, &room_version_rules)
|
||||
.ready_filter_map(|(event_id, value)| {
|
||||
Pdu::from_object_federation(room_id, &event_id, value, &room_version_rules)
|
||||
.inspect_err(|e| {
|
||||
debug_warn!("Invalid PDU in send_join response: {e:?}: {value:#?}");
|
||||
debug_warn!("Invalid PDU {event_id:?} in send_join response: {e:?}");
|
||||
})
|
||||
.map(move |pdu| (event_id, pdu, value))
|
||||
.map(move |(pdu, value)| (event_id, pdu, value))
|
||||
.ok()
|
||||
})
|
||||
.fold(HashMap::new(), async |mut state, (event_id, pdu, value)| {
|
||||
@@ -757,7 +757,7 @@ async fn create_join_event(
|
||||
.server_keys
|
||||
.gen_id_hash_and_sign_event(&mut event, room_version_id)?;
|
||||
|
||||
check_pdu_format(&event, &room_version_rules.event_format)?;
|
||||
check_rules(&event, &room_version_rules.event_format)?;
|
||||
|
||||
Ok((event, event_id, join_authorized_via_users_server))
|
||||
}
|
||||
|
||||
@@ -308,7 +308,7 @@ async fn knock_room_helper_local(
|
||||
|
||||
info!("Parsing knock event");
|
||||
|
||||
let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone())
|
||||
let parsed_knock_pdu = PduEvent::from_object_and_eventid(&event_id, knock_event.clone())
|
||||
.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
|
||||
|
||||
info!("Updating membership locally to knock state with provided stripped state events");
|
||||
@@ -480,7 +480,7 @@ async fn knock_room_helper_remote(
|
||||
.await;
|
||||
|
||||
info!("Parsing knock event");
|
||||
let parsed_knock_pdu = PduEvent::from_id_val(&event_id, knock_event.clone())
|
||||
let parsed_knock_pdu = PduEvent::from_object_and_eventid(&event_id, knock_event.clone())
|
||||
.map_err(|e| err!(BadServerResponse("Invalid knock event PDU: {e:?}")))?;
|
||||
|
||||
info!("Going through send_knock response knock state events");
|
||||
|
||||
@@ -16,7 +16,7 @@ use ruma::{
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Err, Result, debug_info, debug_warn, err, implement,
|
||||
matrix::{PduCount, pdu::check_pdu_format, room_version},
|
||||
matrix::{PduCount, pdu::check_rules, room_version},
|
||||
pdu::PduBuilder,
|
||||
utils::{
|
||||
self, FutureBoolExt,
|
||||
@@ -353,7 +353,7 @@ async fn remote_leave(
|
||||
.server_keys
|
||||
.gen_id_hash_and_sign_event(&mut event, &room_version_id)?;
|
||||
|
||||
check_pdu_format(&event, &room_version_rules.event_format)?;
|
||||
check_rules(&event, &room_version_rules.event_format)?;
|
||||
|
||||
self.services
|
||||
.federation
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
#![recursion_limit = "256"]
|
||||
#![type_length_limit = "98304"]
|
||||
#![expect(refining_impl_trait)]
|
||||
#![expect(clippy::duration_suboptimal_units)] // remove after MSRV 1.91
|
||||
#![allow(unused_features)] // 1.96.0-nightly 2026-03-07 bug
|
||||
|
||||
mod manager;
|
||||
mod migrations;
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod oidc_server;
|
||||
pub mod providers;
|
||||
pub mod sessions;
|
||||
pub mod user_info;
|
||||
@@ -14,13 +15,14 @@ use ruma::UserId;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
use tuwunel_core::{
|
||||
Err, Result, err, implement,
|
||||
Err, Result, err, implement, info, warn,
|
||||
utils::{hash::sha256, result::LogErr, stream::ReadyExt},
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
use self::{providers::Providers, sessions::Sessions};
|
||||
use self::{oidc_server::OidcServer, providers::Providers, sessions::Sessions};
|
||||
pub use self::{
|
||||
oidc_server::ProviderMetadata,
|
||||
providers::{Provider, ProviderId},
|
||||
sessions::{CODE_VERIFIER_LENGTH, SESSION_ID_LENGTH, Session, SessionId},
|
||||
user_info::UserInfo,
|
||||
@@ -31,16 +33,33 @@ pub struct Service {
|
||||
services: SelfServices,
|
||||
pub providers: Arc<Providers>,
|
||||
pub sessions: Arc<Sessions>,
|
||||
/// OIDC server (authorization server) for next-gen Matrix auth.
|
||||
/// Only initialized when identity providers are configured.
|
||||
pub oidc_server: Option<Arc<OidcServer>>,
|
||||
}
|
||||
|
||||
impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
let providers = Arc::new(Providers::build(args));
|
||||
let sessions = Arc::new(Sessions::build(args, providers.clone()));
|
||||
|
||||
let oidc_server = if !args.server.config.identity_provider.is_empty()
|
||||
|| args.server.config.well_known.client.is_some()
|
||||
{
|
||||
if args.server.config.identity_provider.is_empty() {
|
||||
warn!("OIDC server enabled (well_known.client is set) but no identity_provider configured; authorization flow will not work");
|
||||
}
|
||||
info!("Initializing OIDC server for next-gen auth (MSC2965)");
|
||||
Some(Arc::new(OidcServer::build(args)?))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
services: args.services.clone(),
|
||||
sessions,
|
||||
providers,
|
||||
oidc_server,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
217
src/service/oauth/oidc_server.rs
Normal file
217
src/service/oauth/oidc_server.rs
Normal file
@@ -0,0 +1,217 @@
|
||||
use std::{sync::Arc, time::{Duration, SystemTime}};
|
||||
|
||||
use base64::{Engine as _, engine::general_purpose::URL_SAFE_NO_PAD as b64};
|
||||
use ring::{rand::SystemRandom, signature::{self, EcdsaKeyPair, KeyPair}};
|
||||
use ruma::OwnedUserId;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{Err, Result, err, info, jwt, utils};
|
||||
use tuwunel_database::{Cbor, Deserialized, Map};
|
||||
|
||||
const AUTH_CODE_LENGTH: usize = 64;
|
||||
const OIDC_CLIENT_ID_LENGTH: usize = 32;
|
||||
const AUTH_CODE_LIFETIME: Duration = Duration::from_secs(600);
|
||||
const AUTH_REQUEST_LIFETIME: Duration = Duration::from_secs(600);
|
||||
const SIGNING_KEY_DB_KEY: &str = "oidc_signing_key";
|
||||
|
||||
pub struct OidcServer { db: Data, signing_key_der: Vec<u8>, jwk: serde_json::Value, key_id: String }
|
||||
|
||||
struct Data { oidc_signingkey: Arc<Map>, oidcclientid_registration: Arc<Map>, oidccode_authsession: Arc<Map>, oidcreqid_authrequest: Arc<Map> }
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
pub struct DcrRequest {
|
||||
pub redirect_uris: Vec<String>, pub client_name: Option<String>, pub client_uri: Option<String>,
|
||||
pub logo_uri: Option<String>, #[serde(default)] pub contacts: Vec<String>,
|
||||
pub token_endpoint_auth_method: Option<String>, pub grant_types: Option<Vec<String>>,
|
||||
pub response_types: Option<Vec<String>>, pub application_type: Option<String>,
|
||||
pub policy_uri: Option<String>, pub tos_uri: Option<String>,
|
||||
pub software_id: Option<String>, pub software_version: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct OidcClientRegistration {
|
||||
pub client_id: String, pub redirect_uris: Vec<String>, pub client_name: Option<String>,
|
||||
pub client_uri: Option<String>, pub logo_uri: Option<String>, pub contacts: Vec<String>,
|
||||
pub token_endpoint_auth_method: String, pub grant_types: Vec<String>, pub response_types: Vec<String>,
|
||||
pub application_type: Option<String>, pub policy_uri: Option<String>, pub tos_uri: Option<String>,
|
||||
pub software_id: Option<String>, pub software_version: Option<String>, pub registered_at: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct AuthCodeSession {
|
||||
pub code: String, pub client_id: String, pub redirect_uri: String, pub scope: String,
|
||||
pub state: Option<String>, pub nonce: Option<String>, pub code_challenge: Option<String>,
|
||||
pub code_challenge_method: Option<String>, pub user_id: OwnedUserId,
|
||||
pub created_at: SystemTime, pub expires_at: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct OidcAuthRequest {
|
||||
pub client_id: String, pub redirect_uri: String, pub scope: String, pub state: Option<String>,
|
||||
pub nonce: Option<String>, pub code_challenge: Option<String>, pub code_challenge_method: Option<String>,
|
||||
pub created_at: SystemTime, pub expires_at: SystemTime,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct SigningKeyData { key_der: Vec<u8>, key_id: String }
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ProviderMetadata {
|
||||
pub issuer: String, pub authorization_endpoint: String, pub token_endpoint: String,
|
||||
pub registration_endpoint: Option<String>, pub revocation_endpoint: Option<String>, pub jwks_uri: String,
|
||||
pub userinfo_endpoint: Option<String>, pub account_management_uri: Option<String>,
|
||||
pub account_management_actions_supported: Option<Vec<String>>, pub response_types_supported: Vec<String>,
|
||||
pub response_modes_supported: Option<Vec<String>>, pub grant_types_supported: Option<Vec<String>>,
|
||||
pub code_challenge_methods_supported: Option<Vec<String>>, pub token_endpoint_auth_methods_supported: Option<Vec<String>>,
|
||||
pub scopes_supported: Option<Vec<String>>, pub subject_types_supported: Option<Vec<String>>,
|
||||
pub id_token_signing_alg_values_supported: Option<Vec<String>>, pub prompt_values_supported: Option<Vec<String>>,
|
||||
pub claim_types_supported: Option<Vec<String>>, pub claims_supported: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct IdTokenClaims {
|
||||
pub iss: String, pub sub: String, pub aud: String, pub exp: u64, pub iat: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] pub nonce: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")] pub at_hash: Option<String>,
|
||||
}
|
||||
|
||||
impl ProviderMetadata { #[must_use] pub fn into_json(self) -> serde_json::Value { serde_json::to_value(self).unwrap() } }
|
||||
|
||||
impl OidcServer {
|
||||
pub(crate) fn build(args: &crate::Args<'_>) -> Result<Self> {
|
||||
let db = Data {
|
||||
oidc_signingkey: args.db["oidc_signingkey"].clone(),
|
||||
oidcclientid_registration: args.db["oidcclientid_registration"].clone(),
|
||||
oidccode_authsession: args.db["oidccode_authsession"].clone(),
|
||||
oidcreqid_authrequest: args.db["oidcreqid_authrequest"].clone(),
|
||||
};
|
||||
|
||||
let (signing_key_der, key_id) = match db.oidc_signingkey.get_blocking(SIGNING_KEY_DB_KEY).and_then(|handle| handle.deserialized::<Cbor<SigningKeyData>>().map(|cbor| cbor.0)) {
|
||||
| Ok(data) => { info!("Loaded existing OIDC signing key (kid={})", data.key_id); (data.key_der, data.key_id) },
|
||||
| Err(_) => {
|
||||
let (key_der, key_id) = Self::generate_signing_key()?;
|
||||
info!("Generated new OIDC signing key (kid={key_id})");
|
||||
let data = SigningKeyData { key_der: key_der.clone(), key_id: key_id.clone() };
|
||||
db.oidc_signingkey.raw_put(SIGNING_KEY_DB_KEY, Cbor(&data));
|
||||
(key_der, key_id)
|
||||
},
|
||||
};
|
||||
|
||||
let jwk = Self::build_jwk(&signing_key_der, &key_id)?;
|
||||
Ok(Self { db, signing_key_der, jwk, key_id })
|
||||
}
|
||||
|
||||
fn generate_signing_key() -> Result<(Vec<u8>, String)> {
|
||||
let rng = SystemRandom::new();
|
||||
let alg = &signature::ECDSA_P256_SHA256_FIXED_SIGNING;
|
||||
let pkcs8 = EcdsaKeyPair::generate_pkcs8(alg, &rng).map_err(|e| err!(error!("Failed to generate ECDSA key: {e}")))?;
|
||||
let key_id = utils::random_string(16);
|
||||
Ok((pkcs8.as_ref().to_vec(), key_id))
|
||||
}
|
||||
|
||||
fn build_jwk(signing_key_der: &[u8], key_id: &str) -> Result<serde_json::Value> {
|
||||
let rng = SystemRandom::new();
|
||||
let alg = &signature::ECDSA_P256_SHA256_FIXED_SIGNING;
|
||||
let key_pair = EcdsaKeyPair::from_pkcs8(alg, signing_key_der, &rng).map_err(|e| err!(error!("Failed to load ECDSA key: {e}")))?;
|
||||
let public_bytes = key_pair.public_key().as_ref();
|
||||
let x = b64.encode(&public_bytes[1..33]);
|
||||
let y = b64.encode(&public_bytes[33..65]);
|
||||
Ok(serde_json::json!({"kty": "EC", "crv": "P-256", "use": "sig", "alg": "ES256", "kid": key_id, "x": x, "y": y}))
|
||||
}
|
||||
|
||||
pub fn register_client(&self, request: DcrRequest) -> Result<OidcClientRegistration> {
|
||||
let client_id = utils::random_string(OIDC_CLIENT_ID_LENGTH);
|
||||
let auth_method = request.token_endpoint_auth_method.unwrap_or_else(|| "none".to_owned());
|
||||
let registration = OidcClientRegistration {
|
||||
client_id: client_id.clone(), redirect_uris: request.redirect_uris, client_name: request.client_name,
|
||||
client_uri: request.client_uri, logo_uri: request.logo_uri, contacts: request.contacts,
|
||||
token_endpoint_auth_method: auth_method,
|
||||
grant_types: request.grant_types.unwrap_or_else(|| vec!["authorization_code".to_owned(), "refresh_token".to_owned()]),
|
||||
response_types: request.response_types.unwrap_or_else(|| vec!["code".to_owned()]),
|
||||
application_type: request.application_type, policy_uri: request.policy_uri, tos_uri: request.tos_uri,
|
||||
software_id: request.software_id, software_version: request.software_version,
|
||||
registered_at: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap_or_default().as_secs(),
|
||||
};
|
||||
self.db.oidcclientid_registration.raw_put(&*client_id, Cbor(®istration));
|
||||
Ok(registration)
|
||||
}
|
||||
|
||||
pub async fn get_client(&self, client_id: &str) -> Result<OidcClientRegistration> {
|
||||
self.db.oidcclientid_registration.get(client_id).await.deserialized::<Cbor<_>>().map(|cbor: Cbor<OidcClientRegistration>| cbor.0).map_err(|_| err!(Request(NotFound("Unknown client_id"))))
|
||||
}
|
||||
|
||||
pub async fn validate_redirect_uri(&self, client_id: &str, redirect_uri: &str) -> Result {
|
||||
let client = self.get_client(client_id).await?;
|
||||
if client.redirect_uris.iter().any(|uri| uri == redirect_uri) { Ok(()) } else { Err!(Request(InvalidParam("redirect_uri not registered for this client"))) }
|
||||
}
|
||||
|
||||
pub fn store_auth_request(&self, req_id: &str, request: &OidcAuthRequest) { self.db.oidcreqid_authrequest.raw_put(req_id, Cbor(request)); }
|
||||
|
||||
pub async fn take_auth_request(&self, req_id: &str) -> Result<OidcAuthRequest> {
|
||||
let request: OidcAuthRequest = self.db.oidcreqid_authrequest.get(req_id).await.deserialized::<Cbor<_>>().map(|cbor: Cbor<OidcAuthRequest>| cbor.0).map_err(|_| err!(Request(NotFound("Unknown or expired authorization request"))))?;
|
||||
self.db.oidcreqid_authrequest.remove(req_id);
|
||||
if SystemTime::now() > request.expires_at { return Err!(Request(NotFound("Authorization request has expired"))); }
|
||||
Ok(request)
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn create_auth_code(&self, auth_req: &OidcAuthRequest, user_id: OwnedUserId) -> String {
|
||||
let code = utils::random_string(AUTH_CODE_LENGTH);
|
||||
let now = SystemTime::now();
|
||||
let session = AuthCodeSession {
|
||||
code: code.clone(), client_id: auth_req.client_id.clone(), redirect_uri: auth_req.redirect_uri.clone(),
|
||||
scope: auth_req.scope.clone(), state: auth_req.state.clone(), nonce: auth_req.nonce.clone(),
|
||||
code_challenge: auth_req.code_challenge.clone(), code_challenge_method: auth_req.code_challenge_method.clone(),
|
||||
user_id, created_at: now, expires_at: now.checked_add(AUTH_CODE_LIFETIME).unwrap_or(now),
|
||||
};
|
||||
self.db.oidccode_authsession.raw_put(&*code, Cbor(&session));
|
||||
code
|
||||
}
|
||||
|
||||
pub async fn exchange_auth_code(&self, code: &str, client_id: &str, redirect_uri: &str, code_verifier: Option<&str>) -> Result<AuthCodeSession> {
|
||||
let session: AuthCodeSession = self.db.oidccode_authsession.get(code).await.deserialized::<Cbor<_>>().map(|cbor: Cbor<AuthCodeSession>| cbor.0).map_err(|_| err!(Request(Forbidden("Invalid or expired authorization code"))))?;
|
||||
self.db.oidccode_authsession.remove(code);
|
||||
if SystemTime::now() > session.expires_at { return Err!(Request(Forbidden("Authorization code has expired"))); }
|
||||
if session.client_id != client_id { return Err!(Request(Forbidden("client_id mismatch"))); }
|
||||
if session.redirect_uri != redirect_uri { return Err!(Request(Forbidden("redirect_uri mismatch"))); }
|
||||
|
||||
if let Some(challenge) = &session.code_challenge {
|
||||
let Some(verifier) = code_verifier else { return Err!(Request(Forbidden("code_verifier required for PKCE"))); };
|
||||
Self::validate_code_verifier(verifier)?;
|
||||
let method = session.code_challenge_method.as_deref().unwrap_or("S256");
|
||||
let computed = match method {
|
||||
| "S256" => { let hash = utils::hash::sha256::hash(verifier.as_bytes()); b64.encode(hash) },
|
||||
| "plain" => verifier.to_owned(),
|
||||
| _ => return Err!(Request(InvalidParam("Unsupported code_challenge_method"))),
|
||||
};
|
||||
if computed != *challenge { return Err!(Request(Forbidden("PKCE verification failed"))); }
|
||||
}
|
||||
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
/// Validate code_verifier per RFC 7636 Section 4.1: must be 43-128
|
||||
/// characters using only unreserved characters [A-Z] / [a-z] / [0-9] /
|
||||
/// "-" / "." / "_" / "~".
|
||||
fn validate_code_verifier(verifier: &str) -> Result {
|
||||
if !(43..=128).contains(&verifier.len()) {
|
||||
return Err!(Request(InvalidParam("code_verifier must be 43-128 characters")));
|
||||
}
|
||||
if !verifier.bytes().all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'.' || b == b'_' || b == b'~') {
|
||||
return Err!(Request(InvalidParam("code_verifier contains invalid characters")));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn sign_id_token(&self, claims: &IdTokenClaims) -> Result<String> {
|
||||
let mut header = jwt::Header::new(jwt::Algorithm::ES256);
|
||||
header.kid = Some(self.key_id.clone());
|
||||
let key = jwt::EncodingKey::from_ec_der(&self.signing_key_der);
|
||||
jwt::encode(&header, claims, &key).map_err(|e| err!(error!("Failed to sign ID token: {e}")))
|
||||
}
|
||||
|
||||
#[must_use] pub fn jwks(&self) -> serde_json::Value { serde_json::json!({"keys": [self.jwk.clone()]}) }
|
||||
|
||||
#[must_use] pub fn at_hash(access_token: &str) -> String { let hash = utils::hash::sha256::hash(access_token.as_bytes()); b64.encode(&hash[..16]) }
|
||||
|
||||
#[must_use] pub fn auth_request_lifetime() -> Duration { AUTH_REQUEST_LIFETIME }
|
||||
}
|
||||
@@ -112,8 +112,8 @@ impl Resolver {
|
||||
fn configure_opts(server: &Arc<Server>, mut opts: ResolverOpts) -> ResolverOpts {
|
||||
let config = &server.config;
|
||||
|
||||
opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30));
|
||||
opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7));
|
||||
opts.negative_max_ttl = Some(Duration::from_hours(720));
|
||||
opts.positive_max_ttl = Some(Duration::from_hours(168));
|
||||
opts.timeout = Duration::from_secs(config.dns_timeout);
|
||||
opts.attempts = config.dns_attempts as usize;
|
||||
opts.try_tcp_on_error = config.dns_tcp_fallback;
|
||||
|
||||
@@ -10,7 +10,7 @@ use ruma::{
|
||||
ServerName, api::federation::event::get_event,
|
||||
};
|
||||
use tuwunel_core::{
|
||||
debug, debug_error, debug_warn, implement,
|
||||
debug, debug_error, debug_warn, expected, implement,
|
||||
matrix::{PduEvent, event::gen_event_id_canonical_json, pdu::MAX_AUTH_EVENTS},
|
||||
trace,
|
||||
utils::stream::{BroadbandExt, IterStream, ReadyExt},
|
||||
@@ -30,7 +30,11 @@ use tuwunel_core::{
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(%origin),
|
||||
fields(
|
||||
%origin,
|
||||
events = %events.clone().count(),
|
||||
lev = %recursion_level,
|
||||
),
|
||||
)]
|
||||
pub(super) async fn fetch_auth<'a, Events>(
|
||||
&self,
|
||||
@@ -38,9 +42,10 @@ pub(super) async fn fetch_auth<'a, Events>(
|
||||
room_id: &RoomId,
|
||||
events: Events,
|
||||
room_version: &RoomVersionId,
|
||||
recursion_level: usize,
|
||||
) -> Vec<(PduEvent, Option<CanonicalJsonObject>)>
|
||||
where
|
||||
Events: Iterator<Item = &'a EventId> + Send,
|
||||
Events: Iterator<Item = &'a EventId> + Clone + Send,
|
||||
{
|
||||
let events_with_auth_events: Vec<_> = events
|
||||
.stream()
|
||||
@@ -66,8 +71,8 @@ where
|
||||
.stream()
|
||||
.ready_filter(|(next_id, _)| {
|
||||
let backed_off = self.is_backed_off(next_id, Range {
|
||||
start: Duration::from_secs(5 * 60),
|
||||
end: Duration::from_secs(60 * 60 * 24),
|
||||
start: Duration::from_mins(5),
|
||||
end: Duration::from_hours(24),
|
||||
});
|
||||
|
||||
!backed_off
|
||||
@@ -79,6 +84,7 @@ where
|
||||
&next_id,
|
||||
value.clone(),
|
||||
room_version,
|
||||
expected!(recursion_level + 1),
|
||||
true,
|
||||
));
|
||||
|
||||
@@ -135,8 +141,8 @@ async fn fetch_auth_chain(
|
||||
}
|
||||
|
||||
if self.is_backed_off(&next_id, Range {
|
||||
start: Duration::from_secs(2 * 60),
|
||||
end: Duration::from_secs(60 * 60 * 8),
|
||||
start: Duration::from_mins(2),
|
||||
end: Duration::from_hours(8),
|
||||
}) {
|
||||
debug_warn!("Backed off from {next_id}");
|
||||
continue;
|
||||
|
||||
@@ -7,18 +7,23 @@ use ruma::{
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Result, debug_warn, err, implement,
|
||||
matrix::{Event, PduEvent, pdu::MAX_PREV_EVENTS},
|
||||
matrix::{
|
||||
Event, PduEvent,
|
||||
pdu::{MAX_PREV_EVENTS, check_room_id},
|
||||
},
|
||||
utils::stream::IterStream,
|
||||
};
|
||||
|
||||
use super::check_room_id;
|
||||
use crate::rooms::state_res;
|
||||
|
||||
#[implement(super::Service)]
|
||||
#[tracing::instrument(
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(%origin),
|
||||
fields(
|
||||
%origin,
|
||||
events = %initial_set.clone().count(),
|
||||
),
|
||||
)]
|
||||
#[expect(clippy::type_complexity)]
|
||||
pub(super) async fn fetch_prev<'a, Events>(
|
||||
@@ -27,10 +32,11 @@ pub(super) async fn fetch_prev<'a, Events>(
|
||||
room_id: &RoomId,
|
||||
initial_set: Events,
|
||||
room_version: &RoomVersionId,
|
||||
recursion_level: usize,
|
||||
first_ts_in_room: MilliSecondsSinceUnixEpoch,
|
||||
) -> Result<(Vec<OwnedEventId>, HashMap<OwnedEventId, (PduEvent, CanonicalJsonObject)>)>
|
||||
where
|
||||
Events: Iterator<Item = &'a EventId> + Send,
|
||||
Events: Iterator<Item = &'a EventId> + Clone + Send,
|
||||
{
|
||||
let mut todo_outlier_stack: FuturesOrdered<_> = initial_set
|
||||
.stream()
|
||||
@@ -46,7 +52,7 @@ where
|
||||
.map(async |event_id| {
|
||||
let events = once(event_id.as_ref());
|
||||
let auth = self
|
||||
.fetch_auth(origin, room_id, events, room_version)
|
||||
.fetch_auth(origin, room_id, events, room_version, recursion_level)
|
||||
.await;
|
||||
|
||||
(event_id, auth)
|
||||
@@ -65,7 +71,7 @@ where
|
||||
continue;
|
||||
};
|
||||
|
||||
check_room_id(room_id, &pdu)?;
|
||||
check_room_id(&pdu, room_id)?;
|
||||
|
||||
let limit = self.services.server.config.max_fetch_prev_events;
|
||||
if amount > limit {
|
||||
@@ -104,7 +110,13 @@ where
|
||||
let prev_prev = prev_prev.to_owned();
|
||||
let fetch = async move {
|
||||
let fetch = self
|
||||
.fetch_auth(origin, room_id, once(prev_prev.as_ref()), room_version)
|
||||
.fetch_auth(
|
||||
origin,
|
||||
room_id,
|
||||
once(prev_prev.as_ref()),
|
||||
room_version,
|
||||
recursion_level,
|
||||
)
|
||||
.await;
|
||||
|
||||
(prev_prev, fetch)
|
||||
|
||||
@@ -24,6 +24,7 @@ pub(super) async fn fetch_state(
|
||||
room_id: &RoomId,
|
||||
event_id: &EventId,
|
||||
room_version: &RoomVersionId,
|
||||
recursion_level: usize,
|
||||
create_event_id: &EventId,
|
||||
) -> Result<Option<HashMap<u64, OwnedEventId>>> {
|
||||
let res = self
|
||||
@@ -39,7 +40,7 @@ pub(super) async fn fetch_state(
|
||||
debug!("Fetching state events");
|
||||
let state_ids = res.pdu_ids.iter().map(AsRef::as_ref);
|
||||
let state_vec = self
|
||||
.fetch_auth(origin, room_id, state_ids, room_version)
|
||||
.fetch_auth(origin, room_id, state_ids, room_version, recursion_level)
|
||||
.boxed()
|
||||
.await;
|
||||
|
||||
|
||||
@@ -124,9 +124,10 @@ pub async fn handle_incoming_pdu<'a>(
|
||||
}
|
||||
|
||||
let room_version = room_version::from_create_event(create_event)?;
|
||||
let recursion_level = 0;
|
||||
|
||||
let (incoming_pdu, pdu) = self
|
||||
.handle_outlier_pdu(origin, room_id, event_id, pdu, &room_version, false)
|
||||
.handle_outlier_pdu(origin, room_id, event_id, pdu, &room_version, recursion_level, false)
|
||||
.await?;
|
||||
|
||||
// 8. if not timeline event: stop
|
||||
@@ -158,7 +159,14 @@ pub async fn handle_incoming_pdu<'a>(
|
||||
// 9. Fetch any missing prev events doing all checks listed here starting at 1.
|
||||
// These are timeline events
|
||||
let (sorted_prev_events, mut eventid_info) = self
|
||||
.fetch_prev(origin, room_id, incoming_pdu.prev_events(), &room_version, first_ts_in_room)
|
||||
.fetch_prev(
|
||||
origin,
|
||||
room_id,
|
||||
incoming_pdu.prev_events(),
|
||||
&room_version,
|
||||
recursion_level,
|
||||
first_ts_in_room,
|
||||
)
|
||||
.await?;
|
||||
|
||||
trace!(
|
||||
@@ -180,6 +188,7 @@ pub async fn handle_incoming_pdu<'a>(
|
||||
event_id,
|
||||
eventid_info,
|
||||
&room_version,
|
||||
recursion_level,
|
||||
first_ts_in_room,
|
||||
&prev_id,
|
||||
create_event.event_id(),
|
||||
@@ -199,7 +208,7 @@ pub async fn handle_incoming_pdu<'a>(
|
||||
},
|
||||
| Err(e) => {
|
||||
self.back_off(&prev_id);
|
||||
warn!(?i, ?prev_id, "Prev event processing failed: {e}");
|
||||
warn!(?i, ?prev_id, ?event_id, ?room_id, "Prev event processing failed: {e}");
|
||||
|
||||
Ok((prev_id, Err(e)))
|
||||
},
|
||||
@@ -221,6 +230,7 @@ pub async fn handle_incoming_pdu<'a>(
|
||||
incoming_pdu,
|
||||
pdu,
|
||||
&room_version,
|
||||
recursion_level,
|
||||
create_event.event_id(),
|
||||
)
|
||||
.boxed()
|
||||
|
||||
@@ -5,16 +5,21 @@ use ruma::{
|
||||
use tuwunel_core::{
|
||||
Err, Result, debug, debug_info, err, implement,
|
||||
matrix::{Event, PduEvent, event::TypeExt, room_version},
|
||||
pdu::{check_pdu_format, format::from_incoming_federation},
|
||||
ref_at, trace,
|
||||
utils::{future::TryExtExt, stream::IterStream},
|
||||
warn,
|
||||
};
|
||||
|
||||
use super::check_room_id;
|
||||
use crate::rooms::state_res;
|
||||
|
||||
#[implement(super::Service)]
|
||||
#[cfg_attr(unabridged, tracing::instrument(
|
||||
name = "outlier",
|
||||
level = "debug",
|
||||
skip_all,
|
||||
fields(lev = %recursion_level)
|
||||
))]
|
||||
#[expect(clippy::too_many_arguments)]
|
||||
pub(super) async fn handle_outlier_pdu(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
@@ -22,9 +27,10 @@ pub(super) async fn handle_outlier_pdu(
|
||||
event_id: &EventId,
|
||||
mut pdu_json: CanonicalJsonObject,
|
||||
room_version: &RoomVersionId,
|
||||
recursion_level: usize,
|
||||
auth_events_known: bool,
|
||||
) -> Result<(PduEvent, CanonicalJsonObject)> {
|
||||
debug!(?event_id, ?auth_events_known, "handle outlier");
|
||||
debug!(?event_id, ?auth_events_known, %recursion_level, "handle outlier");
|
||||
|
||||
// 1. Remove unsigned field
|
||||
pdu_json.remove("unsigned");
|
||||
@@ -33,7 +39,7 @@ pub(super) async fn handle_outlier_pdu(
|
||||
// anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json
|
||||
// 2. Check signatures, otherwise drop
|
||||
// 3. check content hash, redact if doesn't match
|
||||
let mut pdu_json = match self
|
||||
let pdu_json = match self
|
||||
.services
|
||||
.server_keys
|
||||
.verify_event(&pdu_json, Some(room_version))
|
||||
@@ -49,7 +55,8 @@ pub(super) async fn handle_outlier_pdu(
|
||||
)));
|
||||
};
|
||||
|
||||
let Ok(obj) = ruma::canonical_json::redact(pdu_json, &rules.redaction, None) else {
|
||||
let Ok(pdu_json) = ruma::canonical_json::redact(pdu_json, &rules.redaction, None)
|
||||
else {
|
||||
return Err!(Request(InvalidParam("Redaction failed")));
|
||||
};
|
||||
|
||||
@@ -60,7 +67,7 @@ pub(super) async fn handle_outlier_pdu(
|
||||
)));
|
||||
}
|
||||
|
||||
obj
|
||||
pdu_json
|
||||
},
|
||||
| Err(e) => {
|
||||
return Err!(Request(InvalidParam(debug_error!(
|
||||
@@ -69,15 +76,11 @@ pub(super) async fn handle_outlier_pdu(
|
||||
},
|
||||
};
|
||||
|
||||
let room_rules = room_version::rules(room_version)?;
|
||||
|
||||
check_pdu_format(&pdu_json, &room_rules.event_format)?;
|
||||
|
||||
// Now that we have checked the signature and hashes we can make mutations and
|
||||
// convert to our PduEvent type.
|
||||
let event = from_incoming_federation(room_id, event_id, &mut pdu_json, &room_rules)?;
|
||||
|
||||
check_room_id(room_id, &event)?;
|
||||
let room_rules = room_version::rules(room_version)?;
|
||||
let (event, pdu_json) =
|
||||
PduEvent::from_object_federation(room_id, event_id, pdu_json, &room_rules)?;
|
||||
|
||||
if !auth_events_known {
|
||||
// 4. fetch any missing auth events doing all checks listed here starting at 1.
|
||||
@@ -86,7 +89,14 @@ pub(super) async fn handle_outlier_pdu(
|
||||
// the auth events are also rejected "due to auth events"
|
||||
// NOTE: Step 5 is not applied anymore because it failed too often
|
||||
debug!("Fetching auth events");
|
||||
Box::pin(self.fetch_auth(origin, room_id, event.auth_events(), room_version)).await;
|
||||
Box::pin(self.fetch_auth(
|
||||
origin,
|
||||
room_id,
|
||||
event.auth_events(),
|
||||
room_version,
|
||||
recursion_level,
|
||||
))
|
||||
.await;
|
||||
}
|
||||
|
||||
// 6. Reject "due to auth events" if the event doesn't pass auth based on the
|
||||
|
||||
@@ -26,6 +26,7 @@ pub(super) async fn handle_prev_pdu(
|
||||
event_id: &EventId,
|
||||
eventid_info: Option<(PduEvent, CanonicalJsonObject)>,
|
||||
room_version: &RoomVersionId,
|
||||
recursion_level: usize,
|
||||
first_ts_in_room: MilliSecondsSinceUnixEpoch,
|
||||
prev_id: &EventId,
|
||||
create_event_id: &EventId,
|
||||
@@ -50,8 +51,8 @@ pub(super) async fn handle_prev_pdu(
|
||||
}
|
||||
|
||||
if self.is_backed_off(prev_id, Range {
|
||||
start: Duration::from_secs(5 * 60),
|
||||
end: Duration::from_secs(60 * 60 * 24),
|
||||
start: Duration::from_mins(5),
|
||||
end: Duration::from_hours(24),
|
||||
}) {
|
||||
debug!(?prev_id, "Backing off from prev_event");
|
||||
return Ok(None);
|
||||
@@ -63,6 +64,7 @@ pub(super) async fn handle_prev_pdu(
|
||||
pdu,
|
||||
json,
|
||||
room_version,
|
||||
recursion_level,
|
||||
create_event_id,
|
||||
)
|
||||
.boxed()
|
||||
|
||||
@@ -19,10 +19,10 @@ use std::{
|
||||
};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ruma::{EventId, OwnedEventId, OwnedRoomId, RoomId};
|
||||
use ruma::{EventId, OwnedEventId, OwnedRoomId};
|
||||
use tuwunel_core::{
|
||||
Err, Result, implement,
|
||||
matrix::{Event, PduEvent},
|
||||
Result, implement,
|
||||
matrix::PduEvent,
|
||||
utils::{MutexMap, bytes::pretty, continue_exponential_backoff},
|
||||
};
|
||||
|
||||
@@ -146,16 +146,3 @@ async fn event_exists(&self, event_id: &EventId) -> bool {
|
||||
async fn event_fetch(&self, event_id: &EventId) -> Result<PduEvent> {
|
||||
self.services.timeline.get_pdu(event_id).await
|
||||
}
|
||||
|
||||
fn check_room_id<Pdu: Event>(room_id: &RoomId, pdu: &Pdu) -> Result {
|
||||
if pdu.room_id() != room_id {
|
||||
return Err!(Request(InvalidParam(error!(
|
||||
pdu_event_id = ?pdu.event_id(),
|
||||
pdu_room_id = ?pdu.room_id(),
|
||||
?room_id,
|
||||
"Found event from room in room",
|
||||
))));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId};
|
||||
use serde_json::value::RawValue as RawJsonValue;
|
||||
use tuwunel_core::{
|
||||
Result, err, implement, matrix::event::gen_event_id_canonical_json, result::FlatOk,
|
||||
};
|
||||
use tuwunel_core::{Result, err, implement, matrix::event::gen_event_id, result::FlatOk};
|
||||
|
||||
type Parsed = (OwnedRoomId, OwnedEventId, CanonicalJsonObject);
|
||||
|
||||
#[implement(super::Service)]
|
||||
#[tracing::instrument(
|
||||
name = "parse_incoming",
|
||||
level = "trace",
|
||||
skip_all,
|
||||
fields(
|
||||
len = pdu.get().len(),
|
||||
)
|
||||
)]
|
||||
pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<Parsed> {
|
||||
let value = serde_json::from_str::<CanonicalJsonObject>(pdu.get()).map_err(|e| {
|
||||
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
|
||||
err!(BadServerResponse(debug_error!("Error parsing incoming event: {e} {pdu:#?}")))
|
||||
})?;
|
||||
|
||||
@@ -25,9 +31,9 @@ pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<Parsed> {
|
||||
.await
|
||||
.map_err(|_| err!("Server is not in room {room_id}"))?;
|
||||
|
||||
let (event_id, value) = gen_event_id_canonical_json(pdu, &room_version_id).map_err(|e| {
|
||||
err!(Request(InvalidParam("Could not convert event to canonical json: {e}")))
|
||||
})?;
|
||||
|
||||
Ok((room_id, event_id, value))
|
||||
gen_event_id(&value, &room_version_id)
|
||||
.map(move |event_id| (room_id, event_id, value))
|
||||
.map_err(|e| {
|
||||
err!(Request(InvalidParam("Could not convert event to canonical json: {e}")))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ use ruma::{
|
||||
};
|
||||
use tuwunel_core::{
|
||||
Err, Result, debug, debug_info, err, implement, is_equal_to,
|
||||
matrix::{Event, EventTypeExt, PduEvent, StateKey, pdu::check_pdu_format, room_version},
|
||||
matrix::{Event, EventTypeExt, PduEvent, StateKey, pdu::check_rules, room_version},
|
||||
trace,
|
||||
utils::stream::{BroadbandExt, ReadyExt},
|
||||
warn,
|
||||
@@ -24,15 +24,18 @@ use crate::rooms::{
|
||||
name = "upgrade",
|
||||
level = "debug",
|
||||
ret(level = "debug"),
|
||||
skip_all
|
||||
skip_all,
|
||||
fields(lev = %recursion_level)
|
||||
)]
|
||||
#[expect(clippy::too_many_arguments)]
|
||||
pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
room_id: &RoomId,
|
||||
incoming_pdu: PduEvent,
|
||||
val: CanonicalJsonObject,
|
||||
pdu_json: CanonicalJsonObject,
|
||||
room_version: &RoomVersionId,
|
||||
recursion_level: usize,
|
||||
create_event_id: &EventId,
|
||||
) -> Result<Option<(RawPduId, bool)>> {
|
||||
// Skip the PDU if we already have it as a timeline event
|
||||
@@ -61,7 +64,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||
let room_rules = room_version::rules(room_version)?;
|
||||
|
||||
trace!(format = ?room_rules.event_format, "Checking format");
|
||||
check_pdu_format(&val, &room_rules.event_format)?;
|
||||
check_rules(&pdu_json, &room_rules.event_format)?;
|
||||
|
||||
// 10. Fetch missing state and auth chain events by calling /state_ids at
|
||||
// backwards extremities doing all the checks in this list starting at 1.
|
||||
@@ -79,7 +82,14 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||
|
||||
if state_at_incoming_event.is_none() {
|
||||
state_at_incoming_event = self
|
||||
.fetch_state(origin, room_id, incoming_pdu.event_id(), room_version, create_event_id)
|
||||
.fetch_state(
|
||||
origin,
|
||||
room_id,
|
||||
incoming_pdu.event_id(),
|
||||
room_version,
|
||||
recursion_level,
|
||||
create_event_id,
|
||||
)
|
||||
.boxed()
|
||||
.await?;
|
||||
}
|
||||
@@ -270,7 +280,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||
.timeline
|
||||
.append_incoming_pdu(
|
||||
&incoming_pdu,
|
||||
val,
|
||||
pdu_json,
|
||||
extremities,
|
||||
state_ids_compressed,
|
||||
soft_fail,
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
};
|
||||
use std::{sync::Arc, time::Duration};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use ruma::{CanonicalJsonObject, EventId};
|
||||
use tuwunel_core::{
|
||||
Result, debug_info, expected, implement, matrix::pdu::PduEvent, utils::TryReadyExt,
|
||||
Result, debug_info, expected, implement,
|
||||
matrix::pdu::PduEvent,
|
||||
utils::{TryReadyExt, time::now},
|
||||
};
|
||||
use tuwunel_database::{Deserialized, Json, Map};
|
||||
|
||||
@@ -35,11 +34,7 @@ impl crate::Service for Service {
|
||||
if retention_seconds != 0 {
|
||||
debug_info!("Cleaning up retained events");
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
let now = now().as_secs();
|
||||
let count = self
|
||||
.timeredacted_eventid
|
||||
.keys::<(u64, &EventId)>()
|
||||
@@ -59,7 +54,7 @@ impl crate::Service for Service {
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
() = tokio::time::sleep(Duration::from_secs(60 * 60)) => {},
|
||||
() = tokio::time::sleep(Duration::from_hours(1)) => {},
|
||||
() = self.services.server.until_shutdown() => return Ok(())
|
||||
};
|
||||
}
|
||||
@@ -104,10 +99,7 @@ pub async fn save_original_pdu(
|
||||
return;
|
||||
}
|
||||
|
||||
let now = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
let now = now().as_secs();
|
||||
|
||||
self.eventid_originalpdu
|
||||
.raw_put(event_id, Json(pdu));
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
use async_trait::async_trait;
|
||||
use ruma::RoomId;
|
||||
use tuwunel_core::Result;
|
||||
|
||||
use crate::rooms::{short::ShortRoomId, timeline::RawPduId};
|
||||
|
||||
/// A search hit returned by a backend.
|
||||
pub(super) struct SearchHit {
|
||||
pub(super) pdu_id: RawPduId,
|
||||
#[expect(dead_code)]
|
||||
pub(super) score: Option<f64>,
|
||||
}
|
||||
|
||||
/// Query parameters passed to the backend for searching.
|
||||
pub(super) struct BackendQuery<'a> {
|
||||
pub(super) search_term: &'a str,
|
||||
pub(super) room_ids: Option<&'a [&'a RoomId]>,
|
||||
}
|
||||
|
||||
/// Trait abstracting the search index backend.
|
||||
#[async_trait]
|
||||
pub(super) trait SearchBackend: Send + Sync {
|
||||
async fn index_pdu(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
pdu_id: &RawPduId,
|
||||
room_id: &RoomId,
|
||||
message_body: &str,
|
||||
) -> Result;
|
||||
|
||||
async fn deindex_pdu(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
pdu_id: &RawPduId,
|
||||
room_id: &RoomId,
|
||||
message_body: &str,
|
||||
) -> Result;
|
||||
|
||||
async fn search_pdu_ids(
|
||||
&self,
|
||||
query: &BackendQuery<'_>,
|
||||
shortroomid: Option<ShortRoomId>,
|
||||
limit: usize,
|
||||
skip: usize,
|
||||
) -> Result<(usize, Vec<SearchHit>)>;
|
||||
|
||||
async fn delete_room(&self, room_id: &RoomId, shortroomid: ShortRoomId) -> Result;
|
||||
|
||||
fn supports_cross_room_search(&self) -> bool { false }
|
||||
}
|
||||
@@ -1,33 +1,34 @@
|
||||
mod backend;
|
||||
mod opensearch;
|
||||
mod rocksdb;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::{Stream, StreamExt};
|
||||
use ruma::{RoomId, UserId, api::client::search::search_events::v3::Criteria};
|
||||
use tokio::sync::mpsc;
|
||||
use tuwunel_core::{
|
||||
Result, implement,
|
||||
PduCount, Result,
|
||||
arrayvec::ArrayVec,
|
||||
implement,
|
||||
matrix::event::{Event, Matches},
|
||||
utils::{IterStream, ReadyExt, stream::WidebandExt},
|
||||
trace,
|
||||
utils::{
|
||||
ArrayVecExt, IterStream, ReadyExt, set,
|
||||
stream::{TryIgnore, WidebandExt},
|
||||
},
|
||||
};
|
||||
use tuwunel_database::{Interfix, Map, keyval::Val};
|
||||
|
||||
use self::{
|
||||
backend::{BackendQuery, SearchBackend, SearchHit},
|
||||
opensearch::OpenSearchBackend,
|
||||
rocksdb::RocksDbBackend,
|
||||
use crate::rooms::{
|
||||
short::ShortRoomId,
|
||||
timeline::{PduId, RawPduId},
|
||||
};
|
||||
use crate::rooms::{short::ShortRoomId, timeline::RawPduId};
|
||||
|
||||
pub struct Service {
|
||||
backend: Box<dyn SearchBackend>,
|
||||
bulk_rx: Option<tokio::sync::Mutex<mpsc::Receiver<opensearch::BulkAction>>>,
|
||||
opensearch: Option<Arc<OpenSearchBackend>>,
|
||||
db: Data,
|
||||
services: Arc<crate::services::OnceServices>,
|
||||
}
|
||||
|
||||
struct Data {
|
||||
tokenids: Arc<Map>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RoomQuery<'a> {
|
||||
pub room_id: &'a RoomId,
|
||||
@@ -37,133 +38,52 @@ pub struct RoomQuery<'a> {
|
||||
pub skip: usize,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
type TokenId = ArrayVec<u8, TOKEN_ID_MAX_LEN>;
|
||||
|
||||
const TOKEN_ID_MAX_LEN: usize =
|
||||
size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
|
||||
const WORD_MAX_LEN: usize = 50;
|
||||
|
||||
impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
let config = &args.server.config;
|
||||
let (backend, bulk_rx, opensearch): (Box<dyn SearchBackend>, _, _) = match config
|
||||
.search_backend
|
||||
{
|
||||
| tuwunel_core::config::SearchBackendConfig::RocksDb => {
|
||||
let b = RocksDbBackend::new(args.db["tokenids"].clone());
|
||||
(Box::new(b), None, None)
|
||||
},
|
||||
| tuwunel_core::config::SearchBackendConfig::OpenSearch => {
|
||||
let (b, rx) = OpenSearchBackend::new(config)?;
|
||||
let b = Arc::new(b);
|
||||
let backend: Box<dyn SearchBackend> = Box::new(OpenSearchWorkerRef(b.clone()));
|
||||
(backend, Some(tokio::sync::Mutex::new(rx)), Some(b))
|
||||
},
|
||||
};
|
||||
|
||||
Ok(Arc::new(Self {
|
||||
backend,
|
||||
bulk_rx,
|
||||
opensearch,
|
||||
db: Data { tokenids: args.db["tokenids"].clone() },
|
||||
services: args.services.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
|
||||
async fn worker(self: Arc<Self>) -> Result {
|
||||
if let (Some(os), Some(rx_mutex)) = (self.opensearch.as_ref(), self.bulk_rx.as_ref()) {
|
||||
os.ensure_index().await?;
|
||||
|
||||
let services = &**self.services;
|
||||
let config = &services.server.config;
|
||||
let batch_size = config.search_opensearch_batch_size;
|
||||
let flush_interval =
|
||||
std::time::Duration::from_millis(config.search_opensearch_flush_interval_ms);
|
||||
|
||||
let mut rx = rx_mutex.lock().await;
|
||||
os.process_bulk(&mut rx, batch_size, flush_interval)
|
||||
.await;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Wrapper so we can put the `Arc<OpenSearchBackend>` behind the trait.
|
||||
struct OpenSearchWorkerRef(Arc<OpenSearchBackend>);
|
||||
|
||||
#[async_trait]
|
||||
impl SearchBackend for OpenSearchWorkerRef {
|
||||
async fn index_pdu(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
pdu_id: &RawPduId,
|
||||
room_id: &RoomId,
|
||||
message_body: &str,
|
||||
) -> Result {
|
||||
self.0
|
||||
.index_pdu(shortroomid, pdu_id, room_id, message_body)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn deindex_pdu(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
pdu_id: &RawPduId,
|
||||
room_id: &RoomId,
|
||||
message_body: &str,
|
||||
) -> Result {
|
||||
self.0
|
||||
.deindex_pdu(shortroomid, pdu_id, room_id, message_body)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn search_pdu_ids(
|
||||
&self,
|
||||
query: &BackendQuery<'_>,
|
||||
shortroomid: Option<ShortRoomId>,
|
||||
limit: usize,
|
||||
skip: usize,
|
||||
) -> Result<(usize, Vec<SearchHit>)> {
|
||||
self.0
|
||||
.search_pdu_ids(query, shortroomid, limit, skip)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn delete_room(&self, room_id: &RoomId, shortroomid: ShortRoomId) -> Result {
|
||||
self.0.delete_room(room_id, shortroomid).await
|
||||
}
|
||||
|
||||
fn supports_cross_room_search(&self) -> bool { self.0.supports_cross_room_search() }
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn index_pdu(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
pdu_id: &RawPduId,
|
||||
room_id: &RoomId,
|
||||
message_body: &str,
|
||||
) {
|
||||
if let Err(e) = self
|
||||
.backend
|
||||
.index_pdu(shortroomid, pdu_id, room_id, message_body)
|
||||
.await
|
||||
{
|
||||
tuwunel_core::warn!("Failed to index PDU: {e}");
|
||||
}
|
||||
pub fn index_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) {
|
||||
let batch = tokenize(message_body)
|
||||
.map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here
|
||||
key
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self.db
|
||||
.tokenids
|
||||
.insert_batch(batch.iter().map(|k| (k.as_slice(), &[])));
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn deindex_pdu(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
pdu_id: &RawPduId,
|
||||
room_id: &RoomId,
|
||||
message_body: &str,
|
||||
) {
|
||||
if let Err(e) = self
|
||||
.backend
|
||||
.deindex_pdu(shortroomid, pdu_id, room_id, message_body)
|
||||
.await
|
||||
{
|
||||
tuwunel_core::warn!("Failed to deindex PDU: {e}");
|
||||
pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) {
|
||||
let batch = tokenize(message_body).map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here
|
||||
key
|
||||
});
|
||||
|
||||
for token in batch {
|
||||
self.db.tokenids.remove(&token);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -172,31 +92,12 @@ pub async fn search_pdus<'a>(
|
||||
&'a self,
|
||||
query: &'a RoomQuery<'a>,
|
||||
) -> Result<(usize, impl Stream<Item = impl Event + use<>> + Send + '_)> {
|
||||
let shortroomid = self
|
||||
.services
|
||||
.short
|
||||
.get_shortroomid(query.room_id)
|
||||
.await?;
|
||||
|
||||
let backend_query = BackendQuery {
|
||||
search_term: &query.criteria.search_term,
|
||||
room_ids: None,
|
||||
};
|
||||
|
||||
let (count, hits) = self
|
||||
.backend
|
||||
.search_pdu_ids(
|
||||
&backend_query,
|
||||
Some(shortroomid),
|
||||
query.limit.saturating_add(query.skip),
|
||||
0,
|
||||
)
|
||||
.await?;
|
||||
let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await;
|
||||
|
||||
let filter = &query.criteria.filter;
|
||||
let pdus = hits
|
||||
let count = pdu_ids.len();
|
||||
let pdus = pdu_ids
|
||||
.into_iter()
|
||||
.map(|h| h.pdu_id)
|
||||
.stream()
|
||||
.wide_filter_map(async |result_pdu_id: RawPduId| {
|
||||
self.services
|
||||
@@ -220,64 +121,123 @@ pub async fn search_pdus<'a>(
|
||||
Ok((count, pdus))
|
||||
}
|
||||
|
||||
// result is modeled as a stream such that callers don't have to be refactored
|
||||
// though an additional async/wrap still exists for now
|
||||
#[implement(Service)]
|
||||
pub async fn search_pdus_multi_room<'a>(
|
||||
&'a self,
|
||||
room_ids: &'a [&'a RoomId],
|
||||
sender_user: &'a UserId,
|
||||
criteria: &'a Criteria,
|
||||
limit: usize,
|
||||
skip: usize,
|
||||
) -> Result<(usize, impl Stream<Item = impl Event + use<>> + Send + '_)> {
|
||||
let backend_query = BackendQuery {
|
||||
search_term: &criteria.search_term,
|
||||
room_ids: Some(room_ids),
|
||||
};
|
||||
|
||||
let (count, hits) = self
|
||||
.backend
|
||||
.search_pdu_ids(&backend_query, None, limit.saturating_add(skip), 0)
|
||||
.await?;
|
||||
|
||||
let filter = &criteria.filter;
|
||||
let pdus = hits
|
||||
.into_iter()
|
||||
.map(|h| h.pdu_id)
|
||||
.stream()
|
||||
.wide_filter_map(async |result_pdu_id: RawPduId| {
|
||||
self.services
|
||||
.timeline
|
||||
.get_pdu_from_id(&result_pdu_id)
|
||||
.await
|
||||
.ok()
|
||||
})
|
||||
.ready_filter(|pdu| !pdu.is_redacted())
|
||||
.ready_filter(move |pdu| filter.matches(pdu))
|
||||
.wide_filter_map(async move |pdu| {
|
||||
self.services
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, pdu.room_id(), pdu.event_id())
|
||||
.await
|
||||
.then_some(pdu)
|
||||
})
|
||||
.skip(skip)
|
||||
.take(limit);
|
||||
|
||||
Ok((count, pdus))
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub fn supports_cross_room_search(&self) -> bool { self.backend.supports_cross_room_search() }
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn delete_all_search_tokenids_for_room(&self, room_id: &RoomId) -> Result {
|
||||
pub async fn search_pdu_ids(
|
||||
&self,
|
||||
query: &RoomQuery<'_>,
|
||||
) -> Result<impl Stream<Item = RawPduId> + Send + '_ + use<'_>> {
|
||||
let shortroomid = self
|
||||
.services
|
||||
.short
|
||||
.get_shortroomid(room_id)
|
||||
.get_shortroomid(query.room_id)
|
||||
.await?;
|
||||
|
||||
self.backend
|
||||
.delete_room(room_id, shortroomid)
|
||||
let pdu_ids = self
|
||||
.search_pdu_ids_query_room(query, shortroomid)
|
||||
.await;
|
||||
|
||||
let iters = pdu_ids.into_iter().map(IntoIterator::into_iter);
|
||||
|
||||
Ok(set::intersection(iters).stream())
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
async fn search_pdu_ids_query_room(
|
||||
&self,
|
||||
query: &RoomQuery<'_>,
|
||||
shortroomid: ShortRoomId,
|
||||
) -> Vec<Vec<RawPduId>> {
|
||||
tokenize(&query.criteria.search_term)
|
||||
.stream()
|
||||
.wide_then(async |word| {
|
||||
self.search_pdu_ids_query_words(shortroomid, &word)
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
}
|
||||
|
||||
/// Iterate over PduId's containing a word
|
||||
#[implement(Service)]
|
||||
fn search_pdu_ids_query_words<'a>(
|
||||
&'a self,
|
||||
shortroomid: ShortRoomId,
|
||||
word: &'a str,
|
||||
) -> impl Stream<Item = RawPduId> + Send + '_ {
|
||||
self.search_pdu_ids_query_word(shortroomid, word)
|
||||
.map(move |key| -> RawPduId {
|
||||
let key = &key[prefix_len(word)..];
|
||||
key.into()
|
||||
})
|
||||
}
|
||||
|
||||
/// Iterate over raw database results for a word
|
||||
#[implement(Service)]
|
||||
fn search_pdu_ids_query_word(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
word: &str,
|
||||
) -> impl Stream<Item = Val<'_>> + Send + '_ + use<'_> {
|
||||
// rustc says const'ing this not yet stable
|
||||
let end_id: RawPduId = PduId { shortroomid, count: PduCount::max() }.into();
|
||||
|
||||
// Newest pdus first
|
||||
let end = make_tokenid(shortroomid, word, &end_id);
|
||||
let prefix = make_prefix(shortroomid, word);
|
||||
self.db
|
||||
.tokenids
|
||||
.rev_raw_keys_from(&end)
|
||||
.ignore_err()
|
||||
.ready_take_while(move |key| key.starts_with(&prefix))
|
||||
}
|
||||
|
||||
#[implement(Service)]
|
||||
pub async fn delete_all_search_tokenids_for_room(&self, room_id: &RoomId) -> Result {
|
||||
let prefix = (room_id, Interfix);
|
||||
|
||||
self.db
|
||||
.tokenids
|
||||
.keys_prefix_raw(&prefix)
|
||||
.ignore_err()
|
||||
.ready_for_each(|key| {
|
||||
trace!("Removing key: {key:?}");
|
||||
self.db.tokenids.remove(key);
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Splits a string into tokens used as keys in the search inverted index
|
||||
///
|
||||
/// This may be used to tokenize both message bodies (for indexing) or search
|
||||
/// queries (for querying).
|
||||
fn tokenize(body: &str) -> impl Iterator<Item = String> + Send + '_ {
|
||||
body.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|word| word.len() <= WORD_MAX_LEN)
|
||||
.map(str::to_lowercase)
|
||||
}
|
||||
|
||||
fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &RawPduId) -> TokenId {
|
||||
let mut key = make_prefix(shortroomid, word);
|
||||
key.extend_from_slice(pdu_id.as_ref());
|
||||
key
|
||||
}
|
||||
|
||||
fn make_prefix(shortroomid: ShortRoomId, word: &str) -> TokenId {
|
||||
let mut key = TokenId::new();
|
||||
key.extend_from_slice(&shortroomid.to_be_bytes());
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(tuwunel_database::SEP);
|
||||
key
|
||||
}
|
||||
|
||||
fn prefix_len(word: &str) -> usize {
|
||||
size_of::<ShortRoomId>()
|
||||
.saturating_add(word.len())
|
||||
.saturating_add(1)
|
||||
}
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,193 +0,0 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::StreamExt;
|
||||
use ruma::RoomId;
|
||||
use tuwunel_core::{
|
||||
PduCount, Result,
|
||||
arrayvec::ArrayVec,
|
||||
utils::{ArrayVecExt, ReadyExt, set, stream::TryIgnore},
|
||||
};
|
||||
use tuwunel_database::{Interfix, Map, keyval::Val};
|
||||
|
||||
use super::backend::{BackendQuery, SearchBackend, SearchHit};
|
||||
use crate::rooms::{
|
||||
short::ShortRoomId,
|
||||
timeline::{PduId, RawPduId},
|
||||
};
|
||||
|
||||
pub(super) struct RocksDbBackend {
|
||||
tokenids: Arc<Map>,
|
||||
}
|
||||
|
||||
type TokenId = ArrayVec<u8, TOKEN_ID_MAX_LEN>;
|
||||
|
||||
const TOKEN_ID_MAX_LEN: usize =
|
||||
size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
|
||||
const WORD_MAX_LEN: usize = 50;
|
||||
|
||||
impl RocksDbBackend {
|
||||
pub(super) fn new(tokenids: Arc<Map>) -> Self { Self { tokenids } }
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl SearchBackend for RocksDbBackend {
|
||||
async fn index_pdu(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
pdu_id: &RawPduId,
|
||||
_room_id: &RoomId,
|
||||
message_body: &str,
|
||||
) -> Result {
|
||||
let batch = tokenize(message_body)
|
||||
.map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(pdu_id.as_ref());
|
||||
key
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
self.tokenids
|
||||
.insert_batch(batch.iter().map(|k| (k.as_slice(), &[])));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn deindex_pdu(
|
||||
&self,
|
||||
shortroomid: ShortRoomId,
|
||||
pdu_id: &RawPduId,
|
||||
_room_id: &RoomId,
|
||||
message_body: &str,
|
||||
) -> Result {
|
||||
let batch = tokenize(message_body).map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(pdu_id.as_ref());
|
||||
key
|
||||
});
|
||||
|
||||
for token in batch {
|
||||
self.tokenids.remove(&token);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn search_pdu_ids(
|
||||
&self,
|
||||
query: &BackendQuery<'_>,
|
||||
shortroomid: Option<ShortRoomId>,
|
||||
_limit: usize,
|
||||
_skip: usize,
|
||||
) -> Result<(usize, Vec<SearchHit>)> {
|
||||
let shortroomid = shortroomid.ok_or_else(|| {
|
||||
tuwunel_core::err!(Request(InvalidParam(
|
||||
"RocksDB backend requires a shortroomid for search"
|
||||
)))
|
||||
})?;
|
||||
|
||||
let pdu_ids =
|
||||
search_pdu_ids_query_room(&self.tokenids, shortroomid, query.search_term).await;
|
||||
|
||||
let iters = pdu_ids.into_iter().map(IntoIterator::into_iter);
|
||||
let results: Vec<_> = set::intersection(iters)
|
||||
.map(|pdu_id| SearchHit { pdu_id, score: None })
|
||||
.collect();
|
||||
let count = results.len();
|
||||
|
||||
Ok((count, results))
|
||||
}
|
||||
|
||||
async fn delete_room(&self, room_id: &RoomId, _shortroomid: ShortRoomId) -> Result {
|
||||
let prefix = (room_id, Interfix);
|
||||
|
||||
self.tokenids
|
||||
.keys_prefix_raw(&prefix)
|
||||
.ignore_err()
|
||||
.ready_for_each(|key| {
|
||||
self.tokenids.remove(key);
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn search_pdu_ids_query_room(
|
||||
tokenids: &Arc<Map>,
|
||||
shortroomid: ShortRoomId,
|
||||
search_term: &str,
|
||||
) -> Vec<Vec<RawPduId>> {
|
||||
tokenize(search_term)
|
||||
.collect::<Vec<_>>()
|
||||
.into_iter()
|
||||
.map(|word| {
|
||||
let tokenids = tokenids.clone();
|
||||
async move {
|
||||
search_pdu_ids_query_words(&tokenids, shortroomid, &word)
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
}
|
||||
})
|
||||
.collect::<futures::stream::FuturesUnordered<_>>()
|
||||
.collect::<Vec<_>>()
|
||||
.await
|
||||
}
|
||||
|
||||
fn search_pdu_ids_query_words<'a>(
|
||||
tokenids: &'a Arc<Map>,
|
||||
shortroomid: ShortRoomId,
|
||||
word: &'a str,
|
||||
) -> impl futures::Stream<Item = RawPduId> + Send + 'a {
|
||||
search_pdu_ids_query_word(tokenids, shortroomid, word).map(move |key| -> RawPduId {
|
||||
let key = &key[prefix_len(word)..];
|
||||
key.into()
|
||||
})
|
||||
}
|
||||
|
||||
fn search_pdu_ids_query_word<'a>(
|
||||
tokenids: &'a Arc<Map>,
|
||||
shortroomid: ShortRoomId,
|
||||
word: &'a str,
|
||||
) -> impl futures::Stream<Item = Val<'a>> + Send + 'a {
|
||||
let end_id: RawPduId = PduId { shortroomid, count: PduCount::max() }.into();
|
||||
|
||||
let end = make_tokenid(shortroomid, word, &end_id);
|
||||
let prefix = make_prefix(shortroomid, word);
|
||||
tokenids
|
||||
.rev_raw_keys_from(&end)
|
||||
.ignore_err()
|
||||
.ready_take_while(move |key| key.starts_with(&prefix))
|
||||
}
|
||||
|
||||
/// Splits a string into tokens used as keys in the search inverted index
|
||||
pub(super) fn tokenize(body: &str) -> impl Iterator<Item = String> + Send + '_ {
|
||||
body.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
.filter(|s| !s.is_empty())
|
||||
.filter(|word| word.len() <= WORD_MAX_LEN)
|
||||
.map(str::to_lowercase)
|
||||
}
|
||||
|
||||
fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &RawPduId) -> TokenId {
|
||||
let mut key = make_prefix(shortroomid, word);
|
||||
key.extend_from_slice(pdu_id.as_ref());
|
||||
key
|
||||
}
|
||||
|
||||
fn make_prefix(shortroomid: ShortRoomId, word: &str) -> TokenId {
|
||||
let mut key = TokenId::new();
|
||||
key.extend_from_slice(&shortroomid.to_be_bytes());
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(tuwunel_database::SEP);
|
||||
key
|
||||
}
|
||||
|
||||
fn prefix_len(word: &str) -> usize {
|
||||
size_of::<ShortRoomId>()
|
||||
.saturating_add(word.len())
|
||||
.saturating_add(1)
|
||||
}
|
||||
@@ -281,8 +281,7 @@ async fn append_pdu_effects(
|
||||
if let Some(body) = content.body {
|
||||
self.services
|
||||
.search
|
||||
.index_pdu(shortroomid, &pdu_id, pdu.room_id(), &body)
|
||||
.await;
|
||||
.index_pdu(shortroomid, &pdu_id, &body);
|
||||
|
||||
if self
|
||||
.services
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user