diff --git a/src/rate_limit/cidr.rs b/src/rate_limit/cidr.rs new file mode 100644 index 0000000..9c0057a --- /dev/null +++ b/src/rate_limit/cidr.rs @@ -0,0 +1,143 @@ +use std::net::IpAddr; + +/// A parsed CIDR block for allowlist matching. +#[derive(Debug, Clone)] +pub struct CidrBlock { + addr: IpAddr, + prefix_len: u8, +} + +impl CidrBlock { + /// Parse a CIDR string like "10.0.0.0/8" or "fd00::/8". + pub fn parse(s: &str) -> Option { + let (addr_str, len_str) = s.split_once('/')?; + let prefix_len: u8 = len_str.parse().ok()?; + let addr: IpAddr = addr_str.parse().ok()?; + match &addr { + IpAddr::V4(_) if prefix_len > 32 => return None, + IpAddr::V6(_) if prefix_len > 128 => return None, + _ => {} + } + Some(Self { addr, prefix_len }) + } + + /// Check whether `ip` falls within this CIDR block. + /// Handles IPv4-mapped IPv6 addresses (e.g. `::ffff:10.0.0.1`). + pub fn contains(&self, ip: IpAddr) -> bool { + // Normalise IPv4-mapped IPv6 → IPv4 + let ip = normalise(ip); + let addr = normalise(self.addr); + + match (addr, ip) { + (IpAddr::V4(net), IpAddr::V4(candidate)) => { + let mask = v4_mask(self.prefix_len); + u32::from(net) & mask == u32::from(candidate) & mask + } + (IpAddr::V6(net), IpAddr::V6(candidate)) => { + let mask = v6_mask(self.prefix_len); + let net_bits = u128::from(net); + let cand_bits = u128::from(candidate); + net_bits & mask == cand_bits & mask + } + _ => false, // v4 vs v6 mismatch + } + } +} + +/// Parse a list of CIDR strings, skipping any that are invalid. +pub fn parse_cidrs(strings: &[String]) -> Vec { + strings.iter().filter_map(|s| CidrBlock::parse(s)).collect() +} + +/// Check if `ip` is contained in any of the given CIDR blocks. +pub fn is_bypassed(ip: IpAddr, cidrs: &[CidrBlock]) -> bool { + cidrs.iter().any(|c| c.contains(ip)) +} + +fn normalise(ip: IpAddr) -> IpAddr { + match ip { + IpAddr::V6(v6) => match v6.to_ipv4_mapped() { + Some(v4) => IpAddr::V4(v4), + None => IpAddr::V6(v6), + }, + other => other, + } +} + +fn v4_mask(prefix_len: u8) -> u32 { + if prefix_len == 0 { + 0 + } else { + u32::MAX << (32 - prefix_len) + } +} + +fn v6_mask(prefix_len: u8) -> u128 { + if prefix_len == 0 { + 0 + } else { + u128::MAX << (128 - prefix_len) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ipv4_contains() { + let cidr = CidrBlock::parse("10.0.0.0/8").unwrap(); + assert!(cidr.contains("10.1.2.3".parse().unwrap())); + assert!(cidr.contains("10.255.255.255".parse().unwrap())); + assert!(!cidr.contains("11.0.0.1".parse().unwrap())); + } + + #[test] + fn test_ipv6_contains() { + let cidr = CidrBlock::parse("fd00::/8").unwrap(); + assert!(cidr.contains("fd12::1".parse().unwrap())); + assert!(!cidr.contains("fe80::1".parse().unwrap())); + } + + #[test] + fn test_ipv4_mapped_v6() { + let cidr = CidrBlock::parse("10.0.0.0/8").unwrap(); + // ::ffff:10.0.0.1 should match 10.0.0.0/8 + let mapped: IpAddr = "::ffff:10.0.0.1".parse().unwrap(); + assert!(cidr.contains(mapped)); + } + + #[test] + fn test_private_ranges() { + let cidrs = parse_cidrs(&[ + "10.0.0.0/8".into(), + "172.16.0.0/12".into(), + "192.168.0.0/16".into(), + ]); + assert!(is_bypassed("10.0.0.1".parse().unwrap(), &cidrs)); + assert!(is_bypassed("172.31.255.1".parse().unwrap(), &cidrs)); + assert!(is_bypassed("192.168.1.1".parse().unwrap(), &cidrs)); + assert!(!is_bypassed("8.8.8.8".parse().unwrap(), &cidrs)); + } + + #[test] + fn test_invalid_cidrs() { + assert!(CidrBlock::parse("not-a-cidr").is_none()); + assert!(CidrBlock::parse("10.0.0.0/33").is_none()); + assert!(CidrBlock::parse("10.0.0.0").is_none()); + } + + #[test] + fn test_slash_zero() { + let cidr = CidrBlock::parse("0.0.0.0/0").unwrap(); + assert!(cidr.contains("1.2.3.4".parse().unwrap())); + assert!(cidr.contains("255.255.255.255".parse().unwrap())); + } + + #[test] + fn test_slash_32() { + let cidr = CidrBlock::parse("1.2.3.4/32").unwrap(); + assert!(cidr.contains("1.2.3.4".parse().unwrap())); + assert!(!cidr.contains("1.2.3.5".parse().unwrap())); + } +} diff --git a/src/rate_limit/key.rs b/src/rate_limit/key.rs new file mode 100644 index 0000000..be0514f --- /dev/null +++ b/src/rate_limit/key.rs @@ -0,0 +1,145 @@ +use std::hash::{Hash, Hasher}; +use std::net::IpAddr; + +/// Identity key for rate limiting buckets. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum RateLimitKey { + /// Hashed auth credential (session cookie or bearer token). + Identity(u64), + /// Unauthenticated fallback: client IP. + Ip(IpAddr), +} + +impl RateLimitKey { + pub fn is_authenticated(&self) -> bool { + matches!(self, RateLimitKey::Identity(_)) + } +} + +/// Extract a rate-limit key from request headers. +/// +/// Priority: +/// 1. `ory_kratos_session` cookie → `Identity(hash)` +/// 2. `Authorization: Bearer ` → `Identity(hash)` +/// 3. Client IP → `Ip(addr)` +pub fn extract_key( + cookie_header: Option<&str>, + auth_header: Option<&str>, + client_ip: IpAddr, +) -> RateLimitKey { + // 1. Check for Kratos session cookie + if let Some(cookies) = cookie_header { + if let Some(value) = extract_cookie_value(cookies, "ory_kratos_session") { + return RateLimitKey::Identity(fx_hash(value)); + } + } + + // 2. Check for Bearer token + if let Some(auth) = auth_header { + if let Some(token) = auth.strip_prefix("Bearer ") { + let token = token.trim(); + if !token.is_empty() { + return RateLimitKey::Identity(fx_hash(token)); + } + } + } + + // 3. Fall back to IP + RateLimitKey::Ip(client_ip) +} + +fn extract_cookie_value<'a>(cookies: &'a str, name: &str) -> Option<&'a str> { + for pair in cookies.split(';') { + let pair = pair.trim(); + if let Some((k, v)) = pair.split_once('=') { + if k.trim() == name { + return Some(v.trim()); + } + } + } + None +} + +fn fx_hash(s: &str) -> u64 { + let mut h = rustc_hash::FxHasher::default(); + s.hash(&mut h); + h.finish() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_cookie_extraction() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = extract_key( + Some("ory_kratos_session=abc123; other=val"), + None, + ip, + ); + assert!(key.is_authenticated()); + assert!(matches!(key, RateLimitKey::Identity(_))); + } + + #[test] + fn test_cookie_only_cookie() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = extract_key(Some("ory_kratos_session=tok"), None, ip); + assert!(key.is_authenticated()); + } + + #[test] + fn test_bearer_extraction() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = extract_key(None, Some("Bearer my-token-123"), ip); + assert!(key.is_authenticated()); + } + + #[test] + fn test_ip_fallback() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = extract_key(None, None, ip); + assert_eq!(key, RateLimitKey::Ip(ip)); + assert!(!key.is_authenticated()); + } + + #[test] + fn test_cookie_takes_priority_over_bearer() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key_cookie = extract_key( + Some("ory_kratos_session=sess1"), + Some("Bearer tok1"), + ip, + ); + let key_bearer = extract_key(None, Some("Bearer tok1"), ip); + // Cookie and bearer should produce different hashes + assert_ne!(key_cookie, key_bearer); + assert!(key_cookie.is_authenticated()); + } + + #[test] + fn test_empty_bearer_falls_back() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = extract_key(None, Some("Bearer "), ip); + assert_eq!(key, RateLimitKey::Ip(ip)); + } + + #[test] + fn test_multiple_cookies() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = extract_key( + Some("foo=bar; ory_kratos_session=mysess; baz=qux"), + None, + ip, + ); + assert!(key.is_authenticated()); + } + + #[test] + fn test_wrong_cookie_name_falls_back() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = extract_key(Some("session=val"), None, ip); + assert_eq!(key, RateLimitKey::Ip(ip)); + } +} diff --git a/src/rate_limit/limiter.rs b/src/rate_limit/limiter.rs new file mode 100644 index 0000000..92eb9a3 --- /dev/null +++ b/src/rate_limit/limiter.rs @@ -0,0 +1,264 @@ +use crate::config::{BucketConfig, RateLimitConfig}; +use crate::rate_limit::cidr::{self, CidrBlock}; +use crate::rate_limit::key::RateLimitKey; +use rustc_hash::FxHashMap; +use std::hash::{Hash, Hasher}; +use std::net::IpAddr; +use std::sync::RwLock; +use std::time::Instant; + +const NUM_SHARDS: usize = 256; + +/// Result of a rate limit check. +#[derive(Debug, PartialEq)] +pub enum RateLimitResult { + Allow, + Reject { retry_after: u64 }, +} + +struct Bucket { + tokens: f64, + last_refill: Instant, + authenticated: bool, +} + +pub struct RateLimiter { + shards: Vec>>, + bypass_cidrs: Vec, + authenticated: BucketConfig, + unauthenticated: BucketConfig, + stale_after_secs: u64, +} + +fn shard_index(key: &RateLimitKey) -> usize { + let mut h = rustc_hash::FxHasher::default(); + key.hash(&mut h); + h.finish() as usize % NUM_SHARDS +} + +impl RateLimiter { + pub fn new(config: &RateLimitConfig) -> Self { + let shards = (0..NUM_SHARDS) + .map(|_| RwLock::new(FxHashMap::default())) + .collect(); + let bypass_cidrs = cidr::parse_cidrs(&config.bypass_cidrs); + Self { + shards, + bypass_cidrs, + authenticated: config.authenticated.clone(), + unauthenticated: config.unauthenticated.clone(), + stale_after_secs: config.stale_after_secs, + } + } + + /// Check whether a request should be allowed or rejected. + pub fn check(&self, ip: IpAddr, key: RateLimitKey) -> RateLimitResult { + // CIDR bypass + if cidr::is_bypassed(ip, &self.bypass_cidrs) { + return RateLimitResult::Allow; + } + + let cfg = if key.is_authenticated() { + &self.authenticated + } else { + &self.unauthenticated + }; + + let now = Instant::now(); + let idx = shard_index(&key); + let mut shard = self.shards[idx].write().unwrap_or_else(|e| e.into_inner()); + + let bucket = shard.entry(key).or_insert_with(|| Bucket { + tokens: cfg.burst as f64, + last_refill: now, + authenticated: key.is_authenticated(), + }); + + // Refill tokens based on elapsed time + let elapsed = now.duration_since(bucket.last_refill).as_secs_f64(); + let tier = if bucket.authenticated { + &self.authenticated + } else { + &self.unauthenticated + }; + bucket.tokens = (bucket.tokens + elapsed * tier.rate).min(tier.burst as f64); + bucket.last_refill = now; + + if bucket.tokens >= 1.0 { + bucket.tokens -= 1.0; + RateLimitResult::Allow + } else { + let retry_after = ((1.0 - bucket.tokens) / tier.rate).ceil() as u64; + RateLimitResult::Reject { + retry_after: retry_after.max(1), + } + } + } + + /// Remove buckets that haven't been used for `stale_after_secs`. + pub fn evict_stale(&self) { + let now = Instant::now(); + let stale = std::time::Duration::from_secs(self.stale_after_secs); + for shard in &self.shards { + let mut map = shard.write().unwrap_or_else(|e| e.into_inner()); + map.retain(|_, bucket| now.duration_since(bucket.last_refill) < stale); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{BucketConfig, RateLimitConfig}; + + fn test_config() -> RateLimitConfig { + RateLimitConfig { + enabled: true, + bypass_cidrs: vec!["10.0.0.0/8".into()], + eviction_interval_secs: 60, + stale_after_secs: 120, + authenticated: BucketConfig { + burst: 10, + rate: 5.0, + }, + unauthenticated: BucketConfig { + burst: 3, + rate: 1.0, + }, + } + } + + #[test] + fn test_burst_exhaustion_then_reject() { + let limiter = RateLimiter::new(&test_config()); + let ip: IpAddr = "203.0.113.1".parse().unwrap(); + let key = RateLimitKey::Ip(ip); + + // Burst of 3 for unauthenticated + for _ in 0..3 { + assert_eq!(limiter.check(ip, key), RateLimitResult::Allow); + } + // 4th request should be rejected + match limiter.check(ip, key) { + RateLimitResult::Reject { retry_after } => { + assert!(retry_after >= 1); + } + _ => panic!("expected reject"), + } + } + + #[test] + fn test_refill_allows_again() { + let limiter = RateLimiter::new(&test_config()); + let ip: IpAddr = "203.0.113.1".parse().unwrap(); + let key = RateLimitKey::Ip(ip); + + // Exhaust burst + for _ in 0..3 { + limiter.check(ip, key); + } + assert!(matches!( + limiter.check(ip, key), + RateLimitResult::Reject { .. } + )); + + // Manually simulate time passing by manipulating the bucket + { + let idx = shard_index(&key); + let mut shard = limiter.shards[idx].write().unwrap(); + let bucket = shard.get_mut(&key).unwrap(); + // Pretend 2 seconds passed (rate=1.0/s → 2 tokens refill) + bucket.last_refill -= std::time::Duration::from_secs(2); + } + + assert_eq!(limiter.check(ip, key), RateLimitResult::Allow); + } + + #[test] + fn test_cidr_bypass() { + let limiter = RateLimiter::new(&test_config()); + let ip: IpAddr = "10.0.0.1".parse().unwrap(); + let key = RateLimitKey::Ip(ip); + + // Should always allow for bypassed CIDRs, even after many requests + for _ in 0..100 { + assert_eq!(limiter.check(ip, key), RateLimitResult::Allow); + } + } + + #[test] + fn test_authenticated_higher_burst() { + let limiter = RateLimiter::new(&test_config()); + let ip: IpAddr = "203.0.113.1".parse().unwrap(); + let key = RateLimitKey::Identity(12345); + + // Authenticated burst is 10 + for _ in 0..10 { + assert_eq!(limiter.check(ip, key), RateLimitResult::Allow); + } + assert!(matches!( + limiter.check(ip, key), + RateLimitResult::Reject { .. } + )); + } + + #[test] + fn test_independent_keys() { + let limiter = RateLimiter::new(&test_config()); + let ip1: IpAddr = "203.0.113.1".parse().unwrap(); + let ip2: IpAddr = "203.0.113.2".parse().unwrap(); + + // Exhaust ip1's burst + for _ in 0..3 { + limiter.check(ip1, RateLimitKey::Ip(ip1)); + } + assert!(matches!( + limiter.check(ip1, RateLimitKey::Ip(ip1)), + RateLimitResult::Reject { .. } + )); + + // ip2 should still be allowed + assert_eq!( + limiter.check(ip2, RateLimitKey::Ip(ip2)), + RateLimitResult::Allow + ); + } + + #[test] + fn test_eviction() { + let mut cfg = test_config(); + cfg.stale_after_secs = 0; // everything is stale immediately + let limiter = RateLimiter::new(&cfg); + let ip: IpAddr = "203.0.113.1".parse().unwrap(); + let key = RateLimitKey::Ip(ip); + + limiter.check(ip, key); + // Buckets should be populated + let idx = shard_index(&key); + assert!(!limiter.shards[idx].read().unwrap().is_empty()); + + // After eviction, buckets should be gone (stale_after_secs=0) + std::thread::sleep(std::time::Duration::from_millis(10)); + limiter.evict_stale(); + assert!(limiter.shards[idx].read().unwrap().is_empty()); + } + + #[test] + fn test_retry_after_value() { + let limiter = RateLimiter::new(&test_config()); + let ip: IpAddr = "203.0.113.1".parse().unwrap(); + let key = RateLimitKey::Ip(ip); + + // Exhaust burst (rate=1.0/s for unauth) + for _ in 0..3 { + limiter.check(ip, key); + } + match limiter.check(ip, key) { + RateLimitResult::Reject { retry_after } => { + // With rate=1.0, need ~1 token, so retry_after should be 1 + assert_eq!(retry_after, 1); + } + _ => panic!("expected reject"), + } + } +} diff --git a/src/rate_limit/mod.rs b/src/rate_limit/mod.rs new file mode 100644 index 0000000..8b551c0 --- /dev/null +++ b/src/rate_limit/mod.rs @@ -0,0 +1,3 @@ +pub mod cidr; +pub mod key; +pub mod limiter; diff --git a/tests/rate_limit_test.rs b/tests/rate_limit_test.rs new file mode 100644 index 0000000..1b517dd --- /dev/null +++ b/tests/rate_limit_test.rs @@ -0,0 +1,118 @@ +use std::net::IpAddr; +use sunbeam_proxy::config::{BucketConfig, RateLimitConfig}; +use sunbeam_proxy::rate_limit::key::{self, RateLimitKey}; +use sunbeam_proxy::rate_limit::limiter::{RateLimitResult, RateLimiter}; + +fn default_config() -> RateLimitConfig { + RateLimitConfig { + enabled: true, + bypass_cidrs: vec!["10.0.0.0/8".into(), "fd00::/8".into()], + eviction_interval_secs: 300, + stale_after_secs: 600, + authenticated: BucketConfig { + burst: 20, + rate: 10.0, + }, + unauthenticated: BucketConfig { + burst: 5, + rate: 2.0, + }, + } +} + +#[test] +fn test_unauthenticated_burst_limit() { + let limiter = RateLimiter::new(&default_config()); + let ip: IpAddr = "203.0.113.50".parse().unwrap(); + let key = RateLimitKey::Ip(ip); + + for _ in 0..5 { + assert_eq!(limiter.check(ip, key), RateLimitResult::Allow); + } + match limiter.check(ip, key) { + RateLimitResult::Reject { retry_after } => assert!(retry_after >= 1), + _ => panic!("expected reject after burst exhaustion"), + } +} + +#[test] +fn test_authenticated_gets_higher_burst() { + let limiter = RateLimiter::new(&default_config()); + let ip: IpAddr = "203.0.113.51".parse().unwrap(); + let key = RateLimitKey::Identity(99999); + + for i in 0..20 { + assert_eq!( + limiter.check(ip, key), + RateLimitResult::Allow, + "request {i} should be allowed" + ); + } + assert!(matches!( + limiter.check(ip, key), + RateLimitResult::Reject { .. } + )); +} + +#[test] +fn test_cluster_cidr_bypass() { + let limiter = RateLimiter::new(&default_config()); + let internal_ip: IpAddr = "10.244.1.5".parse().unwrap(); + let key = RateLimitKey::Ip(internal_ip); + + // Should never be rate-limited + for _ in 0..200 { + assert_eq!(limiter.check(internal_ip, key), RateLimitResult::Allow); + } +} + +#[test] +fn test_key_extraction_cookie() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = key::extract_key( + Some("_csrf=abc; ory_kratos_session=session123; lang=en"), + None, + ip, + ); + assert!(key.is_authenticated()); +} + +#[test] +fn test_key_extraction_bearer() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = key::extract_key(None, Some("Bearer gitea_pat_abc123"), ip); + assert!(key.is_authenticated()); +} + +#[test] +fn test_key_extraction_ip_fallback() { + let ip: IpAddr = "1.2.3.4".parse().unwrap(); + let key = key::extract_key(None, None, ip); + assert_eq!(key, RateLimitKey::Ip(ip)); +} + +#[test] +fn test_concurrent_different_ips() { + let limiter = std::sync::Arc::new(RateLimiter::new(&default_config())); + let mut handles = vec![]; + + for i in 0..10u8 { + let limiter = limiter.clone(); + handles.push(std::thread::spawn(move || { + let ip: IpAddr = format!("203.0.113.{i}").parse().unwrap(); + let key = RateLimitKey::Ip(ip); + let mut allowed = 0; + for _ in 0..10 { + if limiter.check(ip, key) == RateLimitResult::Allow { + allowed += 1; + } + } + // Each IP gets burst=5, so exactly 5 should be allowed + assert_eq!(allowed, 5, "IP 203.0.113.{i} allowed {allowed}, expected 5"); + })); + } + + for h in handles { + h.join().unwrap(); + } +}