feat(rate_limit): add per-identity leaky bucket rate limiter
256-shard RwLock<FxHashMap> for concurrent access, auth key extraction (ory_kratos_session cookie > Bearer token > client IP), CIDR bypass for trusted networks, and background eviction of stale buckets. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
143
src/rate_limit/cidr.rs
Normal file
143
src/rate_limit/cidr.rs
Normal file
@@ -0,0 +1,143 @@
|
||||
use std::net::IpAddr;
|
||||
|
||||
/// A parsed CIDR block for allowlist matching.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CidrBlock {
|
||||
addr: IpAddr,
|
||||
prefix_len: u8,
|
||||
}
|
||||
|
||||
impl CidrBlock {
|
||||
/// Parse a CIDR string like "10.0.0.0/8" or "fd00::/8".
|
||||
pub fn parse(s: &str) -> Option<Self> {
|
||||
let (addr_str, len_str) = s.split_once('/')?;
|
||||
let prefix_len: u8 = len_str.parse().ok()?;
|
||||
let addr: IpAddr = addr_str.parse().ok()?;
|
||||
match &addr {
|
||||
IpAddr::V4(_) if prefix_len > 32 => return None,
|
||||
IpAddr::V6(_) if prefix_len > 128 => return None,
|
||||
_ => {}
|
||||
}
|
||||
Some(Self { addr, prefix_len })
|
||||
}
|
||||
|
||||
/// Check whether `ip` falls within this CIDR block.
|
||||
/// Handles IPv4-mapped IPv6 addresses (e.g. `::ffff:10.0.0.1`).
|
||||
pub fn contains(&self, ip: IpAddr) -> bool {
|
||||
// Normalise IPv4-mapped IPv6 → IPv4
|
||||
let ip = normalise(ip);
|
||||
let addr = normalise(self.addr);
|
||||
|
||||
match (addr, ip) {
|
||||
(IpAddr::V4(net), IpAddr::V4(candidate)) => {
|
||||
let mask = v4_mask(self.prefix_len);
|
||||
u32::from(net) & mask == u32::from(candidate) & mask
|
||||
}
|
||||
(IpAddr::V6(net), IpAddr::V6(candidate)) => {
|
||||
let mask = v6_mask(self.prefix_len);
|
||||
let net_bits = u128::from(net);
|
||||
let cand_bits = u128::from(candidate);
|
||||
net_bits & mask == cand_bits & mask
|
||||
}
|
||||
_ => false, // v4 vs v6 mismatch
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a list of CIDR strings, skipping any that are invalid.
|
||||
pub fn parse_cidrs(strings: &[String]) -> Vec<CidrBlock> {
|
||||
strings.iter().filter_map(|s| CidrBlock::parse(s)).collect()
|
||||
}
|
||||
|
||||
/// Check if `ip` is contained in any of the given CIDR blocks.
|
||||
pub fn is_bypassed(ip: IpAddr, cidrs: &[CidrBlock]) -> bool {
|
||||
cidrs.iter().any(|c| c.contains(ip))
|
||||
}
|
||||
|
||||
fn normalise(ip: IpAddr) -> IpAddr {
|
||||
match ip {
|
||||
IpAddr::V6(v6) => match v6.to_ipv4_mapped() {
|
||||
Some(v4) => IpAddr::V4(v4),
|
||||
None => IpAddr::V6(v6),
|
||||
},
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
fn v4_mask(prefix_len: u8) -> u32 {
|
||||
if prefix_len == 0 {
|
||||
0
|
||||
} else {
|
||||
u32::MAX << (32 - prefix_len)
|
||||
}
|
||||
}
|
||||
|
||||
fn v6_mask(prefix_len: u8) -> u128 {
|
||||
if prefix_len == 0 {
|
||||
0
|
||||
} else {
|
||||
u128::MAX << (128 - prefix_len)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ipv4_contains() {
|
||||
let cidr = CidrBlock::parse("10.0.0.0/8").unwrap();
|
||||
assert!(cidr.contains("10.1.2.3".parse().unwrap()));
|
||||
assert!(cidr.contains("10.255.255.255".parse().unwrap()));
|
||||
assert!(!cidr.contains("11.0.0.1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv6_contains() {
|
||||
let cidr = CidrBlock::parse("fd00::/8").unwrap();
|
||||
assert!(cidr.contains("fd12::1".parse().unwrap()));
|
||||
assert!(!cidr.contains("fe80::1".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ipv4_mapped_v6() {
|
||||
let cidr = CidrBlock::parse("10.0.0.0/8").unwrap();
|
||||
// ::ffff:10.0.0.1 should match 10.0.0.0/8
|
||||
let mapped: IpAddr = "::ffff:10.0.0.1".parse().unwrap();
|
||||
assert!(cidr.contains(mapped));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_private_ranges() {
|
||||
let cidrs = parse_cidrs(&[
|
||||
"10.0.0.0/8".into(),
|
||||
"172.16.0.0/12".into(),
|
||||
"192.168.0.0/16".into(),
|
||||
]);
|
||||
assert!(is_bypassed("10.0.0.1".parse().unwrap(), &cidrs));
|
||||
assert!(is_bypassed("172.31.255.1".parse().unwrap(), &cidrs));
|
||||
assert!(is_bypassed("192.168.1.1".parse().unwrap(), &cidrs));
|
||||
assert!(!is_bypassed("8.8.8.8".parse().unwrap(), &cidrs));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_invalid_cidrs() {
|
||||
assert!(CidrBlock::parse("not-a-cidr").is_none());
|
||||
assert!(CidrBlock::parse("10.0.0.0/33").is_none());
|
||||
assert!(CidrBlock::parse("10.0.0.0").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slash_zero() {
|
||||
let cidr = CidrBlock::parse("0.0.0.0/0").unwrap();
|
||||
assert!(cidr.contains("1.2.3.4".parse().unwrap()));
|
||||
assert!(cidr.contains("255.255.255.255".parse().unwrap()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slash_32() {
|
||||
let cidr = CidrBlock::parse("1.2.3.4/32").unwrap();
|
||||
assert!(cidr.contains("1.2.3.4".parse().unwrap()));
|
||||
assert!(!cidr.contains("1.2.3.5".parse().unwrap()));
|
||||
}
|
||||
}
|
||||
145
src/rate_limit/key.rs
Normal file
145
src/rate_limit/key.rs
Normal file
@@ -0,0 +1,145 @@
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::net::IpAddr;
|
||||
|
||||
/// Identity key for rate limiting buckets.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||
pub enum RateLimitKey {
|
||||
/// Hashed auth credential (session cookie or bearer token).
|
||||
Identity(u64),
|
||||
/// Unauthenticated fallback: client IP.
|
||||
Ip(IpAddr),
|
||||
}
|
||||
|
||||
impl RateLimitKey {
|
||||
pub fn is_authenticated(&self) -> bool {
|
||||
matches!(self, RateLimitKey::Identity(_))
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract a rate-limit key from request headers.
|
||||
///
|
||||
/// Priority:
|
||||
/// 1. `ory_kratos_session` cookie → `Identity(hash)`
|
||||
/// 2. `Authorization: Bearer <token>` → `Identity(hash)`
|
||||
/// 3. Client IP → `Ip(addr)`
|
||||
pub fn extract_key(
|
||||
cookie_header: Option<&str>,
|
||||
auth_header: Option<&str>,
|
||||
client_ip: IpAddr,
|
||||
) -> RateLimitKey {
|
||||
// 1. Check for Kratos session cookie
|
||||
if let Some(cookies) = cookie_header {
|
||||
if let Some(value) = extract_cookie_value(cookies, "ory_kratos_session") {
|
||||
return RateLimitKey::Identity(fx_hash(value));
|
||||
}
|
||||
}
|
||||
|
||||
// 2. Check for Bearer token
|
||||
if let Some(auth) = auth_header {
|
||||
if let Some(token) = auth.strip_prefix("Bearer ") {
|
||||
let token = token.trim();
|
||||
if !token.is_empty() {
|
||||
return RateLimitKey::Identity(fx_hash(token));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Fall back to IP
|
||||
RateLimitKey::Ip(client_ip)
|
||||
}
|
||||
|
||||
fn extract_cookie_value<'a>(cookies: &'a str, name: &str) -> Option<&'a str> {
|
||||
for pair in cookies.split(';') {
|
||||
let pair = pair.trim();
|
||||
if let Some((k, v)) = pair.split_once('=') {
|
||||
if k.trim() == name {
|
||||
return Some(v.trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn fx_hash(s: &str) -> u64 {
|
||||
let mut h = rustc_hash::FxHasher::default();
|
||||
s.hash(&mut h);
|
||||
h.finish()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_cookie_extraction() {
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let key = extract_key(
|
||||
Some("ory_kratos_session=abc123; other=val"),
|
||||
None,
|
||||
ip,
|
||||
);
|
||||
assert!(key.is_authenticated());
|
||||
assert!(matches!(key, RateLimitKey::Identity(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cookie_only_cookie() {
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let key = extract_key(Some("ory_kratos_session=tok"), None, ip);
|
||||
assert!(key.is_authenticated());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bearer_extraction() {
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let key = extract_key(None, Some("Bearer my-token-123"), ip);
|
||||
assert!(key.is_authenticated());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ip_fallback() {
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let key = extract_key(None, None, ip);
|
||||
assert_eq!(key, RateLimitKey::Ip(ip));
|
||||
assert!(!key.is_authenticated());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cookie_takes_priority_over_bearer() {
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let key_cookie = extract_key(
|
||||
Some("ory_kratos_session=sess1"),
|
||||
Some("Bearer tok1"),
|
||||
ip,
|
||||
);
|
||||
let key_bearer = extract_key(None, Some("Bearer tok1"), ip);
|
||||
// Cookie and bearer should produce different hashes
|
||||
assert_ne!(key_cookie, key_bearer);
|
||||
assert!(key_cookie.is_authenticated());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_bearer_falls_back() {
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let key = extract_key(None, Some("Bearer "), ip);
|
||||
assert_eq!(key, RateLimitKey::Ip(ip));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_multiple_cookies() {
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let key = extract_key(
|
||||
Some("foo=bar; ory_kratos_session=mysess; baz=qux"),
|
||||
None,
|
||||
ip,
|
||||
);
|
||||
assert!(key.is_authenticated());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wrong_cookie_name_falls_back() {
|
||||
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||
let key = extract_key(Some("session=val"), None, ip);
|
||||
assert_eq!(key, RateLimitKey::Ip(ip));
|
||||
}
|
||||
}
|
||||
264
src/rate_limit/limiter.rs
Normal file
264
src/rate_limit/limiter.rs
Normal file
@@ -0,0 +1,264 @@
|
||||
use crate::config::{BucketConfig, RateLimitConfig};
|
||||
use crate::rate_limit::cidr::{self, CidrBlock};
|
||||
use crate::rate_limit::key::RateLimitKey;
|
||||
use rustc_hash::FxHashMap;
|
||||
use std::hash::{Hash, Hasher};
|
||||
use std::net::IpAddr;
|
||||
use std::sync::RwLock;
|
||||
use std::time::Instant;
|
||||
|
||||
const NUM_SHARDS: usize = 256;
|
||||
|
||||
/// Result of a rate limit check.
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub enum RateLimitResult {
|
||||
Allow,
|
||||
Reject { retry_after: u64 },
|
||||
}
|
||||
|
||||
struct Bucket {
|
||||
tokens: f64,
|
||||
last_refill: Instant,
|
||||
authenticated: bool,
|
||||
}
|
||||
|
||||
pub struct RateLimiter {
|
||||
shards: Vec<RwLock<FxHashMap<RateLimitKey, Bucket>>>,
|
||||
bypass_cidrs: Vec<CidrBlock>,
|
||||
authenticated: BucketConfig,
|
||||
unauthenticated: BucketConfig,
|
||||
stale_after_secs: u64,
|
||||
}
|
||||
|
||||
fn shard_index(key: &RateLimitKey) -> usize {
|
||||
let mut h = rustc_hash::FxHasher::default();
|
||||
key.hash(&mut h);
|
||||
h.finish() as usize % NUM_SHARDS
|
||||
}
|
||||
|
||||
impl RateLimiter {
|
||||
pub fn new(config: &RateLimitConfig) -> Self {
|
||||
let shards = (0..NUM_SHARDS)
|
||||
.map(|_| RwLock::new(FxHashMap::default()))
|
||||
.collect();
|
||||
let bypass_cidrs = cidr::parse_cidrs(&config.bypass_cidrs);
|
||||
Self {
|
||||
shards,
|
||||
bypass_cidrs,
|
||||
authenticated: config.authenticated.clone(),
|
||||
unauthenticated: config.unauthenticated.clone(),
|
||||
stale_after_secs: config.stale_after_secs,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check whether a request should be allowed or rejected.
|
||||
pub fn check(&self, ip: IpAddr, key: RateLimitKey) -> RateLimitResult {
|
||||
// CIDR bypass
|
||||
if cidr::is_bypassed(ip, &self.bypass_cidrs) {
|
||||
return RateLimitResult::Allow;
|
||||
}
|
||||
|
||||
let cfg = if key.is_authenticated() {
|
||||
&self.authenticated
|
||||
} else {
|
||||
&self.unauthenticated
|
||||
};
|
||||
|
||||
let now = Instant::now();
|
||||
let idx = shard_index(&key);
|
||||
let mut shard = self.shards[idx].write().unwrap_or_else(|e| e.into_inner());
|
||||
|
||||
let bucket = shard.entry(key).or_insert_with(|| Bucket {
|
||||
tokens: cfg.burst as f64,
|
||||
last_refill: now,
|
||||
authenticated: key.is_authenticated(),
|
||||
});
|
||||
|
||||
// Refill tokens based on elapsed time
|
||||
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
|
||||
let tier = if bucket.authenticated {
|
||||
&self.authenticated
|
||||
} else {
|
||||
&self.unauthenticated
|
||||
};
|
||||
bucket.tokens = (bucket.tokens + elapsed * tier.rate).min(tier.burst as f64);
|
||||
bucket.last_refill = now;
|
||||
|
||||
if bucket.tokens >= 1.0 {
|
||||
bucket.tokens -= 1.0;
|
||||
RateLimitResult::Allow
|
||||
} else {
|
||||
let retry_after = ((1.0 - bucket.tokens) / tier.rate).ceil() as u64;
|
||||
RateLimitResult::Reject {
|
||||
retry_after: retry_after.max(1),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove buckets that haven't been used for `stale_after_secs`.
|
||||
pub fn evict_stale(&self) {
|
||||
let now = Instant::now();
|
||||
let stale = std::time::Duration::from_secs(self.stale_after_secs);
|
||||
for shard in &self.shards {
|
||||
let mut map = shard.write().unwrap_or_else(|e| e.into_inner());
|
||||
map.retain(|_, bucket| now.duration_since(bucket.last_refill) < stale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::{BucketConfig, RateLimitConfig};
|
||||
|
||||
fn test_config() -> RateLimitConfig {
|
||||
RateLimitConfig {
|
||||
enabled: true,
|
||||
bypass_cidrs: vec!["10.0.0.0/8".into()],
|
||||
eviction_interval_secs: 60,
|
||||
stale_after_secs: 120,
|
||||
authenticated: BucketConfig {
|
||||
burst: 10,
|
||||
rate: 5.0,
|
||||
},
|
||||
unauthenticated: BucketConfig {
|
||||
burst: 3,
|
||||
rate: 1.0,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_burst_exhaustion_then_reject() {
|
||||
let limiter = RateLimiter::new(&test_config());
|
||||
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||
let key = RateLimitKey::Ip(ip);
|
||||
|
||||
// Burst of 3 for unauthenticated
|
||||
for _ in 0..3 {
|
||||
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||
}
|
||||
// 4th request should be rejected
|
||||
match limiter.check(ip, key) {
|
||||
RateLimitResult::Reject { retry_after } => {
|
||||
assert!(retry_after >= 1);
|
||||
}
|
||||
_ => panic!("expected reject"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_refill_allows_again() {
|
||||
let limiter = RateLimiter::new(&test_config());
|
||||
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||
let key = RateLimitKey::Ip(ip);
|
||||
|
||||
// Exhaust burst
|
||||
for _ in 0..3 {
|
||||
limiter.check(ip, key);
|
||||
}
|
||||
assert!(matches!(
|
||||
limiter.check(ip, key),
|
||||
RateLimitResult::Reject { .. }
|
||||
));
|
||||
|
||||
// Manually simulate time passing by manipulating the bucket
|
||||
{
|
||||
let idx = shard_index(&key);
|
||||
let mut shard = limiter.shards[idx].write().unwrap();
|
||||
let bucket = shard.get_mut(&key).unwrap();
|
||||
// Pretend 2 seconds passed (rate=1.0/s → 2 tokens refill)
|
||||
bucket.last_refill -= std::time::Duration::from_secs(2);
|
||||
}
|
||||
|
||||
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cidr_bypass() {
|
||||
let limiter = RateLimiter::new(&test_config());
|
||||
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||
let key = RateLimitKey::Ip(ip);
|
||||
|
||||
// Should always allow for bypassed CIDRs, even after many requests
|
||||
for _ in 0..100 {
|
||||
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_authenticated_higher_burst() {
|
||||
let limiter = RateLimiter::new(&test_config());
|
||||
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||
let key = RateLimitKey::Identity(12345);
|
||||
|
||||
// Authenticated burst is 10
|
||||
for _ in 0..10 {
|
||||
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||
}
|
||||
assert!(matches!(
|
||||
limiter.check(ip, key),
|
||||
RateLimitResult::Reject { .. }
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_independent_keys() {
|
||||
let limiter = RateLimiter::new(&test_config());
|
||||
let ip1: IpAddr = "203.0.113.1".parse().unwrap();
|
||||
let ip2: IpAddr = "203.0.113.2".parse().unwrap();
|
||||
|
||||
// Exhaust ip1's burst
|
||||
for _ in 0..3 {
|
||||
limiter.check(ip1, RateLimitKey::Ip(ip1));
|
||||
}
|
||||
assert!(matches!(
|
||||
limiter.check(ip1, RateLimitKey::Ip(ip1)),
|
||||
RateLimitResult::Reject { .. }
|
||||
));
|
||||
|
||||
// ip2 should still be allowed
|
||||
assert_eq!(
|
||||
limiter.check(ip2, RateLimitKey::Ip(ip2)),
|
||||
RateLimitResult::Allow
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_eviction() {
|
||||
let mut cfg = test_config();
|
||||
cfg.stale_after_secs = 0; // everything is stale immediately
|
||||
let limiter = RateLimiter::new(&cfg);
|
||||
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||
let key = RateLimitKey::Ip(ip);
|
||||
|
||||
limiter.check(ip, key);
|
||||
// Buckets should be populated
|
||||
let idx = shard_index(&key);
|
||||
assert!(!limiter.shards[idx].read().unwrap().is_empty());
|
||||
|
||||
// After eviction, buckets should be gone (stale_after_secs=0)
|
||||
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||
limiter.evict_stale();
|
||||
assert!(limiter.shards[idx].read().unwrap().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retry_after_value() {
|
||||
let limiter = RateLimiter::new(&test_config());
|
||||
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||
let key = RateLimitKey::Ip(ip);
|
||||
|
||||
// Exhaust burst (rate=1.0/s for unauth)
|
||||
for _ in 0..3 {
|
||||
limiter.check(ip, key);
|
||||
}
|
||||
match limiter.check(ip, key) {
|
||||
RateLimitResult::Reject { retry_after } => {
|
||||
// With rate=1.0, need ~1 token, so retry_after should be 1
|
||||
assert_eq!(retry_after, 1);
|
||||
}
|
||||
_ => panic!("expected reject"),
|
||||
}
|
||||
}
|
||||
}
|
||||
3
src/rate_limit/mod.rs
Normal file
3
src/rate_limit/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
||||
pub mod cidr;
|
||||
pub mod key;
|
||||
pub mod limiter;
|
||||
Reference in New Issue
Block a user