Use map of identity_provider to accommodate env var enumerations.
Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
@@ -140,7 +140,7 @@ pub(super) async fn oauth_list_providers(&self) -> Result {
|
|||||||
self.services
|
self.services
|
||||||
.config
|
.config
|
||||||
.identity_provider
|
.identity_provider
|
||||||
.iter()
|
.values()
|
||||||
.try_stream()
|
.try_stream()
|
||||||
.map_ok(Provider::id)
|
.map_ok(Provider::id)
|
||||||
.map_ok(|id| format!("{id}\n"))
|
.map_ok(|id| format!("{id}\n"))
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ pub(crate) async fn get_login_types_route(
|
|||||||
let identity_providers: Vec<_> = services
|
let identity_providers: Vec<_> = services
|
||||||
.config
|
.config
|
||||||
.identity_provider
|
.identity_provider
|
||||||
.iter()
|
.values()
|
||||||
.filter(|_| list_idps)
|
.filter(|_| list_idps)
|
||||||
.cloned()
|
.cloned()
|
||||||
.map(|config| IdentityProvider {
|
.map(|config| IdentityProvider {
|
||||||
|
|||||||
@@ -90,9 +90,9 @@ pub(crate) async fn sso_login_route(
|
|||||||
let default_idp_id = services
|
let default_idp_id = services
|
||||||
.config
|
.config
|
||||||
.identity_provider
|
.identity_provider
|
||||||
.iter()
|
.values()
|
||||||
.find(|idp| idp.default)
|
.find(|idp| idp.default)
|
||||||
.or_else(|| services.config.identity_provider.iter().next())
|
.or_else(|| services.config.identity_provider.values().next())
|
||||||
.map(IdentityProvider::id)
|
.map(IdentityProvider::id)
|
||||||
.map(ToOwned::to_owned)
|
.map(ToOwned::to_owned)
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
@@ -419,7 +419,7 @@ pub(crate) async fn sso_callback_route(
|
|||||||
let next_idp_url = services
|
let next_idp_url = services
|
||||||
.config
|
.config
|
||||||
.identity_provider
|
.identity_provider
|
||||||
.iter()
|
.values()
|
||||||
.filter(|idp| idp.default || services.config.single_sso)
|
.filter(|idp| idp.default || services.config.single_sso)
|
||||||
.skip_while(|idp| idp.id() != idp_id)
|
.skip_while(|idp| idp.id() != idp_id)
|
||||||
.nth(1)
|
.nth(1)
|
||||||
|
|||||||
@@ -300,7 +300,24 @@ pub fn check(config: &Config) -> Result {
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
for (i, provider) in config.identity_provider.iter().enumerate() {
|
for a in config.identity_provider.values() {
|
||||||
|
let count = config
|
||||||
|
.identity_provider
|
||||||
|
.values()
|
||||||
|
.filter(|b| a.id().eq(b.id()))
|
||||||
|
.count();
|
||||||
|
|
||||||
|
debug_assert_ne!(count, 0, "expected at least one identity_provider");
|
||||||
|
if count > 1 {
|
||||||
|
return Err!(Config(
|
||||||
|
"client_id",
|
||||||
|
"Duplicate identity_provider with client_id {}",
|
||||||
|
a.client_id
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (i, provider) in &config.identity_provider {
|
||||||
if provider.client_secret.is_some() {
|
if provider.client_secret.is_some() {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
@@ -333,14 +350,14 @@ pub fn check(config: &Config) -> Result {
|
|||||||
&& config.identity_provider.len() > 1
|
&& config.identity_provider.len() > 1
|
||||||
&& config
|
&& config
|
||||||
.identity_provider
|
.identity_provider
|
||||||
.iter()
|
.values()
|
||||||
.filter(|idp| idp.default)
|
.filter(|idp| idp.default)
|
||||||
.count()
|
.count()
|
||||||
.eq(&0)
|
.eq(&0)
|
||||||
{
|
{
|
||||||
let default = config
|
let default = config
|
||||||
.identity_provider
|
.identity_provider
|
||||||
.iter()
|
.values()
|
||||||
.next()
|
.next()
|
||||||
.map(IdentityProvider::id)
|
.map(IdentityProvider::id)
|
||||||
.expect("Check at least one provider is configured to reach here");
|
.expect("Check at least one provider is configured to reach here");
|
||||||
|
|||||||
@@ -3,8 +3,7 @@ pub mod manager;
|
|||||||
pub mod proxy;
|
pub mod proxy;
|
||||||
|
|
||||||
use std::{
|
use std::{
|
||||||
collections::{BTreeMap, BTreeSet, HashSet},
|
collections::{BTreeMap, BTreeSet},
|
||||||
hash::{Hash, Hasher},
|
|
||||||
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
|
||||||
path::{Path, PathBuf},
|
path::{Path, PathBuf},
|
||||||
};
|
};
|
||||||
@@ -2259,8 +2258,8 @@ pub struct Config {
|
|||||||
pub appservice: BTreeMap<String, AppService>,
|
pub appservice: BTreeMap<String, AppService>,
|
||||||
|
|
||||||
// external structure; separate sections
|
// external structure; separate sections
|
||||||
#[serde(default)]
|
#[serde(default, with = "identity_provider_serde")]
|
||||||
pub identity_provider: HashSet<IdentityProvider>,
|
pub identity_provider: BTreeMap<String, IdentityProvider>,
|
||||||
|
|
||||||
#[serde(flatten)]
|
#[serde(flatten)]
|
||||||
#[expect(clippy::zero_sized_map_values)]
|
#[expect(clippy::zero_sized_map_values)]
|
||||||
@@ -2578,7 +2577,7 @@ pub struct JwtConfig {
|
|||||||
pub validate_signature: bool,
|
pub validate_signature: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
|
#[derive(Clone, Debug, Deserialize)]
|
||||||
#[config_example_generator(
|
#[config_example_generator(
|
||||||
filename = "tuwunel-example.toml",
|
filename = "tuwunel-example.toml",
|
||||||
section = "[global.identity_provider]"
|
section = "[global.identity_provider]"
|
||||||
@@ -2763,8 +2762,50 @@ impl IdentityProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Hash for IdentityProvider {
|
mod identity_provider_serde {
|
||||||
fn hash<H: Hasher>(&self, state: &mut H) { self.id().hash(state) }
|
use std::{collections::BTreeMap, fmt, marker::PhantomData};
|
||||||
|
|
||||||
|
use serde::{
|
||||||
|
Deserializer, de,
|
||||||
|
de::{MapAccess, SeqAccess},
|
||||||
|
};
|
||||||
|
|
||||||
|
struct Visitor(PhantomData<IdentityProviders>);
|
||||||
|
|
||||||
|
type IdentityProviders = BTreeMap<String, super::IdentityProvider>;
|
||||||
|
|
||||||
|
pub(super) fn deserialize<'de, D>(de: D) -> Result<IdentityProviders, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
de.deserialize_any(Visitor(PhantomData))
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> de::Visitor<'de> for Visitor {
|
||||||
|
type Value = IdentityProviders;
|
||||||
|
|
||||||
|
fn expecting(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
fmt.write_str("Mapping or Sequence")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
|
||||||
|
let mut ret = Self::Value::new();
|
||||||
|
while let Some((k, v)) = map.next_entry()? {
|
||||||
|
ret.insert(k, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
|
||||||
|
let mut ret = Self::Value::new();
|
||||||
|
while let Some(v) = seq.next_element()? {
|
||||||
|
ret.insert(ret.len().to_string(), v);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(ret)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Default, Deserialize)]
|
#[derive(Clone, Debug, Default, Deserialize)]
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ pub fn get_config(&self, id: &str) -> Result<Provider> {
|
|||||||
let providers = &self.services.config.identity_provider;
|
let providers = &self.services.config.identity_provider;
|
||||||
|
|
||||||
if let Some(provider) = providers
|
if let Some(provider) = providers
|
||||||
.iter()
|
.values()
|
||||||
.find(|config| config.id() == id)
|
.find(|config| config.id() == id)
|
||||||
.cloned()
|
.cloned()
|
||||||
{
|
{
|
||||||
@@ -68,11 +68,11 @@ pub fn get_config(&self, id: &str) -> Result<Provider> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if let Some(provider) = providers
|
if let Some(provider) = providers
|
||||||
.iter()
|
.values()
|
||||||
.find(|config| config.brand == id.to_lowercase())
|
.find(|config| config.brand == id.to_lowercase())
|
||||||
.filter(|_| {
|
.filter(|_| {
|
||||||
providers
|
providers
|
||||||
.iter()
|
.values()
|
||||||
.filter(|config| config.brand == id.to_lowercase())
|
.filter(|config| config.brand == id.to_lowercase())
|
||||||
.count()
|
.count()
|
||||||
.eq(&1)
|
.eq(&1)
|
||||||
|
|||||||
Reference in New Issue
Block a user