From 56f3f5ea1545a5d4f1db07346e6433a428d7717c Mon Sep 17 00:00:00 2001 From: dasha_uwu Date: Sat, 24 Jan 2026 11:29:56 +0500 Subject: [PATCH] Limited use registration token support Co-authored-by: Ginger Signed-off-by: Jason Volk --- src/admin/admin.rs | 20 ++- src/admin/mod.rs | 1 + src/admin/token/commands.rs | 61 ++++++++ src/admin/token/mod.rs | 36 +++++ src/api/client/register.rs | 17 +-- src/database/maps.rs | 4 + src/service/globals/mod.rs | 27 +--- src/service/mod.rs | 1 + src/service/registration_tokens/data.rs | 194 ++++++++++++++++++++++++ src/service/registration_tokens/mod.rs | 161 ++++++++++++++++++++ src/service/services.rs | 5 +- src/service/uiaa/mod.rs | 10 +- 12 files changed, 494 insertions(+), 43 deletions(-) create mode 100644 src/admin/token/commands.rs create mode 100644 src/admin/token/mod.rs create mode 100644 src/service/registration_tokens/data.rs create mode 100644 src/service/registration_tokens/mod.rs diff --git a/src/admin/admin.rs b/src/admin/admin.rs index 1eaa83a3..2d2b98e8 100644 --- a/src/admin/admin.rs +++ b/src/admin/admin.rs @@ -2,10 +2,17 @@ use clap::Parser; use tuwunel_core::Result; use crate::{ - appservice, appservice::AppserviceCommand, check, check::CheckCommand, context::Context, - debug, debug::DebugCommand, federation, federation::FederationCommand, media, - media::MediaCommand, query, query::QueryCommand, room, room::RoomCommand, server, - server::ServerCommand, user, user::UserCommand, + appservice::{self, AppserviceCommand}, + check::{self, CheckCommand}, + context::Context, + debug::{self, DebugCommand}, + federation::{self, FederationCommand}, + media::{self, MediaCommand}, + query::{self, QueryCommand}, + room::{self, RoomCommand}, + server::{self, ServerCommand}, + token::{self, TokenCommand}, + user::{self, UserCommand}, }; #[derive(Debug, Parser)] @@ -46,6 +53,10 @@ pub(super) enum AdminCommand { #[command(subcommand)] /// - Low-level queries for database getters and iterators Query(QueryCommand), + + #[command(subcommand)] + /// - Commands for managing registration tokens + Token(TokenCommand), } #[tracing::instrument(skip_all, name = "command")] @@ -62,5 +73,6 @@ pub(super) async fn process(command: AdminCommand, context: &Context<'_>) -> Res | Debug(command) => debug::process(command, context).await, | Query(command) => query::process(command, context).await, | Check(command) => check::process(command, context).await, + | Token(command) => token::process(command, context).await, } } diff --git a/src/admin/mod.rs b/src/admin/mod.rs index a2e592d3..6c765f4b 100644 --- a/src/admin/mod.rs +++ b/src/admin/mod.rs @@ -15,6 +15,7 @@ pub(crate) mod media; pub(crate) mod query; pub(crate) mod room; pub(crate) mod server; +pub(crate) mod token; pub(crate) mod user; pub(crate) use tuwunel_macros::{admin_command, admin_command_dispatch}; diff --git a/src/admin/token/commands.rs b/src/admin/token/commands.rs new file mode 100644 index 00000000..0ab29a58 --- /dev/null +++ b/src/admin/token/commands.rs @@ -0,0 +1,61 @@ +use futures::StreamExt; +use tuwunel_core::{Result, utils}; +use tuwunel_macros::admin_command; +use tuwunel_service::registration_tokens::TokenExpires; + +#[admin_command] +pub(super) async fn issue( + &self, + max_uses: Option, + max_age: Option, + once: bool, +) -> Result { + let expires = TokenExpires { + max_uses: max_uses.or_else(|| once.then_some(1)), + max_age: max_age + .map(|max_age| { + let duration = utils::time::parse_duration(&max_age)?; + utils::time::timepoint_from_now(duration) + }) + .transpose()?, + }; + + let (token, info) = self + .services + .registration_tokens + .issue_token(expires) + .await?; + + self.write_str(&format!("New registration token issued: `{token}` - {info}",)) + .await +} + +#[admin_command] +pub(super) async fn revoke(&self, token: String) -> Result { + self.services + .registration_tokens + .revoke_token(&token) + .await?; + + self.write_str("Token revoked successfully.") + .await +} + +#[admin_command] +pub(super) async fn list(&self) -> Result { + let tokens: Vec<_> = self + .services + .registration_tokens + .iterate_tokens() + .collect() + .await; + + self.write_str(&format!("Found {} registration tokens:\n", tokens.len())) + .await?; + + for token in tokens { + self.write_str(&format!("- {token}\n")).await?; + } + + Ok(()) +} diff --git a/src/admin/token/mod.rs b/src/admin/token/mod.rs new file mode 100644 index 00000000..8af9dbdb --- /dev/null +++ b/src/admin/token/mod.rs @@ -0,0 +1,36 @@ +mod commands; + +use clap::Subcommand; +use tuwunel_core::Result; + +use crate::admin_command_dispatch; + +#[admin_command_dispatch] +#[derive(Debug, Subcommand)] +pub(crate) enum TokenCommand { + /// - Issue a new registration token + Issue { + /// The maximum number of times this token is allowed to be used before + /// it expires. + #[arg(long)] + max_uses: Option, + + /// The maximum age of this token (e.g. 30s, 5m, 7d). It will expire + /// after this much time has passed. + #[arg(long)] + max_age: Option, + + /// A shortcut for `--max-uses 1`. + #[arg(long)] + once: bool, + }, + + /// - Revoke a registration token + Revoke { + /// The token to revoke. + token: String, + }, + + /// - List all registration tokens + List, +} diff --git a/src/api/client/register.rs b/src/api/client/register.rs index 28c96be5..2f26cd1e 100644 --- a/src/api/client/register.rs +++ b/src/api/client/register.rs @@ -263,12 +263,7 @@ pub(crate) async fn register_route( // UIAA let mut uiaainfo; - let skip_auth = if !services - .globals - .get_registration_tokens() - .is_empty() - && !is_guest - { + let skip_auth = if services.registration_tokens.is_enabled().await && !is_guest { // Registration token required uiaainfo = UiaaInfo { flows: vec![AuthFlow { @@ -424,13 +419,15 @@ pub(crate) async fn check_registration_token_validity( State(services): State, body: Ruma, ) -> Result { - let tokens = services.globals.get_registration_tokens(); - - if tokens.is_empty() { + if !services.registration_tokens.is_enabled().await { return Err!(Request(Forbidden("Server does not allow token registration"))); } - let valid = tokens.contains(&body.token); + let valid = services + .registration_tokens + .is_token_valid(&body.token) + .await + .is_ok(); Ok(check_registration_token_validity::v1::Response { valid }) } diff --git a/src/database/maps.rs b/src/database/maps.rs index 7fee7225..e418551f 100644 --- a/src/database/maps.rs +++ b/src/database/maps.rs @@ -175,6 +175,10 @@ pub(super) static MAPS: &[Descriptor] = &[ name: "referencedevents", ..descriptor::RANDOM }, + Descriptor { + name: "registrationtoken_info", + ..descriptor::RANDOM_SMALL + }, Descriptor { name: "roomid_knockedcount", ..descriptor::RANDOM_SMALL diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index d816d78a..fb48acf5 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,6 +1,6 @@ mod data; -use std::{collections::HashSet, ops::Range, sync::Arc}; +use std::{ops::Range, sync::Arc}; use data::Data; use ruma::{OwnedUserId, RoomAliasId, ServerName, UserId}; @@ -106,31 +106,6 @@ impl Service { #[must_use] pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() } - pub fn get_registration_tokens(&self) -> HashSet { - let mut tokens = HashSet::new(); - if let Some(file) = &self - .server - .config - .registration_token_file - .as_ref() - { - match std::fs::read_to_string(file) { - | Err(e) => error!("Failed to read the registration token file: {e}"), - | Ok(text) => { - text.split_ascii_whitespace().for_each(|token| { - tokens.insert(token.to_owned()); - }); - }, - } - } - - if let Some(token) = &self.server.config.registration_token { - tokens.insert(token.to_owned()); - } - - tokens - } - pub fn init_rustls_provider(&self) -> Result { if rustls::crypto::CryptoProvider::get_default().is_none() { rustls::crypto::aws_lc_rs::default_provider() diff --git a/src/service/mod.rs b/src/service/mod.rs index e89a56a5..ce58b8fa 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -23,6 +23,7 @@ pub mod membership; pub mod oauth; pub mod presence; pub mod pusher; +pub mod registration_tokens; pub mod resolver; pub mod rooms; pub mod sending; diff --git a/src/service/registration_tokens/data.rs b/src/service/registration_tokens/data.rs new file mode 100644 index 00000000..85a1b3bd --- /dev/null +++ b/src/service/registration_tokens/data.rs @@ -0,0 +1,194 @@ +use std::{sync::Arc, time::SystemTime}; + +use futures::Stream; +use serde::{Deserialize, Serialize}; +use tuwunel_core::{ + Err, Result, + utils::{ + self, + stream::{ReadyExt, TryIgnore}, + }, +}; +use tuwunel_database::{Database, Deserialized, Json, Map}; + +pub(super) struct Data { + registrationtoken_info: Arc, +} + +/// Metadata of a registration token. +#[derive(Debug, Serialize, Deserialize)] +pub struct DatabaseTokenInfo { + /// The number of times this token has been used to create an account. + pub uses: u64, + /// When this token will expire, if it expires. + pub expires: TokenExpires, +} + +impl DatabaseTokenInfo { + pub(super) fn new(expires: TokenExpires) -> Self { Self { uses: 0, expires } } + + /// Determine whether this token info represents a valid token, i.e. one + /// that has not expired according to its [`Self::expires`] property. If + /// [`Self::expires`] is [`None`], this function will always return `true`. + #[must_use] + pub fn is_valid(&self) -> bool { + if let Some(max_uses) = self.expires.max_uses + && max_uses >= self.uses + { + return false; + } + + if let Some(max_age) = self.expires.max_age { + let now = SystemTime::now(); + + if now > max_age { + return false; + } + } + + true + } +} + +impl std::fmt::Display for DatabaseTokenInfo { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Token used {} times. {}", self.uses, self.expires)?; + + Ok(()) + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct TokenExpires { + pub max_uses: Option, + pub max_age: Option, +} + +impl std::fmt::Display for TokenExpires { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut msgs = vec![]; + + if let Some(max_uses) = self.max_uses { + msgs.push(format!("after {max_uses} uses")); + } + + if let Some(max_age) = self.max_age { + let now = SystemTime::now(); + let expires_at = utils::time::format(max_age, "%F %T"); + + match max_age.duration_since(now) { + | Ok(duration) => { + let expires_in = utils::time::pretty(duration); + msgs.push(format!("in {expires_in} ({expires_at})")); + }, + | Err(_) => { + write!(f, "Expired at {expires_at}")?; + return Ok(()); + }, + } + } + + if !msgs.is_empty() { + write!(f, "Expires {}.", msgs.join(" or "))?; + } else { + write!(f, "Never expires.")?; + } + + Ok(()) + } +} + +impl Data { + pub(super) fn new(db: &Arc) -> Self { + Self { + registrationtoken_info: db["registrationtoken_info"].clone(), + } + } + + /// Associate a registration token with its metadata in the database. + pub(super) async fn save_token( + &self, + token: &str, + expires: TokenExpires, + ) -> Result { + if self + .registrationtoken_info + .exists(token) + .await + .is_err() + { + let info = DatabaseTokenInfo::new(expires); + + self.registrationtoken_info + .raw_put(token, Json(&info)); + + Ok(info) + } else { + Err!(Request(InvalidParam("Registration token already exists"))) + } + } + + /// Delete a registration token. + pub(super) async fn revoke_token(&self, token: &str) -> Result { + if self + .registrationtoken_info + .exists(token) + .await + .is_ok() + { + self.registrationtoken_info.remove(token); + + Ok(()) + } else { + Err!(Request(NotFound("Registration token not found"))) + } + } + + /// Look up a registration token's metadata. + pub(super) async fn check_token(&self, token: &str, consume: bool) -> bool { + let info = self + .registrationtoken_info + .get(token) + .await + .deserialized::() + .ok(); + + info.map(|mut info| { + if !info.is_valid() { + self.registrationtoken_info.remove(token); + return false; + } + + if consume { + info.uses = info.uses.saturating_add(1); + + if info.is_valid() { + self.registrationtoken_info + .raw_put(token, Json(info)); + } else { + self.registrationtoken_info.remove(token); + } + } + + true + }) + .unwrap_or(false) + } + + /// Iterate over all valid tokens and delete expired ones. + pub(super) fn iterate_and_clean_tokens( + &self, + ) -> impl Stream + Send + '_ { + self.registrationtoken_info + .stream() + .ignore_err() + .ready_filter_map(|(token, info): (&str, DatabaseTokenInfo)| { + if info.is_valid() { + Some((token, info)) + } else { + self.registrationtoken_info.remove(token); + None + } + }) + } +} diff --git a/src/service/registration_tokens/mod.rs b/src/service/registration_tokens/mod.rs new file mode 100644 index 00000000..78978a27 --- /dev/null +++ b/src/service/registration_tokens/mod.rs @@ -0,0 +1,161 @@ +mod data; + +use std::{collections::HashSet, sync::Arc}; + +use data::Data; +pub use data::{DatabaseTokenInfo, TokenExpires}; +use futures::{Stream, StreamExt, pin_mut}; +use tuwunel_core::{ + Err, Result, error, + utils::{self, IterStream}, +}; + +const RANDOM_TOKEN_LENGTH: usize = 16; + +pub struct Service { + db: Data, + services: Arc, +} + +/// A validated registration token which may be used to create an account. +#[derive(Debug)] +pub struct ValidToken { + pub token: String, + pub source: ValidTokenSource, +} + +impl std::fmt::Display for ValidToken { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "`{}` --- {}", self.token, &self.source) + } +} + +impl PartialEq for ValidToken { + fn eq(&self, other: &str) -> bool { self.token == other } +} + +/// The source of a valid database token. +#[derive(Debug)] +pub enum ValidTokenSource { + /// The static token set in the homeserver's config file, which is + /// always valid. + ConfigFile, + /// A database token which has been checked to be valid. + Database(DatabaseTokenInfo), +} + +impl std::fmt::Display for ValidTokenSource { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + | Self::ConfigFile => write!(f, "Token defined in config."), + | Self::Database(info) => info.fmt(f), + } + } +} + +impl crate::Service for Service { + fn build(args: &crate::Args<'_>) -> Result> { + Ok(Arc::new(Self { + db: Data::new(args.db), + services: args.services.clone(), + })) + } + + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } +} + +impl Service { + /// Issue a new registration token and save it in the database. + pub async fn issue_token( + &self, + expires: TokenExpires, + ) -> Result<(String, DatabaseTokenInfo)> { + let token = utils::random_string(RANDOM_TOKEN_LENGTH); + + let info = self.db.save_token(&token, expires).await?; + + Ok((token, info)) + } + + pub async fn is_enabled(&self) -> bool { + let stream = self.iterate_tokens(); + + pin_mut!(stream); + + stream.next().await.is_some() + } + + pub fn get_config_tokens(&self) -> HashSet { + let mut tokens = HashSet::new(); + if let Some(file) = &self + .services + .server + .config + .registration_token_file + .as_ref() + { + match std::fs::read_to_string(file) { + | Err(e) => error!("Failed to read the registration token file: {e}"), + | Ok(text) => { + text.split_ascii_whitespace().for_each(|token| { + tokens.insert(token.to_owned()); + }); + }, + } + } + + if let Some(token) = &self.services.server.config.registration_token { + tokens.insert(token.to_owned()); + } + + tokens + } + + pub async fn is_token_valid(&self, token: &str) -> Result { self.check(token, false).await } + + pub async fn try_consume(&self, token: &str) -> Result { self.check(token, true).await } + + async fn check(&self, token: &str, consume: bool) -> Result { + if self.get_config_tokens().contains(token) || self.db.check_token(token, consume).await { + return Ok(()); + } + + Err!(Request(Forbidden("Registration token not valid"))) + } + + /// Try to revoke a valid token. + /// + /// Note that tokens set in the config file cannot be revoked. + pub async fn revoke_token(&self, token: &str) -> Result { + if self.get_config_tokens().contains(token) { + return Err!( + "The token set in the config file cannot be revoked. Edit the config file to \ + change it." + ); + } + + self.db.revoke_token(token).await + } + + /// Iterate over all valid registration tokens. + pub fn iterate_tokens(&self) -> impl Stream + Send + '_ { + let config_tokens = self + .get_config_tokens() + .into_iter() + .map(|token| ValidToken { + token, + source: ValidTokenSource::ConfigFile, + }) + .stream(); + + let db_tokens = self + .db + .iterate_and_clean_tokens() + .map(|(token, info)| ValidToken { + token: token.to_owned(), + source: ValidTokenSource::Database(info), + }); + + config_tokens.chain(db_tokens) + } +} diff --git a/src/service/services.rs b/src/service/services.rs index a7229fdc..0ab66689 100644 --- a/src/service/services.rs +++ b/src/service/services.rs @@ -12,7 +12,7 @@ use crate::{ account_data, admin, appservice, client, config, deactivate, emergency, federation, globals, key_backups, manager::Manager, - media, membership, oauth, presence, pusher, resolver, + media, membership, oauth, presence, pusher, registration_tokens, resolver, rooms::{self, retention}, sending, server_keys, service::{Args, Service}, @@ -62,6 +62,7 @@ pub struct Services { pub deactivate: Arc, pub oauth: Arc, pub retention: Arc, + pub registration_tokens: Arc, manager: Mutex>>, pub server: Arc, @@ -121,6 +122,7 @@ pub async fn build(server: Arc) -> Result> { deactivate: deactivate::Service::build(&args)?, oauth: oauth::Service::build(&args)?, retention: retention::Service::build(&args)?, + registration_tokens: registration_tokens::Service::build(&args)?, manager: Mutex::new(None), server, @@ -181,6 +183,7 @@ pub(crate) fn services(&self) -> impl Iterator> + Send { cast!(self.deactivate), cast!(self.oauth), cast!(self.retention), + cast!(self.registration_tokens), ] .into_iter() } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 20a28f20..a4fb795d 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -143,8 +143,14 @@ pub async fn try_auth( uiaainfo.completed.push(AuthType::Password); }, | AuthData::RegistrationToken(t) => { - let tokens = self.services.globals.get_registration_tokens(); - if tokens.contains(t.token.trim()) { + let token = t.token.trim(); + if self + .services + .registration_tokens + .try_consume(token) + .await + .is_ok() + { uiaainfo .completed .push(AuthType::RegistrationToken);