From 0b864801f54952a86366e21ade0be2ccee4d8ed8 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 30 Jan 2026 23:55:28 +0000 Subject: [PATCH] Use map of identity_provider to accommodate env var enumerations. Signed-off-by: Jason Volk --- src/admin/query/oauth.rs | 2 +- src/api/client/session/mod.rs | 2 +- src/api/client/session/sso.rs | 6 ++-- src/core/config/check.rs | 23 ++++++++++++-- src/core/config/mod.rs | 55 +++++++++++++++++++++++++++++----- src/service/oauth/providers.rs | 6 ++-- 6 files changed, 76 insertions(+), 18 deletions(-) diff --git a/src/admin/query/oauth.rs b/src/admin/query/oauth.rs index 9b87a2aa..ef932966 100644 --- a/src/admin/query/oauth.rs +++ b/src/admin/query/oauth.rs @@ -140,7 +140,7 @@ pub(super) async fn oauth_list_providers(&self) -> Result { self.services .config .identity_provider - .iter() + .values() .try_stream() .map_ok(Provider::id) .map_ok(|id| format!("{id}\n")) diff --git a/src/api/client/session/mod.rs b/src/api/client/session/mod.rs index b834f1c6..edc505ff 100644 --- a/src/api/client/session/mod.rs +++ b/src/api/client/session/mod.rs @@ -52,7 +52,7 @@ pub(crate) async fn get_login_types_route( let identity_providers: Vec<_> = services .config .identity_provider - .iter() + .values() .filter(|_| list_idps) .cloned() .map(|config| IdentityProvider { diff --git a/src/api/client/session/sso.rs b/src/api/client/session/sso.rs index a166f6db..b80288ab 100644 --- a/src/api/client/session/sso.rs +++ b/src/api/client/session/sso.rs @@ -90,9 +90,9 @@ pub(crate) async fn sso_login_route( let default_idp_id = services .config .identity_provider - .iter() + .values() .find(|idp| idp.default) - .or_else(|| services.config.identity_provider.iter().next()) + .or_else(|| services.config.identity_provider.values().next()) .map(IdentityProvider::id) .map(ToOwned::to_owned) .unwrap_or_default(); @@ -419,7 +419,7 @@ pub(crate) async fn sso_callback_route( let next_idp_url = services .config .identity_provider - .iter() + .values() .filter(|idp| idp.default || services.config.single_sso) .skip_while(|idp| idp.id() != idp_id) .nth(1) diff --git a/src/core/config/check.rs b/src/core/config/check.rs index 6bad9f49..ba74773f 100644 --- a/src/core/config/check.rs +++ b/src/core/config/check.rs @@ -300,7 +300,24 @@ pub fn check(config: &Config) -> Result { )); } - for (i, provider) in config.identity_provider.iter().enumerate() { + for a in config.identity_provider.values() { + let count = config + .identity_provider + .values() + .filter(|b| a.id().eq(b.id())) + .count(); + + debug_assert_ne!(count, 0, "expected at least one identity_provider"); + if count > 1 { + return Err!(Config( + "client_id", + "Duplicate identity_provider with client_id {}", + a.client_id + )); + } + } + + for (i, provider) in &config.identity_provider { if provider.client_secret.is_some() { continue; } @@ -333,14 +350,14 @@ pub fn check(config: &Config) -> Result { && config.identity_provider.len() > 1 && config .identity_provider - .iter() + .values() .filter(|idp| idp.default) .count() .eq(&0) { let default = config .identity_provider - .iter() + .values() .next() .map(IdentityProvider::id) .expect("Check at least one provider is configured to reach here"); diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 3f5c0267..eb6c614a 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -3,8 +3,7 @@ pub mod manager; pub mod proxy; use std::{ - collections::{BTreeMap, BTreeSet, HashSet}, - hash::{Hash, Hasher}, + collections::{BTreeMap, BTreeSet}, net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, path::{Path, PathBuf}, }; @@ -2259,8 +2258,8 @@ pub struct Config { pub appservice: BTreeMap, // external structure; separate sections - #[serde(default)] - pub identity_provider: HashSet, + #[serde(default, with = "identity_provider_serde")] + pub identity_provider: BTreeMap, #[serde(flatten)] #[expect(clippy::zero_sized_map_values)] @@ -2578,7 +2577,7 @@ pub struct JwtConfig { pub validate_signature: bool, } -#[derive(Clone, Debug, Deserialize, Eq, PartialEq)] +#[derive(Clone, Debug, Deserialize)] #[config_example_generator( filename = "tuwunel-example.toml", section = "[global.identity_provider]" @@ -2763,8 +2762,50 @@ impl IdentityProvider { } } -impl Hash for IdentityProvider { - fn hash(&self, state: &mut H) { self.id().hash(state) } +mod identity_provider_serde { + use std::{collections::BTreeMap, fmt, marker::PhantomData}; + + use serde::{ + Deserializer, de, + de::{MapAccess, SeqAccess}, + }; + + struct Visitor(PhantomData); + + type IdentityProviders = BTreeMap; + + pub(super) fn deserialize<'de, D>(de: D) -> Result + where + D: Deserializer<'de>, + { + de.deserialize_any(Visitor(PhantomData)) + } + + impl<'de> de::Visitor<'de> for Visitor { + type Value = IdentityProviders; + + fn expecting(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt.write_str("Mapping or Sequence") + } + + fn visit_map>(self, mut map: A) -> Result { + let mut ret = Self::Value::new(); + while let Some((k, v)) = map.next_entry()? { + ret.insert(k, v); + } + + Ok(ret) + } + + fn visit_seq>(self, mut seq: A) -> Result { + let mut ret = Self::Value::new(); + while let Some(v) = seq.next_element()? { + ret.insert(ret.len().to_string(), v); + } + + Ok(ret) + } + } } #[derive(Clone, Debug, Default, Deserialize)] diff --git a/src/service/oauth/providers.rs b/src/service/oauth/providers.rs index 5598546c..21a1556c 100644 --- a/src/service/oauth/providers.rs +++ b/src/service/oauth/providers.rs @@ -60,7 +60,7 @@ pub fn get_config(&self, id: &str) -> Result { let providers = &self.services.config.identity_provider; if let Some(provider) = providers - .iter() + .values() .find(|config| config.id() == id) .cloned() { @@ -68,11 +68,11 @@ pub fn get_config(&self, id: &str) -> Result { } if let Some(provider) = providers - .iter() + .values() .find(|config| config.brand == id.to_lowercase()) .filter(|_| { providers - .iter() + .values() .filter(|config| config.brand == id.to_lowercase()) .count() .eq(&1)