Limited use registration token support

Co-authored-by: Ginger <ginger@gingershaped.computer>
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
dasha_uwu
2026-01-24 11:29:56 +05:00
committed by Jason Volk
parent 0dbe79df8e
commit 56f3f5ea15
12 changed files with 494 additions and 43 deletions

View File

@@ -2,10 +2,17 @@ use clap::Parser;
use tuwunel_core::Result;
use crate::{
appservice, appservice::AppserviceCommand, check, check::CheckCommand, context::Context,
debug, debug::DebugCommand, federation, federation::FederationCommand, media,
media::MediaCommand, query, query::QueryCommand, room, room::RoomCommand, server,
server::ServerCommand, user, user::UserCommand,
appservice::{self, AppserviceCommand},
check::{self, CheckCommand},
context::Context,
debug::{self, DebugCommand},
federation::{self, FederationCommand},
media::{self, MediaCommand},
query::{self, QueryCommand},
room::{self, RoomCommand},
server::{self, ServerCommand},
token::{self, TokenCommand},
user::{self, UserCommand},
};
#[derive(Debug, Parser)]
@@ -46,6 +53,10 @@ pub(super) enum AdminCommand {
#[command(subcommand)]
/// - Low-level queries for database getters and iterators
Query(QueryCommand),
#[command(subcommand)]
/// - Commands for managing registration tokens
Token(TokenCommand),
}
#[tracing::instrument(skip_all, name = "command")]
@@ -62,5 +73,6 @@ pub(super) async fn process(command: AdminCommand, context: &Context<'_>) -> Res
| Debug(command) => debug::process(command, context).await,
| Query(command) => query::process(command, context).await,
| Check(command) => check::process(command, context).await,
| Token(command) => token::process(command, context).await,
}
}

View File

@@ -15,6 +15,7 @@ pub(crate) mod media;
pub(crate) mod query;
pub(crate) mod room;
pub(crate) mod server;
pub(crate) mod token;
pub(crate) mod user;
pub(crate) use tuwunel_macros::{admin_command, admin_command_dispatch};

View File

@@ -0,0 +1,61 @@
use futures::StreamExt;
use tuwunel_core::{Result, utils};
use tuwunel_macros::admin_command;
use tuwunel_service::registration_tokens::TokenExpires;
#[admin_command]
pub(super) async fn issue(
&self,
max_uses: Option<u64>,
max_age: Option<String>,
once: bool,
) -> Result {
let expires = TokenExpires {
max_uses: max_uses.or_else(|| once.then_some(1)),
max_age: max_age
.map(|max_age| {
let duration = utils::time::parse_duration(&max_age)?;
utils::time::timepoint_from_now(duration)
})
.transpose()?,
};
let (token, info) = self
.services
.registration_tokens
.issue_token(expires)
.await?;
self.write_str(&format!("New registration token issued: `{token}` - {info}",))
.await
}
#[admin_command]
pub(super) async fn revoke(&self, token: String) -> Result {
self.services
.registration_tokens
.revoke_token(&token)
.await?;
self.write_str("Token revoked successfully.")
.await
}
#[admin_command]
pub(super) async fn list(&self) -> Result {
let tokens: Vec<_> = self
.services
.registration_tokens
.iterate_tokens()
.collect()
.await;
self.write_str(&format!("Found {} registration tokens:\n", tokens.len()))
.await?;
for token in tokens {
self.write_str(&format!("- {token}\n")).await?;
}
Ok(())
}

36
src/admin/token/mod.rs Normal file
View File

@@ -0,0 +1,36 @@
mod commands;
use clap::Subcommand;
use tuwunel_core::Result;
use crate::admin_command_dispatch;
#[admin_command_dispatch]
#[derive(Debug, Subcommand)]
pub(crate) enum TokenCommand {
/// - Issue a new registration token
Issue {
/// The maximum number of times this token is allowed to be used before
/// it expires.
#[arg(long)]
max_uses: Option<u64>,
/// The maximum age of this token (e.g. 30s, 5m, 7d). It will expire
/// after this much time has passed.
#[arg(long)]
max_age: Option<String>,
/// A shortcut for `--max-uses 1`.
#[arg(long)]
once: bool,
},
/// - Revoke a registration token
Revoke {
/// The token to revoke.
token: String,
},
/// - List all registration tokens
List,
}

