Limited use registration token support
Co-authored-by: Ginger <ginger@gingershaped.computer> Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
mod data;
|
||||
|
||||
use std::{collections::HashSet, ops::Range, sync::Arc};
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use data::Data;
|
||||
use ruma::{OwnedUserId, RoomAliasId, ServerName, UserId};
|
||||
@@ -106,31 +106,6 @@ impl Service {
|
||||
#[must_use]
|
||||
pub fn is_read_only(&self) -> bool { self.db.db.is_read_only() }
|
||||
|
||||
pub fn get_registration_tokens(&self) -> HashSet<String> {
|
||||
let mut tokens = HashSet::new();
|
||||
if let Some(file) = &self
|
||||
.server
|
||||
.config
|
||||
.registration_token_file
|
||||
.as_ref()
|
||||
{
|
||||
match std::fs::read_to_string(file) {
|
||||
| Err(e) => error!("Failed to read the registration token file: {e}"),
|
||||
| Ok(text) => {
|
||||
text.split_ascii_whitespace().for_each(|token| {
|
||||
tokens.insert(token.to_owned());
|
||||
});
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(token) = &self.server.config.registration_token {
|
||||
tokens.insert(token.to_owned());
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
pub fn init_rustls_provider(&self) -> Result {
|
||||
if rustls::crypto::CryptoProvider::get_default().is_none() {
|
||||
rustls::crypto::aws_lc_rs::default_provider()
|
||||
|
||||
@@ -23,6 +23,7 @@ pub mod membership;
|
||||
pub mod oauth;
|
||||
pub mod presence;
|
||||
pub mod pusher;
|
||||
pub mod registration_tokens;
|
||||
pub mod resolver;
|
||||
pub mod rooms;
|
||||
pub mod sending;
|
||||
|
||||
194
src/service/registration_tokens/data.rs
Normal file
194
src/service/registration_tokens/data.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
use std::{sync::Arc, time::SystemTime};
|
||||
|
||||
use futures::Stream;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tuwunel_core::{
|
||||
Err, Result,
|
||||
utils::{
|
||||
self,
|
||||
stream::{ReadyExt, TryIgnore},
|
||||
},
|
||||
};
|
||||
use tuwunel_database::{Database, Deserialized, Json, Map};
|
||||
|
||||
pub(super) struct Data {
|
||||
registrationtoken_info: Arc<Map>,
|
||||
}
|
||||
|
||||
/// Metadata of a registration token.
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct DatabaseTokenInfo {
|
||||
/// The number of times this token has been used to create an account.
|
||||
pub uses: u64,
|
||||
/// When this token will expire, if it expires.
|
||||
pub expires: TokenExpires,
|
||||
}
|
||||
|
||||
impl DatabaseTokenInfo {
|
||||
pub(super) fn new(expires: TokenExpires) -> Self { Self { uses: 0, expires } }
|
||||
|
||||
/// Determine whether this token info represents a valid token, i.e. one
|
||||
/// that has not expired according to its [`Self::expires`] property. If
|
||||
/// [`Self::expires`] is [`None`], this function will always return `true`.
|
||||
#[must_use]
|
||||
pub fn is_valid(&self) -> bool {
|
||||
if let Some(max_uses) = self.expires.max_uses
|
||||
&& max_uses >= self.uses
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if let Some(max_age) = self.expires.max_age {
|
||||
let now = SystemTime::now();
|
||||
|
||||
if now > max_age {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DatabaseTokenInfo {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "Token used {} times. {}", self.uses, self.expires)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct TokenExpires {
|
||||
pub max_uses: Option<u64>,
|
||||
pub max_age: Option<SystemTime>,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for TokenExpires {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let mut msgs = vec![];
|
||||
|
||||
if let Some(max_uses) = self.max_uses {
|
||||
msgs.push(format!("after {max_uses} uses"));
|
||||
}
|
||||
|
||||
if let Some(max_age) = self.max_age {
|
||||
let now = SystemTime::now();
|
||||
let expires_at = utils::time::format(max_age, "%F %T");
|
||||
|
||||
match max_age.duration_since(now) {
|
||||
| Ok(duration) => {
|
||||
let expires_in = utils::time::pretty(duration);
|
||||
msgs.push(format!("in {expires_in} ({expires_at})"));
|
||||
},
|
||||
| Err(_) => {
|
||||
write!(f, "Expired at {expires_at}")?;
|
||||
return Ok(());
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if !msgs.is_empty() {
|
||||
write!(f, "Expires {}.", msgs.join(" or "))?;
|
||||
} else {
|
||||
write!(f, "Never expires.")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl Data {
|
||||
pub(super) fn new(db: &Arc<Database>) -> Self {
|
||||
Self {
|
||||
registrationtoken_info: db["registrationtoken_info"].clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Associate a registration token with its metadata in the database.
|
||||
pub(super) async fn save_token(
|
||||
&self,
|
||||
token: &str,
|
||||
expires: TokenExpires,
|
||||
) -> Result<DatabaseTokenInfo> {
|
||||
if self
|
||||
.registrationtoken_info
|
||||
.exists(token)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
let info = DatabaseTokenInfo::new(expires);
|
||||
|
||||
self.registrationtoken_info
|
||||
.raw_put(token, Json(&info));
|
||||
|
||||
Ok(info)
|
||||
} else {
|
||||
Err!(Request(InvalidParam("Registration token already exists")))
|
||||
}
|
||||
}
|
||||
|
||||
/// Delete a registration token.
|
||||
pub(super) async fn revoke_token(&self, token: &str) -> Result {
|
||||
if self
|
||||
.registrationtoken_info
|
||||
.exists(token)
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
self.registrationtoken_info.remove(token);
|
||||
|
||||
Ok(())
|
||||
} else {
|
||||
Err!(Request(NotFound("Registration token not found")))
|
||||
}
|
||||
}
|
||||
|
||||
/// Look up a registration token's metadata.
|
||||
pub(super) async fn check_token(&self, token: &str, consume: bool) -> bool {
|
||||
let info = self
|
||||
.registrationtoken_info
|
||||
.get(token)
|
||||
.await
|
||||
.deserialized::<DatabaseTokenInfo>()
|
||||
.ok();
|
||||
|
||||
info.map(|mut info| {
|
||||
if !info.is_valid() {
|
||||
self.registrationtoken_info.remove(token);
|
||||
return false;
|
||||
}
|
||||
|
||||
if consume {
|
||||
info.uses = info.uses.saturating_add(1);
|
||||
|
||||
if info.is_valid() {
|
||||
self.registrationtoken_info
|
||||
.raw_put(token, Json(info));
|
||||
} else {
|
||||
self.registrationtoken_info.remove(token);
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Iterate over all valid tokens and delete expired ones.
|
||||
pub(super) fn iterate_and_clean_tokens(
|
||||
&self,
|
||||
) -> impl Stream<Item = (&str, DatabaseTokenInfo)> + Send + '_ {
|
||||
self.registrationtoken_info
|
||||
.stream()
|
||||
.ignore_err()
|
||||
.ready_filter_map(|(token, info): (&str, DatabaseTokenInfo)| {
|
||||
if info.is_valid() {
|
||||
Some((token, info))
|
||||
} else {
|
||||
self.registrationtoken_info.remove(token);
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
161
src/service/registration_tokens/mod.rs
Normal file
161
src/service/registration_tokens/mod.rs
Normal file
@@ -0,0 +1,161 @@
|
||||
mod data;
|
||||
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use data::Data;
|
||||
pub use data::{DatabaseTokenInfo, TokenExpires};
|
||||
use futures::{Stream, StreamExt, pin_mut};
|
||||
use tuwunel_core::{
|
||||
Err, Result, error,
|
||||
utils::{self, IterStream},
|
||||
};
|
||||
|
||||
const RANDOM_TOKEN_LENGTH: usize = 16;
|
||||
|
||||
pub struct Service {
|
||||
db: Data,
|
||||
services: Arc<crate::services::OnceServices>,
|
||||
}
|
||||
|
||||
/// A validated registration token which may be used to create an account.
|
||||
#[derive(Debug)]
|
||||
pub struct ValidToken {
|
||||
pub token: String,
|
||||
pub source: ValidTokenSource,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ValidToken {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "`{}` --- {}", self.token, &self.source)
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq<str> for ValidToken {
|
||||
fn eq(&self, other: &str) -> bool { self.token == other }
|
||||
}
|
||||
|
||||
/// The source of a valid database token.
|
||||
#[derive(Debug)]
|
||||
pub enum ValidTokenSource {
|
||||
/// The static token set in the homeserver's config file, which is
|
||||
/// always valid.
|
||||
ConfigFile,
|
||||
/// A database token which has been checked to be valid.
|
||||
Database(DatabaseTokenInfo),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ValidTokenSource {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
| Self::ConfigFile => write!(f, "Token defined in config."),
|
||||
| Self::Database(info) => info.fmt(f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl crate::Service for Service {
|
||||
fn build(args: &crate::Args<'_>) -> Result<Arc<Self>> {
|
||||
Ok(Arc::new(Self {
|
||||
db: Data::new(args.db),
|
||||
services: args.services.clone(),
|
||||
}))
|
||||
}
|
||||
|
||||
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
|
||||
}
|
||||
|
||||
impl Service {
|
||||
/// Issue a new registration token and save it in the database.
|
||||
pub async fn issue_token(
|
||||
&self,
|
||||
expires: TokenExpires,
|
||||
) -> Result<(String, DatabaseTokenInfo)> {
|
||||
let token = utils::random_string(RANDOM_TOKEN_LENGTH);
|
||||
|
||||
let info = self.db.save_token(&token, expires).await?;
|
||||
|
||||
Ok((token, info))
|
||||
}
|
||||
|
||||
pub async fn is_enabled(&self) -> bool {
|
||||
let stream = self.iterate_tokens();
|
||||
|
||||
pin_mut!(stream);
|
||||
|
||||
stream.next().await.is_some()
|
||||
}
|
||||
|
||||
pub fn get_config_tokens(&self) -> HashSet<String> {
|
||||
let mut tokens = HashSet::new();
|
||||
if let Some(file) = &self
|
||||
.services
|
||||
.server
|
||||
.config
|
||||
.registration_token_file
|
||||
.as_ref()
|
||||
{
|
||||
match std::fs::read_to_string(file) {
|
||||
| Err(e) => error!("Failed to read the registration token file: {e}"),
|
||||
| Ok(text) => {
|
||||
text.split_ascii_whitespace().for_each(|token| {
|
||||
tokens.insert(token.to_owned());
|
||||
});
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(token) = &self.services.server.config.registration_token {
|
||||
tokens.insert(token.to_owned());
|
||||
}
|
||||
|
||||
tokens
|
||||
}
|
||||
|
||||
pub async fn is_token_valid(&self, token: &str) -> Result { self.check(token, false).await }
|
||||
|
||||
pub async fn try_consume(&self, token: &str) -> Result { self.check(token, true).await }
|
||||
|
||||
async fn check(&self, token: &str, consume: bool) -> Result {
|
||||
if self.get_config_tokens().contains(token) || self.db.check_token(token, consume).await {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err!(Request(Forbidden("Registration token not valid")))
|
||||
}
|
||||
|
||||
/// Try to revoke a valid token.
|
||||
///
|
||||
/// Note that tokens set in the config file cannot be revoked.
|
||||
pub async fn revoke_token(&self, token: &str) -> Result {
|
||||
if self.get_config_tokens().contains(token) {
|
||||
return Err!(
|
||||
"The token set in the config file cannot be revoked. Edit the config file to \
|
||||
change it."
|
||||
);
|
||||
}
|
||||
|
||||
self.db.revoke_token(token).await
|
||||
}
|
||||
|
||||
/// Iterate over all valid registration tokens.
|
||||
pub fn iterate_tokens(&self) -> impl Stream<Item = ValidToken> + Send + '_ {
|
||||
let config_tokens = self
|
||||
.get_config_tokens()
|
||||
.into_iter()
|
||||
.map(|token| ValidToken {
|
||||
token,
|
||||
source: ValidTokenSource::ConfigFile,
|
||||
})
|
||||
.stream();
|
||||
|
||||
let db_tokens = self
|
||||
.db
|
||||
.iterate_and_clean_tokens()
|
||||
.map(|(token, info)| ValidToken {
|
||||
token: token.to_owned(),
|
||||
source: ValidTokenSource::Database(info),
|
||||
});
|
||||
|
||||
config_tokens.chain(db_tokens)
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,7 @@ use crate::{
|
||||
account_data, admin, appservice, client, config, deactivate, emergency, federation, globals,
|
||||
key_backups,
|
||||
manager::Manager,
|
||||
media, membership, oauth, presence, pusher, resolver,
|
||||
media, membership, oauth, presence, pusher, registration_tokens, resolver,
|
||||
rooms::{self, retention},
|
||||
sending, server_keys,
|
||||
service::{Args, Service},
|
||||
@@ -62,6 +62,7 @@ pub struct Services {
|
||||
pub deactivate: Arc<deactivate::Service>,
|
||||
pub oauth: Arc<oauth::Service>,
|
||||
pub retention: Arc<retention::Service>,
|
||||
pub registration_tokens: Arc<registration_tokens::Service>,
|
||||
|
||||
manager: Mutex<Option<Arc<Manager>>>,
|
||||
pub server: Arc<Server>,
|
||||
@@ -121,6 +122,7 @@ pub async fn build(server: Arc<Server>) -> Result<Arc<Self>> {
|
||||
deactivate: deactivate::Service::build(&args)?,
|
||||
oauth: oauth::Service::build(&args)?,
|
||||
retention: retention::Service::build(&args)?,
|
||||
registration_tokens: registration_tokens::Service::build(&args)?,
|
||||
|
||||
manager: Mutex::new(None),
|
||||
server,
|
||||
@@ -181,6 +183,7 @@ pub(crate) fn services(&self) -> impl Iterator<Item = Arc<dyn Service>> + Send {
|
||||
cast!(self.deactivate),
|
||||
cast!(self.oauth),
|
||||
cast!(self.retention),
|
||||
cast!(self.registration_tokens),
|
||||
]
|
||||
.into_iter()
|
||||
}
|
||||
|
||||
@@ -143,8 +143,14 @@ pub async fn try_auth(
|
||||
uiaainfo.completed.push(AuthType::Password);
|
||||
},
|
||||
| AuthData::RegistrationToken(t) => {
|
||||
let tokens = self.services.globals.get_registration_tokens();
|
||||
if tokens.contains(t.token.trim()) {
|
||||
let token = t.token.trim();
|
||||
if self
|
||||
.services
|
||||
.registration_tokens
|
||||
.try_consume(token)
|
||||
.await
|
||||
.is_ok()
|
||||
{
|
||||
uiaainfo
|
||||
.completed
|
||||
.push(AuthType::RegistrationToken);
|
||||
|
||||
Reference in New Issue
Block a user