Use map of identity_provider to accommodate env var enumerations.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk
2026-01-30 23:55:28 +00:00
parent 0474753333
commit 0b864801f5
6 changed files with 76 additions and 18 deletions

View File

@@ -300,7 +300,24 @@ pub fn check(config: &Config) -> Result {
));
}
for (i, provider) in config.identity_provider.iter().enumerate() {
for a in config.identity_provider.values() {
let count = config
.identity_provider
.values()
.filter(|b| a.id().eq(b.id()))
.count();
debug_assert_ne!(count, 0, "expected at least one identity_provider");
if count > 1 {
return Err!(Config(
"client_id",
"Duplicate identity_provider with client_id {}",
a.client_id
));
}
}
for (i, provider) in &config.identity_provider {
if provider.client_secret.is_some() {
continue;
}
@@ -333,14 +350,14 @@ pub fn check(config: &Config) -> Result {
&& config.identity_provider.len() > 1
&& config
.identity_provider
.iter()
.values()
.filter(|idp| idp.default)
.count()
.eq(&0)
{
let default = config
.identity_provider
.iter()
.values()
.next()
.map(IdentityProvider::id)
.expect("Check at least one provider is configured to reach here");

View File

@@ -3,8 +3,7 @@ pub mod manager;
pub mod proxy;
use std::{
collections::{BTreeMap, BTreeSet, HashSet},
hash::{Hash, Hasher},
collections::{BTreeMap, BTreeSet},
net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
path::{Path, PathBuf},
};
@@ -2259,8 +2258,8 @@ pub struct Config {
pub appservice: BTreeMap<String, AppService>,
// external structure; separate sections
#[serde(default)]
pub identity_provider: HashSet<IdentityProvider>,
#[serde(default, with = "identity_provider_serde")]
pub identity_provider: BTreeMap<String, IdentityProvider>,
#[serde(flatten)]
#[expect(clippy::zero_sized_map_values)]
@@ -2578,7 +2577,7 @@ pub struct JwtConfig {
pub validate_signature: bool,
}
#[derive(Clone, Debug, Deserialize, Eq, PartialEq)]
#[derive(Clone, Debug, Deserialize)]
#[config_example_generator(
filename = "tuwunel-example.toml",
section = "[global.identity_provider]"
@@ -2763,8 +2762,50 @@ impl IdentityProvider {
}
}
impl Hash for IdentityProvider {
fn hash<H: Hasher>(&self, state: &mut H) { self.id().hash(state) }
mod identity_provider_serde {
use std::{collections::BTreeMap, fmt, marker::PhantomData};
use serde::{
Deserializer, de,
de::{MapAccess, SeqAccess},
};
struct Visitor(PhantomData<IdentityProviders>);
type IdentityProviders = BTreeMap<String, super::IdentityProvider>;
pub(super) fn deserialize<'de, D>(de: D) -> Result<IdentityProviders, D::Error>
where
D: Deserializer<'de>,
{
de.deserialize_any(Visitor(PhantomData))
}
impl<'de> de::Visitor<'de> for Visitor {
type Value = IdentityProviders;
fn expecting(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.write_str("Mapping or Sequence")
}
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
let mut ret = Self::Value::new();
while let Some((k, v)) = map.next_entry()? {
ret.insert(k, v);
}
Ok(ret)
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let mut ret = Self::Value::new();
while let Some(v) = seq.next_element()? {
ret.insert(ret.len().to_string(), v);
}
Ok(ret)
}
}
}
#[derive(Clone, Debug, Default, Deserialize)]