Implement refresh-tokens. (resolves #50)
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -1,6 +1,9 @@
|
||||
use std::sync::Arc;
|
||||
use std::{
|
||||
sync::Arc,
|
||||
time::{Duration, SystemTime},
|
||||
};
|
||||
|
||||
use futures::{Stream, StreamExt};
|
||||
use futures::{FutureExt, Stream, StreamExt, future::join};
|
||||
use ruma::{
|
||||
DeviceId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedUserId, UserId,
|
||||
api::client::device::Device, events::AnyToDeviceEvent, serde::Raw,
|
||||
@@ -8,17 +11,26 @@ use ruma::{
|
||||
use serde_json::json;
|
||||
use tuwunel_core::{
|
||||
Err, Result, at, implement,
|
||||
utils::{self, ReadyExt, stream::TryIgnore},
|
||||
utils::{
|
||||
self, ReadyExt,
|
||||
stream::TryIgnore,
|
||||
time::{duration_since_epoch, timepoint_from_epoch, timepoint_from_now},
|
||||
},
|
||||
};
|
||||
use tuwunel_database::{Deserialized, Ignore, Interfix, Json, Map};
|
||||
|
||||
/// generated user access token length
|
||||
pub const TOKEN_LENGTH: usize = 32;
|
||||
|
||||
/// Adds a new device to a user.
|
||||
#[implement(super::Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn create_device(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
token: &str,
|
||||
(access_token, expires_in): (&str, Option<Duration>),
|
||||
refresh_token: Option<&str>,
|
||||
initial_device_display_name: Option<String>,
|
||||
client_ip: Option<String>,
|
||||
) -> Result {
|
||||
@@ -38,25 +50,16 @@ pub async fn create_device(
|
||||
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
self.db.userdeviceid_metadata.put(key, Json(val));
|
||||
self.set_access_token(user_id, device_id, token)
|
||||
self.set_access_token(user_id, device_id, access_token, expires_in, refresh_token)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Removes a device from a user.
|
||||
#[implement(super::Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
|
||||
let userdeviceid = (user_id, device_id);
|
||||
|
||||
// Remove tokens
|
||||
if let Ok(old_token) = self
|
||||
.db
|
||||
.userdeviceid_token
|
||||
.qry(&userdeviceid)
|
||||
.await
|
||||
{
|
||||
self.db.userdeviceid_token.del(userdeviceid);
|
||||
self.db.token_userdeviceid.remove(&old_token);
|
||||
}
|
||||
// Remove access tokens
|
||||
self.remove_tokens(user_id, device_id).await;
|
||||
|
||||
// Remove todevice events
|
||||
let prefix = (user_id, device_id, Interfix);
|
||||
@@ -71,6 +74,7 @@ pub async fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) {
|
||||
|
||||
increment(&self.db.userid_devicelistversion, user_id.as_bytes());
|
||||
|
||||
let userdeviceid = (user_id, device_id);
|
||||
self.db.userdeviceid_metadata.del(userdeviceid);
|
||||
self.mark_device_key_update(user_id).await;
|
||||
}
|
||||
@@ -89,50 +93,97 @@ pub fn all_device_ids<'a>(
|
||||
.map(|(_, device_id): (Ignore, &DeviceId)| device_id)
|
||||
}
|
||||
|
||||
/// Replaces the access token of one device.
|
||||
/// Find out which user an access or refresh token belongs to.
|
||||
#[implement(super::Service)]
|
||||
pub async fn set_access_token(
|
||||
#[tracing::instrument(level = "trace", skip(self, token))]
|
||||
pub async fn find_from_token(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
token: &str,
|
||||
) -> Result {
|
||||
let key = (user_id, device_id);
|
||||
if self
|
||||
.db
|
||||
.userdeviceid_metadata
|
||||
.qry(&key)
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
return Err!(Database(error!(
|
||||
?user_id,
|
||||
?device_id,
|
||||
"User does not exist or device has no metadata."
|
||||
)));
|
||||
}
|
||||
|
||||
// Remove old token
|
||||
if let Ok(old_token) = self.db.userdeviceid_token.qry(&key).await {
|
||||
self.db.token_userdeviceid.remove(&old_token);
|
||||
// It will be removed from userdeviceid_token by the insert later
|
||||
}
|
||||
|
||||
// Assign token to user device combination
|
||||
self.db.userdeviceid_token.put_raw(key, token);
|
||||
self.db.token_userdeviceid.raw_put(token, key);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Find out which user an access token belongs to.
|
||||
#[implement(super::Service)]
|
||||
pub async fn find_from_access_token(&self, token: &str) -> Result<(OwnedUserId, OwnedDeviceId)> {
|
||||
) -> Result<(OwnedUserId, OwnedDeviceId, Option<SystemTime>)> {
|
||||
self.db
|
||||
.token_userdeviceid
|
||||
.get(token)
|
||||
.await
|
||||
.deserialized()
|
||||
.and_then(|(user_id, device_id, expires_at): (_, _, Option<u64>)| {
|
||||
let expires_at = expires_at
|
||||
.map(Duration::from_secs)
|
||||
.map(timepoint_from_epoch)
|
||||
.transpose()?;
|
||||
|
||||
Ok((user_id, device_id, expires_at))
|
||||
})
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn remove_tokens(&self, user_id: &UserId, device_id: &DeviceId) {
|
||||
let remove_access = self
|
||||
.remove_access_token(user_id, device_id)
|
||||
.map(Result::ok);
|
||||
|
||||
let remove_refresh = self
|
||||
.remove_refresh_token(user_id, device_id)
|
||||
.map(Result::ok);
|
||||
|
||||
join(remove_access, remove_refresh).await;
|
||||
}
|
||||
|
||||
/// Replaces the access token of one device.
|
||||
#[implement(super::Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn set_access_token(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
access_token: &str,
|
||||
expires_in: Option<Duration>,
|
||||
refresh_token: Option<&str>,
|
||||
) -> Result {
|
||||
if let Some(refresh_token) = refresh_token {
|
||||
self.set_refresh_token(user_id, device_id, refresh_token)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Remove old token.
|
||||
self.remove_access_token(user_id, device_id)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
let expires_at = expires_in
|
||||
.map(timepoint_from_now)
|
||||
.transpose()?
|
||||
.map(duration_since_epoch)
|
||||
.as_ref()
|
||||
.map(Duration::as_secs);
|
||||
|
||||
let userdeviceid = (user_id, device_id);
|
||||
let value = (user_id, device_id, expires_at);
|
||||
self.db
|
||||
.token_userdeviceid
|
||||
.raw_put(access_token, value);
|
||||
self.db
|
||||
.userdeviceid_token
|
||||
.put_raw(userdeviceid, access_token);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Revoke the access token without deleting the device. Take care to not leave
|
||||
/// dangling devices if using this method.
|
||||
#[implement(super::Service)]
|
||||
pub async fn remove_access_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
|
||||
let userdeviceid = (user_id, device_id);
|
||||
let access_token = self
|
||||
.db
|
||||
.userdeviceid_token
|
||||
.qry(&userdeviceid)
|
||||
.await?;
|
||||
|
||||
self.db.userdeviceid_token.del(userdeviceid);
|
||||
self.db.token_userdeviceid.remove(&access_token);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
@@ -145,6 +196,75 @@ pub async fn get_access_token(&self, user_id: &UserId, device_id: &DeviceId) ->
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub fn generate_access_token(&self, expires: bool) -> (String, Option<Duration>) {
|
||||
let access_token = utils::random_string(TOKEN_LENGTH);
|
||||
let expires_in = expires
|
||||
.then_some(self.services.server.config.access_token_ttl)
|
||||
.map(Duration::from_secs);
|
||||
|
||||
(access_token, expires_in)
|
||||
}
|
||||
|
||||
/// Replaces the refresh token of one device.
|
||||
#[implement(super::Service)]
|
||||
#[tracing::instrument(level = "debug", skip(self))]
|
||||
pub async fn set_refresh_token(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
refresh_token: &str,
|
||||
) -> Result {
|
||||
debug_assert!(refresh_token.starts_with("refresh_"), "refresh_token missing prefix");
|
||||
|
||||
// Remove old token
|
||||
self.remove_refresh_token(user_id, device_id)
|
||||
.await
|
||||
.ok();
|
||||
|
||||
let userdeviceid = (user_id, device_id);
|
||||
self.db
|
||||
.token_userdeviceid
|
||||
.raw_put(refresh_token, userdeviceid);
|
||||
self.db
|
||||
.userdeviceid_refresh
|
||||
.put_raw(userdeviceid, refresh_token);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Revoke the refresh token without deleting the device. Take care to not leave
|
||||
/// dangling devices if using this method.
|
||||
#[implement(super::Service)]
|
||||
pub async fn remove_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result {
|
||||
let userdeviceid = (user_id, device_id);
|
||||
let refresh_token = self
|
||||
.db
|
||||
.userdeviceid_refresh
|
||||
.qry(&userdeviceid)
|
||||
.await?;
|
||||
|
||||
self.db.userdeviceid_refresh.del(userdeviceid);
|
||||
self.db.token_userdeviceid.remove(&refresh_token);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn get_refresh_token(&self, user_id: &UserId, device_id: &DeviceId) -> Result<String> {
|
||||
let key = (user_id, device_id);
|
||||
self.db
|
||||
.userdeviceid_refresh
|
||||
.qry(&key)
|
||||
.await
|
||||
.deserialized()
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn generate_refresh_token() -> String {
|
||||
format!("refresh_{}", utils::random_string(TOKEN_LENGTH))
|
||||
}
|
||||
|
||||
#[implement(super::Service)]
|
||||
pub async fn add_to_device_event(
|
||||
&self,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
mod device;
|
||||
pub mod device;
|
||||
mod keys;
|
||||
mod ldap;
|
||||
mod profile;
|
||||
@@ -44,6 +44,7 @@ struct Data {
|
||||
token_userdeviceid: Arc<Map>,
|
||||
userdeviceid_metadata: Arc<Map>,
|
||||
userdeviceid_token: Arc<Map>,
|
||||
userdeviceid_refresh: Arc<Map>,
|
||||
userfilterid_filter: Arc<Map>,
|
||||
userid_avatarurl: Arc<Map>,
|
||||
userid_blurhash: Arc<Map>,
|
||||
@@ -80,6 +81,7 @@ impl crate::Service for Service {
|
||||
token_userdeviceid: args.db["token_userdeviceid"].clone(),
|
||||
userdeviceid_metadata: args.db["userdeviceid_metadata"].clone(),
|
||||
userdeviceid_token: args.db["userdeviceid_token"].clone(),
|
||||
userdeviceid_refresh: args.db["userdeviceid_refresh"].clone(),
|
||||
userfilterid_filter: args.db["userfilterid_filter"].clone(),
|
||||
userid_avatarurl: args.db["userid_avatarurl"].clone(),
|
||||
userid_blurhash: args.db["userid_blurhash"].clone(),
|
||||
|
||||
Reference in New Issue
Block a user