feat(rate_limit): add per-identity leaky bucket rate limiter

256-shard RwLock<FxHashMap> for concurrent access, auth key extraction
(ory_kratos_session cookie > Bearer token > client IP), CIDR bypass
for trusted networks, and background eviction of stale buckets.

Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
2026-03-10 23:38:19 +00:00
parent 007865fbe7
commit 4bccff3303
5 changed files with 673 additions and 0 deletions

143
src/rate_limit/cidr.rs Normal file
View File

@@ -0,0 +1,143 @@
use std::net::IpAddr;
/// A parsed CIDR block for allowlist matching.
#[derive(Debug, Clone)]
pub struct CidrBlock {
addr: IpAddr,
prefix_len: u8,
}
impl CidrBlock {
/// Parse a CIDR string like "10.0.0.0/8" or "fd00::/8".
pub fn parse(s: &str) -> Option<Self> {
let (addr_str, len_str) = s.split_once('/')?;
let prefix_len: u8 = len_str.parse().ok()?;
let addr: IpAddr = addr_str.parse().ok()?;
match &addr {
IpAddr::V4(_) if prefix_len > 32 => return None,
IpAddr::V6(_) if prefix_len > 128 => return None,
_ => {}
}
Some(Self { addr, prefix_len })
}
/// Check whether `ip` falls within this CIDR block.
/// Handles IPv4-mapped IPv6 addresses (e.g. `::ffff:10.0.0.1`).
pub fn contains(&self, ip: IpAddr) -> bool {
// Normalise IPv4-mapped IPv6 → IPv4
let ip = normalise(ip);
let addr = normalise(self.addr);
match (addr, ip) {
(IpAddr::V4(net), IpAddr::V4(candidate)) => {
let mask = v4_mask(self.prefix_len);
u32::from(net) & mask == u32::from(candidate) & mask
}
(IpAddr::V6(net), IpAddr::V6(candidate)) => {
let mask = v6_mask(self.prefix_len);
let net_bits = u128::from(net);
let cand_bits = u128::from(candidate);
net_bits & mask == cand_bits & mask
}
_ => false, // v4 vs v6 mismatch
}
}
}
/// Parse a list of CIDR strings, skipping any that are invalid.
pub fn parse_cidrs(strings: &[String]) -> Vec<CidrBlock> {
strings.iter().filter_map(|s| CidrBlock::parse(s)).collect()
}
/// Check if `ip` is contained in any of the given CIDR blocks.
pub fn is_bypassed(ip: IpAddr, cidrs: &[CidrBlock]) -> bool {
cidrs.iter().any(|c| c.contains(ip))
}
fn normalise(ip: IpAddr) -> IpAddr {
match ip {
IpAddr::V6(v6) => match v6.to_ipv4_mapped() {
Some(v4) => IpAddr::V4(v4),
None => IpAddr::V6(v6),
},
other => other,
}
}
fn v4_mask(prefix_len: u8) -> u32 {
if prefix_len == 0 {
0
} else {
u32::MAX << (32 - prefix_len)
}
}
fn v6_mask(prefix_len: u8) -> u128 {
if prefix_len == 0 {
0
} else {
u128::MAX << (128 - prefix_len)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ipv4_contains() {
let cidr = CidrBlock::parse("10.0.0.0/8").unwrap();
assert!(cidr.contains("10.1.2.3".parse().unwrap()));
assert!(cidr.contains("10.255.255.255".parse().unwrap()));
assert!(!cidr.contains("11.0.0.1".parse().unwrap()));
}
#[test]
fn test_ipv6_contains() {
let cidr = CidrBlock::parse("fd00::/8").unwrap();
assert!(cidr.contains("fd12::1".parse().unwrap()));
assert!(!cidr.contains("fe80::1".parse().unwrap()));
}
#[test]
fn test_ipv4_mapped_v6() {
let cidr = CidrBlock::parse("10.0.0.0/8").unwrap();
// ::ffff:10.0.0.1 should match 10.0.0.0/8
let mapped: IpAddr = "::ffff:10.0.0.1".parse().unwrap();
assert!(cidr.contains(mapped));
}
#[test]
fn test_private_ranges() {
let cidrs = parse_cidrs(&[
"10.0.0.0/8".into(),
"172.16.0.0/12".into(),
"192.168.0.0/16".into(),
]);
assert!(is_bypassed("10.0.0.1".parse().unwrap(), &cidrs));
assert!(is_bypassed("172.31.255.1".parse().unwrap(), &cidrs));
assert!(is_bypassed("192.168.1.1".parse().unwrap(), &cidrs));
assert!(!is_bypassed("8.8.8.8".parse().unwrap(), &cidrs));
}
#[test]
fn test_invalid_cidrs() {
assert!(CidrBlock::parse("not-a-cidr").is_none());
assert!(CidrBlock::parse("10.0.0.0/33").is_none());
assert!(CidrBlock::parse("10.0.0.0").is_none());
}
#[test]
fn test_slash_zero() {
let cidr = CidrBlock::parse("0.0.0.0/0").unwrap();
assert!(cidr.contains("1.2.3.4".parse().unwrap()));
assert!(cidr.contains("255.255.255.255".parse().unwrap()));
}
#[test]
fn test_slash_32() {
let cidr = CidrBlock::parse("1.2.3.4/32").unwrap();
assert!(cidr.contains("1.2.3.4".parse().unwrap()));
assert!(!cidr.contains("1.2.3.5".parse().unwrap()));
}
}

145
src/rate_limit/key.rs Normal file
View File

@@ -0,0 +1,145 @@
use std::hash::{Hash, Hasher};
use std::net::IpAddr;
/// Identity key for rate limiting buckets.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum RateLimitKey {
/// Hashed auth credential (session cookie or bearer token).
Identity(u64),
/// Unauthenticated fallback: client IP.
Ip(IpAddr),
}
impl RateLimitKey {
pub fn is_authenticated(&self) -> bool {
matches!(self, RateLimitKey::Identity(_))
}
}
/// Extract a rate-limit key from request headers.
///
/// Priority:
/// 1. `ory_kratos_session` cookie → `Identity(hash)`
/// 2. `Authorization: Bearer <token>` → `Identity(hash)`
/// 3. Client IP → `Ip(addr)`
pub fn extract_key(
cookie_header: Option<&str>,
auth_header: Option<&str>,
client_ip: IpAddr,
) -> RateLimitKey {
// 1. Check for Kratos session cookie
if let Some(cookies) = cookie_header {
if let Some(value) = extract_cookie_value(cookies, "ory_kratos_session") {
return RateLimitKey::Identity(fx_hash(value));
}
}
// 2. Check for Bearer token
if let Some(auth) = auth_header {
if let Some(token) = auth.strip_prefix("Bearer ") {
let token = token.trim();
if !token.is_empty() {
return RateLimitKey::Identity(fx_hash(token));
}
}
}
// 3. Fall back to IP
RateLimitKey::Ip(client_ip)
}
fn extract_cookie_value<'a>(cookies: &'a str, name: &str) -> Option<&'a str> {
for pair in cookies.split(';') {
let pair = pair.trim();
if let Some((k, v)) = pair.split_once('=') {
if k.trim() == name {
return Some(v.trim());
}
}
}
None
}
fn fx_hash(s: &str) -> u64 {
let mut h = rustc_hash::FxHasher::default();
s.hash(&mut h);
h.finish()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cookie_extraction() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = extract_key(
Some("ory_kratos_session=abc123; other=val"),
None,
ip,
);
assert!(key.is_authenticated());
assert!(matches!(key, RateLimitKey::Identity(_)));
}
#[test]
fn test_cookie_only_cookie() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = extract_key(Some("ory_kratos_session=tok"), None, ip);
assert!(key.is_authenticated());
}
#[test]
fn test_bearer_extraction() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = extract_key(None, Some("Bearer my-token-123"), ip);
assert!(key.is_authenticated());
}
#[test]
fn test_ip_fallback() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = extract_key(None, None, ip);
assert_eq!(key, RateLimitKey::Ip(ip));
assert!(!key.is_authenticated());
}
#[test]
fn test_cookie_takes_priority_over_bearer() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key_cookie = extract_key(
Some("ory_kratos_session=sess1"),
Some("Bearer tok1"),
ip,
);
let key_bearer = extract_key(None, Some("Bearer tok1"), ip);
// Cookie and bearer should produce different hashes
assert_ne!(key_cookie, key_bearer);
assert!(key_cookie.is_authenticated());
}
#[test]
fn test_empty_bearer_falls_back() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = extract_key(None, Some("Bearer "), ip);
assert_eq!(key, RateLimitKey::Ip(ip));
}
#[test]
fn test_multiple_cookies() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = extract_key(
Some("foo=bar; ory_kratos_session=mysess; baz=qux"),
None,
ip,
);
assert!(key.is_authenticated());
}
#[test]
fn test_wrong_cookie_name_falls_back() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = extract_key(Some("session=val"), None, ip);
assert_eq!(key, RateLimitKey::Ip(ip));
}
}