View File

@@ -263,12 +263,7 @@ pub(crate) async fn register_route(
// UIAA
let mut uiaainfo;
let skip_auth = if !services
.globals
.get_registration_tokens()
.is_empty()
&& !is_guest
{
let skip_auth = if services.registration_tokens.is_enabled().await && !is_guest {
// Registration token required
uiaainfo = UiaaInfo {
flows: vec![AuthFlow {
@@ -424,13 +419,15 @@ pub(crate) async fn check_registration_token_validity(
State(services): State<crate::State>,
body: Ruma<check_registration_token_validity::v1::Request>,
) -> Result<check_registration_token_validity::v1::Response> {
let tokens = services.globals.get_registration_tokens();
if tokens.is_empty() {
if !services.registration_tokens.is_enabled().await {
return Err!(Request(Forbidden("Server does not allow token registration")));
}
let valid = tokens.contains(&body.token);
let valid = services
.registration_tokens
.is_token_valid(&body.token)
.await
.is_ok();
Ok(check_registration_token_validity::v1::Response { valid })
}

View File

@@ -175,6 +175,10 @@ pub(super) static MAPS: &[Descriptor] = &[
name: "referencedevents",
..descriptor::RANDOM
},
Descriptor {
name: "registrationtoken_info",
..descriptor::RANDOM_SMALL
},
Descriptor {
name: "roomid_knockedcount",
..descriptor::RANDOM_SMALL

View File

@@ -1,6 +1,6 @@
mod data;
use std::{collections::HashSet, ops::Range, sync::Arc};
use std::{ops::Range, sync::Arc};
use data::Data;
use ruma::{OwnedUserId, RoomAliasId, ServerName, UserId};
@@ -106,31 +106,6 @@ impl Service {
#[must_use]
pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() }
pub fn get_registration_tokens(&self) -> HashSet<String> {
let mut tokens = HashSet::new();
if let Some(file) = &self
.server
.config
.registration_token_file
.as_ref()
{
match std::fs::read_to_string(file) {
| Err(e) => error!("Failed to read the registration token file: {e}"),
| Ok(text) => {
text.split_ascii_whitespace().for_each(|token| {
tokens.insert(token.to_owned());
});
},
}
}
if let Some(token) = &self.server.config.registration_token {
tokens.insert(token.to_owned());
}
tokens
}
pub fn init_rustls_provider(&self) -> Result {
if rustls::crypto::CryptoProvider::get_default().is_none() {
rustls::crypto::aws_lc_rs::default_provider()

View File

@@ -23,6 +23,7 @@ pub mod membership;
pub mod oauth;
pub mod presence;
pub mod pusher;
pub mod registration_tokens;
pub mod resolver;
pub mod rooms;
pub mod sending;

View File

@@ -0,0 +1,194 @@
use std::{sync::Arc, time::SystemTime};
use futures::Stream;
use serde::{Deserialize, Serialize};
use tuwunel_core::{
Err, Result,
utils::{
self,
stream::{ReadyExt, TryIgnore},
},
};
use tuwunel_database::{Database, Deserialized, Json, Map};
pub(super) struct Data {
registrationtoken_info: Arc<Map>,
}
/// Metadata of a registration token.
#[derive(Debug, Serialize, Deserialize)]
pub struct DatabaseTokenInfo {
/// The number of times this token has been used to create an account.
pub uses: u64,
/// When this token will expire, if it expires.
pub expires: TokenExpires,
}
impl DatabaseTokenInfo {
pub(super) fn new(expires: TokenExpires) -> Self { Self { uses: 0, expires } }
/// Determine whether this token info represents a valid token, i.e. one
/// that has not expired according to its [`Self::expires`] property. If
/// [`Self::expires`] is [`None`], this function will always return `true`.
#[must_use]
pub fn is_valid(&self) -> bool {
if let Some(max_uses) = self.expires.max_uses
&& max_uses >= self.uses
{
return false;
}
if let Some(max_age) = self.expires.max_age {
let now = SystemTime::now();
if now > max_age {
return false;
}
}
true
}
}
impl std::fmt::Display for DatabaseTokenInfo {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Token used {} times. {}", self.uses, self.expires)?;
Ok(())
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct TokenExpires {
pub max_uses: Option<u64>,
pub max_age: Option<SystemTime>,
}
impl std::fmt::Display for TokenExpires {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut msgs = vec![];
if let Some(max_uses) = self.max_uses {
msgs.push(format!("after {max_uses} uses"));
}
if let Some(max_age) = self.max_age {
let now = SystemTime::now();
let expires_at = utils::time::format(max_age, "%F %T");
match max_age.duration_since(now) {
| Ok(duration) => {
let expires_in = utils::time::pretty(duration);
msgs.push(format!("in {expires_in} ({expires_at})"));
},
| Err(_) => {
write!(f, "Expired at {expires_at}")?;
return Ok(());
},
}
}
if !msgs.is_empty() {
write!(f, "Expires {}.", msgs.join(" or "))?;
} else {
write!(f, "Never expires.")?;
}
Ok(())
}
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
registrationtoken_info: db["registrationtoken_info"].clone(),
}
}
/// Associate a registration token with its metadata in the database.
pub(super) async fn save_token(
&self,
token: &str,
expires: TokenExpires,
) -> Result<DatabaseTokenInfo> {
if self
.registrationtoken_info
.exists(token)
.await
.is_err()
{
let info = DatabaseTokenInfo::new(expires);
self.registrationtoken_info
.raw_put(token, Json(&info));
Ok(info)
} else {
Err!(Request(InvalidParam("Registration token already exists")))
}
}
/// Delete a registration token.
pub(super) async fn revoke_token(&self, token: &str) -> Result {
if self
.registrationtoken_info
.exists(token)
.await
.is_ok()
{
self.registrationtoken_info.remove(token);
Ok(())
} else {
Err!(Request(NotFound("Registration token not found")))
}
}
/// Look up a registration token's metadata.
pub(super) async fn check_token(&self, token: &str, consume: bool) -> bool {
let info = self
.registrationtoken_info
.get(token)
.await
.deserialized::<DatabaseTokenInfo>()
.ok();
info.map(|mut info| {
if !info.is_valid() {
self.registrationtoken_info.remove(token);
return false;
}
if consume {
info.uses = info.uses.saturating_add(1);
if info.is_valid() {
self.registrationtoken_info
.raw_put(token, Json(info));
} else {
self.registrationtoken_info.remove(token);
}
}
true
})
.unwrap_or(false)
}
/// Iterate over all valid tokens and delete expired ones.
pub(super) fn iterate_and_clean_tokens(
&self,
) -> impl Stream<Item = (&str, DatabaseTokenInfo)> + Send + '_ {
self.registrationtoken_info
.stream()
.ignore_err()
.ready_filter_map(|(token, info): (&str, DatabaseTokenInfo)| {
if info.is_valid() {
Some((token, info))
} else {
self.registrationtoken_info.remove(token);
None
}
})
}
}

View File

@@ -0,0 +1,161 @@
mod data;
use std::{collections::HashSet, sync::Arc};
use data::Data;
pub use data::{DatabaseTokenInfo, TokenExpires};
use futures::{Stream, StreamExt, pin_mut};
use tuwunel_core::{
Err, Result, error,
utils::{self, IterStream},
};
const RANDOM_TOKEN_LENGTH: usize = 16;
pub struct Service {
db: Data,
services: Arc<crate::services::OnceServices>,
}
/// A validated registration token which may be used to create an account.
#[derive(Debug)]
pub struct ValidToken {
pub token: String,
pub source: ValidTokenSource,
}
impl std::fmt::Display for ValidToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "`{}` --- {}", self.token, &self.source)
}
}
impl PartialEq<str> for ValidToken {
fn eq(&self, other: &str) -> bool { self.token == other }
}
/// The source of a valid database token.
#[derive(Debug)]
pub enum ValidTokenSource {
/// The static token set in the homeserver's config file, which is
/// always valid.
ConfigFile,
/// A database token which has been checked to be valid.
Database(DatabaseTokenInfo),
}
impl std::fmt::Display for ValidTokenSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
| Self::ConfigFile => write!(f, "Token defined in config."),
| Self::Database(info) => info.fmt(f),
}
}
}
impl crate::Service for Service {
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
services: args.services.clone(),
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Issue a new registration token and save it in the database.
pub async fn issue_token(
&self,
expires: TokenExpires,
) -> Result<(String, DatabaseTokenInfo)> {
let token = utils::random_string(RANDOM_TOKEN_LENGTH);
let info = self.db.save_token(&token, expires).await?;
Ok((token, info))
}
pub async fn is_enabled(&self) -> bool {
let stream = self.iterate_tokens();
pin_mut!(stream);
stream.next().await.is_some()
}
pub fn get_config_tokens(&self) -> HashSet<String> {
let mut tokens = HashSet::new();
if let Some(file) = &self
.services
.server
.config
.registration_token_file
.as_ref()
{
match std::fs::read_to_string(file) {
| Err(e) => error!("Failed to read the registration token file: {e}"),
| Ok(text) => {
text.split_ascii_whitespace().for_each(|token| {
tokens.insert(token.to_owned());
});
},
}
}
if let Some(token) = &self.services.server.config.registration_token {
tokens.insert(token.to_owned());
}
tokens
}
pub async fn is_token_valid(&self, token: &str) -> Result { self.check(token, false).await }
pub async fn try_consume(&self, token: &str) -> Result { self.check(token, true).await }
async fn check(&self, token: &str, consume: bool) -> Result {
if self.get_config_tokens().contains(token) || self.db.check_token(token, consume).await {
return Ok(());
}
Err!(Request(Forbidden("Registration token not valid")))
}
/// Try to revoke a valid token.
///
/// Note that tokens set in the config file cannot be revoked.
pub async fn revoke_token(&self, token: &str) -> Result {
if self.get_config_tokens().contains(token) {
return Err!(
"The token set in the config file cannot be revoked. Edit the config file to \
change it."
);
}
self.db.revoke_token(token).await
}
/// Iterate over all valid registration tokens.
pub fn iterate_tokens(&self) -> impl Stream<Item = ValidToken> + Send + '_ {
let config_tokens = self
.get_config_tokens()
.into_iter()
.map(|token| ValidToken {
token,
source: ValidTokenSource::ConfigFile,
})
.stream();
let db_tokens = self
.db
.iterate_and_clean_tokens()
.map(|(token, info)| ValidToken {
token: token.to_owned(),
source: ValidTokenSource::Database(info),
});
config_tokens.chain(db_tokens)
}
}