264
src/rate_limit/limiter.rs Normal file
View File

@@ -0,0 +1,264 @@
use crate::config::{BucketConfig, RateLimitConfig};
use crate::rate_limit::cidr::{self, CidrBlock};
use crate::rate_limit::key::RateLimitKey;
use rustc_hash::FxHashMap;
use std::hash::{Hash, Hasher};
use std::net::IpAddr;
use std::sync::RwLock;
use std::time::Instant;
const NUM_SHARDS: usize = 256;
/// Result of a rate limit check.
#[derive(Debug, PartialEq)]
pub enum RateLimitResult {
Allow,
Reject { retry_after: u64 },
}
struct Bucket {
tokens: f64,
last_refill: Instant,
authenticated: bool,
}
pub struct RateLimiter {
shards: Vec<RwLock<FxHashMap<RateLimitKey, Bucket>>>,
bypass_cidrs: Vec<CidrBlock>,
authenticated: BucketConfig,
unauthenticated: BucketConfig,
stale_after_secs: u64,
}
fn shard_index(key: &RateLimitKey) -> usize {
let mut h = rustc_hash::FxHasher::default();
key.hash(&mut h);
h.finish() as usize % NUM_SHARDS
}
impl RateLimiter {
pub fn new(config: &RateLimitConfig) -> Self {
let shards = (0..NUM_SHARDS)
.map(|_| RwLock::new(FxHashMap::default()))
.collect();
let bypass_cidrs = cidr::parse_cidrs(&config.bypass_cidrs);
Self {
shards,
bypass_cidrs,
authenticated: config.authenticated.clone(),
unauthenticated: config.unauthenticated.clone(),
stale_after_secs: config.stale_after_secs,
}
}
/// Check whether a request should be allowed or rejected.
pub fn check(&self, ip: IpAddr, key: RateLimitKey) -> RateLimitResult {
// CIDR bypass
if cidr::is_bypassed(ip, &self.bypass_cidrs) {
return RateLimitResult::Allow;
}
let cfg = if key.is_authenticated() {
&self.authenticated
} else {
&self.unauthenticated
};
let now = Instant::now();
let idx = shard_index(&key);
let mut shard = self.shards[idx].write().unwrap_or_else(|e| e.into_inner());
let bucket = shard.entry(key).or_insert_with(|| Bucket {
tokens: cfg.burst as f64,
last_refill: now,
authenticated: key.is_authenticated(),
});
// Refill tokens based on elapsed time
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
let tier = if bucket.authenticated {
&self.authenticated
} else {
&self.unauthenticated
};
bucket.tokens = (bucket.tokens + elapsed * tier.rate).min(tier.burst as f64);
bucket.last_refill = now;
if bucket.tokens >= 1.0 {
bucket.tokens -= 1.0;
RateLimitResult::Allow
} else {
let retry_after = ((1.0 - bucket.tokens) / tier.rate).ceil() as u64;
RateLimitResult::Reject {
retry_after: retry_after.max(1),
}
}
}
/// Remove buckets that haven't been used for `stale_after_secs`.
pub fn evict_stale(&self) {
let now = Instant::now();
let stale = std::time::Duration::from_secs(self.stale_after_secs);
for shard in &self.shards {
let mut map = shard.write().unwrap_or_else(|e| e.into_inner());
map.retain(|_, bucket| now.duration_since(bucket.last_refill) < stale);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{BucketConfig, RateLimitConfig};
fn test_config() -> RateLimitConfig {
RateLimitConfig {
enabled: true,
bypass_cidrs: vec!["10.0.0.0/8".into()],
eviction_interval_secs: 60,
stale_after_secs: 120,
authenticated: BucketConfig {
burst: 10,
rate: 5.0,
},
unauthenticated: BucketConfig {
burst: 3,
rate: 1.0,
},
}
}
#[test]
fn test_burst_exhaustion_then_reject() {
let limiter = RateLimiter::new(&test_config());
let ip: IpAddr = "203.0.113.1".parse().unwrap();
let key = RateLimitKey::Ip(ip);
// Burst of 3 for unauthenticated
for _ in 0..3 {
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
}
// 4th request should be rejected
match limiter.check(ip, key) {
RateLimitResult::Reject { retry_after } => {
assert!(retry_after >= 1);
}
_ => panic!("expected reject"),
}
}
#[test]
fn test_refill_allows_again() {
let limiter = RateLimiter::new(&test_config());
let ip: IpAddr = "203.0.113.1".parse().unwrap();
let key = RateLimitKey::Ip(ip);
// Exhaust burst
for _ in 0..3 {
limiter.check(ip, key);
}
assert!(matches!(
limiter.check(ip, key),
RateLimitResult::Reject { .. }
));
// Manually simulate time passing by manipulating the bucket
{
let idx = shard_index(&key);
let mut shard = limiter.shards[idx].write().unwrap();
let bucket = shard.get_mut(&key).unwrap();
// Pretend 2 seconds passed (rate=1.0/s → 2 tokens refill)
bucket.last_refill -= std::time::Duration::from_secs(2);
}
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
}
#[test]
fn test_cidr_bypass() {
let limiter = RateLimiter::new(&test_config());
let ip: IpAddr = "10.0.0.1".parse().unwrap();
let key = RateLimitKey::Ip(ip);
// Should always allow for bypassed CIDRs, even after many requests
for _ in 0..100 {
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
}
}
#[test]
fn test_authenticated_higher_burst() {
let limiter = RateLimiter::new(&test_config());
let ip: IpAddr = "203.0.113.1".parse().unwrap();
let key = RateLimitKey::Identity(12345);
// Authenticated burst is 10
for _ in 0..10 {
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
}
assert!(matches!(
limiter.check(ip, key),
RateLimitResult::Reject { .. }
));
}
#[test]
fn test_independent_keys() {
let limiter = RateLimiter::new(&test_config());
let ip1: IpAddr = "203.0.113.1".parse().unwrap();
let ip2: IpAddr = "203.0.113.2".parse().unwrap();
// Exhaust ip1's burst
for _ in 0..3 {
limiter.check(ip1, RateLimitKey::Ip(ip1));
}
assert!(matches!(
limiter.check(ip1, RateLimitKey::Ip(ip1)),
RateLimitResult::Reject { .. }
));
// ip2 should still be allowed
assert_eq!(
limiter.check(ip2, RateLimitKey::Ip(ip2)),
RateLimitResult::Allow
);
}
#[test]
fn test_eviction() {
let mut cfg = test_config();
cfg.stale_after_secs = 0; // everything is stale immediately
let limiter = RateLimiter::new(&cfg);
let ip: IpAddr = "203.0.113.1".parse().unwrap();
let key = RateLimitKey::Ip(ip);
limiter.check(ip, key);
// Buckets should be populated
let idx = shard_index(&key);
assert!(!limiter.shards[idx].read().unwrap().is_empty());
// After eviction, buckets should be gone (stale_after_secs=0)
std::thread::sleep(std::time::Duration::from_millis(10));
limiter.evict_stale();
assert!(limiter.shards[idx].read().unwrap().is_empty());
}
#[test]
fn test_retry_after_value() {
let limiter = RateLimiter::new(&test_config());
let ip: IpAddr = "203.0.113.1".parse().unwrap();
let key = RateLimitKey::Ip(ip);
// Exhaust burst (rate=1.0/s for unauth)
for _ in 0..3 {
limiter.check(ip, key);
}
match limiter.check(ip, key) {
RateLimitResult::Reject { retry_after } => {
// With rate=1.0, need ~1 token, so retry_after should be 1
assert_eq!(retry_after, 1);
}
_ => panic!("expected reject"),
}
}
}

3
src/rate_limit/mod.rs Normal file
View File

@@ -0,0 +1,3 @@
pub mod cidr;
pub mod key;
pub mod limiter;

118
tests/rate_limit_test.rs Normal file
View File

@@ -0,0 +1,118 @@
use std::net::IpAddr;
use sunbeam_proxy::config::{BucketConfig, RateLimitConfig};
use sunbeam_proxy::rate_limit::key::{self, RateLimitKey};
use sunbeam_proxy::rate_limit::limiter::{RateLimitResult, RateLimiter};
fn default_config() -> RateLimitConfig {
RateLimitConfig {
enabled: true,
bypass_cidrs: vec!["10.0.0.0/8".into(), "fd00::/8".into()],
eviction_interval_secs: 300,
stale_after_secs: 600,
authenticated: BucketConfig {
burst: 20,
rate: 10.0,
},
unauthenticated: BucketConfig {
burst: 5,
rate: 2.0,
},
}
}
#[test]
fn test_unauthenticated_burst_limit() {
let limiter = RateLimiter::new(&default_config());
let ip: IpAddr = "203.0.113.50".parse().unwrap();
let key = RateLimitKey::Ip(ip);
for _ in 0..5 {
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
}
match limiter.check(ip, key) {
RateLimitResult::Reject { retry_after } => assert!(retry_after >= 1),
_ => panic!("expected reject after burst exhaustion"),
}
}
#[test]
fn test_authenticated_gets_higher_burst() {
let limiter = RateLimiter::new(&default_config());
let ip: IpAddr = "203.0.113.51".parse().unwrap();
let key = RateLimitKey::Identity(99999);
for i in 0..20 {
assert_eq!(
limiter.check(ip, key),
RateLimitResult::Allow,
"request {i} should be allowed"
);
}
assert!(matches!(
limiter.check(ip, key),
RateLimitResult::Reject { .. }
));
}
#[test]
fn test_cluster_cidr_bypass() {
let limiter = RateLimiter::new(&default_config());
let internal_ip: IpAddr = "10.244.1.5".parse().unwrap();
let key = RateLimitKey::Ip(internal_ip);
// Should never be rate-limited
for _ in 0..200 {
assert_eq!(limiter.check(internal_ip, key), RateLimitResult::Allow);
}
}
#[test]
fn test_key_extraction_cookie() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = key::extract_key(
Some("_csrf=abc; ory_kratos_session=session123; lang=en"),
None,
ip,
);
assert!(key.is_authenticated());
}
#[test]
fn test_key_extraction_bearer() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = key::extract_key(None, Some("Bearer gitea_pat_abc123"), ip);
assert!(key.is_authenticated());
}
#[test]
fn test_key_extraction_ip_fallback() {
let ip: IpAddr = "1.2.3.4".parse().unwrap();
let key = key::extract_key(None, None, ip);
assert_eq!(key, RateLimitKey::Ip(ip));
}
#[test]
fn test_concurrent_different_ips() {
let limiter = std::sync::Arc::new(RateLimiter::new(&default_config()));
let mut handles = vec![];
for i in 0..10u8 {
let limiter = limiter.clone();
handles.push(std::thread::spawn(move || {
let ip: IpAddr = format!("203.0.113.{i}").parse().unwrap();
let key = RateLimitKey::Ip(ip);
let mut allowed = 0;
for _ in 0..10 {
if limiter.check(ip, key) == RateLimitResult::Allow {
allowed += 1;
}
}
// Each IP gets burst=5, so exactly 5 should be allowed
assert_eq!(allowed, 5, "IP 203.0.113.{i} allowed {allowed}, expected 5");
}));
}
for h in handles {
h.join().unwrap();
}
}