View File

@@ -12,7 +12,7 @@ use crate::{
account_data, admin, appservice, client, config, deactivate, emergency, federation, globals,
key_backups,
manager::Manager,
media, membership, oauth, presence, pusher, resolver,
media, membership, oauth, presence, pusher, registration_tokens, resolver,
rooms::{self, retention},
sending, server_keys,
service::{Args, Service},
@@ -62,6 +62,7 @@ pub struct Services {
pub deactivate: Arc<deactivate::Service>,
pub oauth: Arc<oauth::Service>,
pub retention: Arc<retention::Service>,
pub registration_tokens: Arc<registration_tokens::Service>,
manager: Mutex<Option<Arc<Manager>>>,
pub server: Arc<Server>,
@@ -121,6 +122,7 @@ pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
deactivate: deactivate::Service::build(&args)?,
oauth: oauth::Service::build(&args)?,
retention: retention::Service::build(&args)?,
registration_tokens: registration_tokens::Service::build(&args)?,
manager: Mutex::new(None),
server,
@@ -181,6 +183,7 @@ pub(crate) fn services(&self) -> impl Iterator<Item = Arc<dyn Service>> + Send {
cast!(self.deactivate),
cast!(self.oauth),
cast!(self.retention),
cast!(self.registration_tokens),
]
.into_iter()
}

View File

@@ -143,8 +143,14 @@ pub async fn try_auth(
uiaainfo.completed.push(AuthType::Password);
},
| AuthData::RegistrationToken(t) => {
let tokens = self.services.globals.get_registration_tokens();
if tokens.contains(t.token.trim()) {
let token = t.token.trim();
if self
.services
.registration_tokens
.try_consume(token)
.await
.is_ok()
{
uiaainfo
.completed
.push(AuthType::RegistrationToken);