feat(rate_limit): add per-identity leaky bucket rate limiter
256-shard RwLock<FxHashMap> for concurrent access, auth key extraction (ory_kratos_session cookie > Bearer token > client IP), CIDR bypass for trusted networks, and background eviction of stale buckets. Signed-off-by: Sienna Meridian Satterwhite <sienna@sunbeam.pt>
This commit is contained in:
143
src/rate_limit/cidr.rs
Normal file
143
src/rate_limit/cidr.rs
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
use std::net::IpAddr;
|
||||||
|
|
||||||
|
/// A parsed CIDR block for allowlist matching.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CidrBlock {
|
||||||
|
addr: IpAddr,
|
||||||
|
prefix_len: u8,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CidrBlock {
|
||||||
|
/// Parse a CIDR string like "10.0.0.0/8" or "fd00::/8".
|
||||||
|
pub fn parse(s: &str) -> Option<Self> {
|
||||||
|
let (addr_str, len_str) = s.split_once('/')?;
|
||||||
|
let prefix_len: u8 = len_str.parse().ok()?;
|
||||||
|
let addr: IpAddr = addr_str.parse().ok()?;
|
||||||
|
match &addr {
|
||||||
|
IpAddr::V4(_) if prefix_len > 32 => return None,
|
||||||
|
IpAddr::V6(_) if prefix_len > 128 => return None,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
Some(Self { addr, prefix_len })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check whether `ip` falls within this CIDR block.
|
||||||
|
/// Handles IPv4-mapped IPv6 addresses (e.g. `::ffff:10.0.0.1`).
|
||||||
|
pub fn contains(&self, ip: IpAddr) -> bool {
|
||||||
|
// Normalise IPv4-mapped IPv6 → IPv4
|
||||||
|
let ip = normalise(ip);
|
||||||
|
let addr = normalise(self.addr);
|
||||||
|
|
||||||
|
match (addr, ip) {
|
||||||
|
(IpAddr::V4(net), IpAddr::V4(candidate)) => {
|
||||||
|
let mask = v4_mask(self.prefix_len);
|
||||||
|
u32::from(net) & mask == u32::from(candidate) & mask
|
||||||
|
}
|
||||||
|
(IpAddr::V6(net), IpAddr::V6(candidate)) => {
|
||||||
|
let mask = v6_mask(self.prefix_len);
|
||||||
|
let net_bits = u128::from(net);
|
||||||
|
let cand_bits = u128::from(candidate);
|
||||||
|
net_bits & mask == cand_bits & mask
|
||||||
|
}
|
||||||
|
_ => false, // v4 vs v6 mismatch
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a list of CIDR strings, skipping any that are invalid.
|
||||||
|
pub fn parse_cidrs(strings: &[String]) -> Vec<CidrBlock> {
|
||||||
|
strings.iter().filter_map(|s| CidrBlock::parse(s)).collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if `ip` is contained in any of the given CIDR blocks.
|
||||||
|
pub fn is_bypassed(ip: IpAddr, cidrs: &[CidrBlock]) -> bool {
|
||||||
|
cidrs.iter().any(|c| c.contains(ip))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalise(ip: IpAddr) -> IpAddr {
|
||||||
|
match ip {
|
||||||
|
IpAddr::V6(v6) => match v6.to_ipv4_mapped() {
|
||||||
|
Some(v4) => IpAddr::V4(v4),
|
||||||
|
None => IpAddr::V6(v6),
|
||||||
|
},
|
||||||
|
other => other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn v4_mask(prefix_len: u8) -> u32 {
|
||||||
|
if prefix_len == 0 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
u32::MAX << (32 - prefix_len)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn v6_mask(prefix_len: u8) -> u128 {
|
||||||
|
if prefix_len == 0 {
|
||||||
|
0
|
||||||
|
} else {
|
||||||
|
u128::MAX << (128 - prefix_len)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ipv4_contains() {
|
||||||
|
let cidr = CidrBlock::parse("10.0.0.0/8").unwrap();
|
||||||
|
assert!(cidr.contains("10.1.2.3".parse().unwrap()));
|
||||||
|
assert!(cidr.contains("10.255.255.255".parse().unwrap()));
|
||||||
|
assert!(!cidr.contains("11.0.0.1".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ipv6_contains() {
|
||||||
|
let cidr = CidrBlock::parse("fd00::/8").unwrap();
|
||||||
|
assert!(cidr.contains("fd12::1".parse().unwrap()));
|
||||||
|
assert!(!cidr.contains("fe80::1".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ipv4_mapped_v6() {
|
||||||
|
let cidr = CidrBlock::parse("10.0.0.0/8").unwrap();
|
||||||
|
// ::ffff:10.0.0.1 should match 10.0.0.0/8
|
||||||
|
let mapped: IpAddr = "::ffff:10.0.0.1".parse().unwrap();
|
||||||
|
assert!(cidr.contains(mapped));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_private_ranges() {
|
||||||
|
let cidrs = parse_cidrs(&[
|
||||||
|
"10.0.0.0/8".into(),
|
||||||
|
"172.16.0.0/12".into(),
|
||||||
|
"192.168.0.0/16".into(),
|
||||||
|
]);
|
||||||
|
assert!(is_bypassed("10.0.0.1".parse().unwrap(), &cidrs));
|
||||||
|
assert!(is_bypassed("172.31.255.1".parse().unwrap(), &cidrs));
|
||||||
|
assert!(is_bypassed("192.168.1.1".parse().unwrap(), &cidrs));
|
||||||
|
assert!(!is_bypassed("8.8.8.8".parse().unwrap(), &cidrs));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_invalid_cidrs() {
|
||||||
|
assert!(CidrBlock::parse("not-a-cidr").is_none());
|
||||||
|
assert!(CidrBlock::parse("10.0.0.0/33").is_none());
|
||||||
|
assert!(CidrBlock::parse("10.0.0.0").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_slash_zero() {
|
||||||
|
let cidr = CidrBlock::parse("0.0.0.0/0").unwrap();
|
||||||
|
assert!(cidr.contains("1.2.3.4".parse().unwrap()));
|
||||||
|
assert!(cidr.contains("255.255.255.255".parse().unwrap()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_slash_32() {
|
||||||
|
let cidr = CidrBlock::parse("1.2.3.4/32").unwrap();
|
||||||
|
assert!(cidr.contains("1.2.3.4".parse().unwrap()));
|
||||||
|
assert!(!cidr.contains("1.2.3.5".parse().unwrap()));
|
||||||
|
}
|
||||||
|
}
|
||||||
145
src/rate_limit/key.rs
Normal file
145
src/rate_limit/key.rs
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::net::IpAddr;
|
||||||
|
|
||||||
|
/// Identity key for rate limiting buckets.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
|
||||||
|
pub enum RateLimitKey {
|
||||||
|
/// Hashed auth credential (session cookie or bearer token).
|
||||||
|
Identity(u64),
|
||||||
|
/// Unauthenticated fallback: client IP.
|
||||||
|
Ip(IpAddr),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimitKey {
|
||||||
|
pub fn is_authenticated(&self) -> bool {
|
||||||
|
matches!(self, RateLimitKey::Identity(_))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract a rate-limit key from request headers.
|
||||||
|
///
|
||||||
|
/// Priority:
|
||||||
|
/// 1. `ory_kratos_session` cookie → `Identity(hash)`
|
||||||
|
/// 2. `Authorization: Bearer <token>` → `Identity(hash)`
|
||||||
|
/// 3. Client IP → `Ip(addr)`
|
||||||
|
pub fn extract_key(
|
||||||
|
cookie_header: Option<&str>,
|
||||||
|
auth_header: Option<&str>,
|
||||||
|
client_ip: IpAddr,
|
||||||
|
) -> RateLimitKey {
|
||||||
|
// 1. Check for Kratos session cookie
|
||||||
|
if let Some(cookies) = cookie_header {
|
||||||
|
if let Some(value) = extract_cookie_value(cookies, "ory_kratos_session") {
|
||||||
|
return RateLimitKey::Identity(fx_hash(value));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check for Bearer token
|
||||||
|
if let Some(auth) = auth_header {
|
||||||
|
if let Some(token) = auth.strip_prefix("Bearer ") {
|
||||||
|
let token = token.trim();
|
||||||
|
if !token.is_empty() {
|
||||||
|
return RateLimitKey::Identity(fx_hash(token));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Fall back to IP
|
||||||
|
RateLimitKey::Ip(client_ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn extract_cookie_value<'a>(cookies: &'a str, name: &str) -> Option<&'a str> {
|
||||||
|
for pair in cookies.split(';') {
|
||||||
|
let pair = pair.trim();
|
||||||
|
if let Some((k, v)) = pair.split_once('=') {
|
||||||
|
if k.trim() == name {
|
||||||
|
return Some(v.trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fx_hash(s: &str) -> u64 {
|
||||||
|
let mut h = rustc_hash::FxHasher::default();
|
||||||
|
s.hash(&mut h);
|
||||||
|
h.finish()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cookie_extraction() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = extract_key(
|
||||||
|
Some("ory_kratos_session=abc123; other=val"),
|
||||||
|
None,
|
||||||
|
ip,
|
||||||
|
);
|
||||||
|
assert!(key.is_authenticated());
|
||||||
|
assert!(matches!(key, RateLimitKey::Identity(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cookie_only_cookie() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = extract_key(Some("ory_kratos_session=tok"), None, ip);
|
||||||
|
assert!(key.is_authenticated());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_bearer_extraction() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = extract_key(None, Some("Bearer my-token-123"), ip);
|
||||||
|
assert!(key.is_authenticated());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ip_fallback() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = extract_key(None, None, ip);
|
||||||
|
assert_eq!(key, RateLimitKey::Ip(ip));
|
||||||
|
assert!(!key.is_authenticated());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cookie_takes_priority_over_bearer() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key_cookie = extract_key(
|
||||||
|
Some("ory_kratos_session=sess1"),
|
||||||
|
Some("Bearer tok1"),
|
||||||
|
ip,
|
||||||
|
);
|
||||||
|
let key_bearer = extract_key(None, Some("Bearer tok1"), ip);
|
||||||
|
// Cookie and bearer should produce different hashes
|
||||||
|
assert_ne!(key_cookie, key_bearer);
|
||||||
|
assert!(key_cookie.is_authenticated());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_empty_bearer_falls_back() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = extract_key(None, Some("Bearer "), ip);
|
||||||
|
assert_eq!(key, RateLimitKey::Ip(ip));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_multiple_cookies() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = extract_key(
|
||||||
|
Some("foo=bar; ory_kratos_session=mysess; baz=qux"),
|
||||||
|
None,
|
||||||
|
ip,
|
||||||
|
);
|
||||||
|
assert!(key.is_authenticated());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_wrong_cookie_name_falls_back() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = extract_key(Some("session=val"), None, ip);
|
||||||
|
assert_eq!(key, RateLimitKey::Ip(ip));
|
||||||
|
}
|
||||||
|
}
|
||||||
264
src/rate_limit/limiter.rs
Normal file
264
src/rate_limit/limiter.rs
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
use crate::config::{BucketConfig, RateLimitConfig};
|
||||||
|
use crate::rate_limit::cidr::{self, CidrBlock};
|
||||||
|
use crate::rate_limit::key::RateLimitKey;
|
||||||
|
use rustc_hash::FxHashMap;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
use std::net::IpAddr;
|
||||||
|
use std::sync::RwLock;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
const NUM_SHARDS: usize = 256;
|
||||||
|
|
||||||
|
/// Result of a rate limit check.
|
||||||
|
#[derive(Debug, PartialEq)]
|
||||||
|
pub enum RateLimitResult {
|
||||||
|
Allow,
|
||||||
|
Reject { retry_after: u64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
struct Bucket {
|
||||||
|
tokens: f64,
|
||||||
|
last_refill: Instant,
|
||||||
|
authenticated: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct RateLimiter {
|
||||||
|
shards: Vec<RwLock<FxHashMap<RateLimitKey, Bucket>>>,
|
||||||
|
bypass_cidrs: Vec<CidrBlock>,
|
||||||
|
authenticated: BucketConfig,
|
||||||
|
unauthenticated: BucketConfig,
|
||||||
|
stale_after_secs: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn shard_index(key: &RateLimitKey) -> usize {
|
||||||
|
let mut h = rustc_hash::FxHasher::default();
|
||||||
|
key.hash(&mut h);
|
||||||
|
h.finish() as usize % NUM_SHARDS
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateLimiter {
|
||||||
|
pub fn new(config: &RateLimitConfig) -> Self {
|
||||||
|
let shards = (0..NUM_SHARDS)
|
||||||
|
.map(|_| RwLock::new(FxHashMap::default()))
|
||||||
|
.collect();
|
||||||
|
let bypass_cidrs = cidr::parse_cidrs(&config.bypass_cidrs);
|
||||||
|
Self {
|
||||||
|
shards,
|
||||||
|
bypass_cidrs,
|
||||||
|
authenticated: config.authenticated.clone(),
|
||||||
|
unauthenticated: config.unauthenticated.clone(),
|
||||||
|
stale_after_secs: config.stale_after_secs,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check whether a request should be allowed or rejected.
|
||||||
|
pub fn check(&self, ip: IpAddr, key: RateLimitKey) -> RateLimitResult {
|
||||||
|
// CIDR bypass
|
||||||
|
if cidr::is_bypassed(ip, &self.bypass_cidrs) {
|
||||||
|
return RateLimitResult::Allow;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cfg = if key.is_authenticated() {
|
||||||
|
&self.authenticated
|
||||||
|
} else {
|
||||||
|
&self.unauthenticated
|
||||||
|
};
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let idx = shard_index(&key);
|
||||||
|
let mut shard = self.shards[idx].write().unwrap_or_else(|e| e.into_inner());
|
||||||
|
|
||||||
|
let bucket = shard.entry(key).or_insert_with(|| Bucket {
|
||||||
|
tokens: cfg.burst as f64,
|
||||||
|
last_refill: now,
|
||||||
|
authenticated: key.is_authenticated(),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Refill tokens based on elapsed time
|
||||||
|
let elapsed = now.duration_since(bucket.last_refill).as_secs_f64();
|
||||||
|
let tier = if bucket.authenticated {
|
||||||
|
&self.authenticated
|
||||||
|
} else {
|
||||||
|
&self.unauthenticated
|
||||||
|
};
|
||||||
|
bucket.tokens = (bucket.tokens + elapsed * tier.rate).min(tier.burst as f64);
|
||||||
|
bucket.last_refill = now;
|
||||||
|
|
||||||
|
if bucket.tokens >= 1.0 {
|
||||||
|
bucket.tokens -= 1.0;
|
||||||
|
RateLimitResult::Allow
|
||||||
|
} else {
|
||||||
|
let retry_after = ((1.0 - bucket.tokens) / tier.rate).ceil() as u64;
|
||||||
|
RateLimitResult::Reject {
|
||||||
|
retry_after: retry_after.max(1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove buckets that haven't been used for `stale_after_secs`.
|
||||||
|
pub fn evict_stale(&self) {
|
||||||
|
let now = Instant::now();
|
||||||
|
let stale = std::time::Duration::from_secs(self.stale_after_secs);
|
||||||
|
for shard in &self.shards {
|
||||||
|
let mut map = shard.write().unwrap_or_else(|e| e.into_inner());
|
||||||
|
map.retain(|_, bucket| now.duration_since(bucket.last_refill) < stale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::{BucketConfig, RateLimitConfig};
|
||||||
|
|
||||||
|
fn test_config() -> RateLimitConfig {
|
||||||
|
RateLimitConfig {
|
||||||
|
enabled: true,
|
||||||
|
bypass_cidrs: vec!["10.0.0.0/8".into()],
|
||||||
|
eviction_interval_secs: 60,
|
||||||
|
stale_after_secs: 120,
|
||||||
|
authenticated: BucketConfig {
|
||||||
|
burst: 10,
|
||||||
|
rate: 5.0,
|
||||||
|
},
|
||||||
|
unauthenticated: BucketConfig {
|
||||||
|
burst: 3,
|
||||||
|
rate: 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_burst_exhaustion_then_reject() {
|
||||||
|
let limiter = RateLimiter::new(&test_config());
|
||||||
|
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Ip(ip);
|
||||||
|
|
||||||
|
// Burst of 3 for unauthenticated
|
||||||
|
for _ in 0..3 {
|
||||||
|
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||||
|
}
|
||||||
|
// 4th request should be rejected
|
||||||
|
match limiter.check(ip, key) {
|
||||||
|
RateLimitResult::Reject { retry_after } => {
|
||||||
|
assert!(retry_after >= 1);
|
||||||
|
}
|
||||||
|
_ => panic!("expected reject"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_refill_allows_again() {
|
||||||
|
let limiter = RateLimiter::new(&test_config());
|
||||||
|
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Ip(ip);
|
||||||
|
|
||||||
|
// Exhaust burst
|
||||||
|
for _ in 0..3 {
|
||||||
|
limiter.check(ip, key);
|
||||||
|
}
|
||||||
|
assert!(matches!(
|
||||||
|
limiter.check(ip, key),
|
||||||
|
RateLimitResult::Reject { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
// Manually simulate time passing by manipulating the bucket
|
||||||
|
{
|
||||||
|
let idx = shard_index(&key);
|
||||||
|
let mut shard = limiter.shards[idx].write().unwrap();
|
||||||
|
let bucket = shard.get_mut(&key).unwrap();
|
||||||
|
// Pretend 2 seconds passed (rate=1.0/s → 2 tokens refill)
|
||||||
|
bucket.last_refill -= std::time::Duration::from_secs(2);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cidr_bypass() {
|
||||||
|
let limiter = RateLimiter::new(&test_config());
|
||||||
|
let ip: IpAddr = "10.0.0.1".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Ip(ip);
|
||||||
|
|
||||||
|
// Should always allow for bypassed CIDRs, even after many requests
|
||||||
|
for _ in 0..100 {
|
||||||
|
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_authenticated_higher_burst() {
|
||||||
|
let limiter = RateLimiter::new(&test_config());
|
||||||
|
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Identity(12345);
|
||||||
|
|
||||||
|
// Authenticated burst is 10
|
||||||
|
for _ in 0..10 {
|
||||||
|
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||||
|
}
|
||||||
|
assert!(matches!(
|
||||||
|
limiter.check(ip, key),
|
||||||
|
RateLimitResult::Reject { .. }
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_independent_keys() {
|
||||||
|
let limiter = RateLimiter::new(&test_config());
|
||||||
|
let ip1: IpAddr = "203.0.113.1".parse().unwrap();
|
||||||
|
let ip2: IpAddr = "203.0.113.2".parse().unwrap();
|
||||||
|
|
||||||
|
// Exhaust ip1's burst
|
||||||
|
for _ in 0..3 {
|
||||||
|
limiter.check(ip1, RateLimitKey::Ip(ip1));
|
||||||
|
}
|
||||||
|
assert!(matches!(
|
||||||
|
limiter.check(ip1, RateLimitKey::Ip(ip1)),
|
||||||
|
RateLimitResult::Reject { .. }
|
||||||
|
));
|
||||||
|
|
||||||
|
// ip2 should still be allowed
|
||||||
|
assert_eq!(
|
||||||
|
limiter.check(ip2, RateLimitKey::Ip(ip2)),
|
||||||
|
RateLimitResult::Allow
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_eviction() {
|
||||||
|
let mut cfg = test_config();
|
||||||
|
cfg.stale_after_secs = 0; // everything is stale immediately
|
||||||
|
let limiter = RateLimiter::new(&cfg);
|
||||||
|
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Ip(ip);
|
||||||
|
|
||||||
|
limiter.check(ip, key);
|
||||||
|
// Buckets should be populated
|
||||||
|
let idx = shard_index(&key);
|
||||||
|
assert!(!limiter.shards[idx].read().unwrap().is_empty());
|
||||||
|
|
||||||
|
// After eviction, buckets should be gone (stale_after_secs=0)
|
||||||
|
std::thread::sleep(std::time::Duration::from_millis(10));
|
||||||
|
limiter.evict_stale();
|
||||||
|
assert!(limiter.shards[idx].read().unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_value() {
|
||||||
|
let limiter = RateLimiter::new(&test_config());
|
||||||
|
let ip: IpAddr = "203.0.113.1".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Ip(ip);
|
||||||
|
|
||||||
|
// Exhaust burst (rate=1.0/s for unauth)
|
||||||
|
for _ in 0..3 {
|
||||||
|
limiter.check(ip, key);
|
||||||
|
}
|
||||||
|
match limiter.check(ip, key) {
|
||||||
|
RateLimitResult::Reject { retry_after } => {
|
||||||
|
// With rate=1.0, need ~1 token, so retry_after should be 1
|
||||||
|
assert_eq!(retry_after, 1);
|
||||||
|
}
|
||||||
|
_ => panic!("expected reject"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
3
src/rate_limit/mod.rs
Normal file
3
src/rate_limit/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
pub mod cidr;
|
||||||
|
pub mod key;
|
||||||
|
pub mod limiter;
|
||||||
118
tests/rate_limit_test.rs
Normal file
118
tests/rate_limit_test.rs
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
use std::net::IpAddr;
|
||||||
|
use sunbeam_proxy::config::{BucketConfig, RateLimitConfig};
|
||||||
|
use sunbeam_proxy::rate_limit::key::{self, RateLimitKey};
|
||||||
|
use sunbeam_proxy::rate_limit::limiter::{RateLimitResult, RateLimiter};
|
||||||
|
|
||||||
|
fn default_config() -> RateLimitConfig {
|
||||||
|
RateLimitConfig {
|
||||||
|
enabled: true,
|
||||||
|
bypass_cidrs: vec!["10.0.0.0/8".into(), "fd00::/8".into()],
|
||||||
|
eviction_interval_secs: 300,
|
||||||
|
stale_after_secs: 600,
|
||||||
|
authenticated: BucketConfig {
|
||||||
|
burst: 20,
|
||||||
|
rate: 10.0,
|
||||||
|
},
|
||||||
|
unauthenticated: BucketConfig {
|
||||||
|
burst: 5,
|
||||||
|
rate: 2.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_unauthenticated_burst_limit() {
|
||||||
|
let limiter = RateLimiter::new(&default_config());
|
||||||
|
let ip: IpAddr = "203.0.113.50".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Ip(ip);
|
||||||
|
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert_eq!(limiter.check(ip, key), RateLimitResult::Allow);
|
||||||
|
}
|
||||||
|
match limiter.check(ip, key) {
|
||||||
|
RateLimitResult::Reject { retry_after } => assert!(retry_after >= 1),
|
||||||
|
_ => panic!("expected reject after burst exhaustion"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_authenticated_gets_higher_burst() {
|
||||||
|
let limiter = RateLimiter::new(&default_config());
|
||||||
|
let ip: IpAddr = "203.0.113.51".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Identity(99999);
|
||||||
|
|
||||||
|
for i in 0..20 {
|
||||||
|
assert_eq!(
|
||||||
|
limiter.check(ip, key),
|
||||||
|
RateLimitResult::Allow,
|
||||||
|
"request {i} should be allowed"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
assert!(matches!(
|
||||||
|
limiter.check(ip, key),
|
||||||
|
RateLimitResult::Reject { .. }
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_cluster_cidr_bypass() {
|
||||||
|
let limiter = RateLimiter::new(&default_config());
|
||||||
|
let internal_ip: IpAddr = "10.244.1.5".parse().unwrap();
|
||||||
|
let key = RateLimitKey::Ip(internal_ip);
|
||||||
|
|
||||||
|
// Should never be rate-limited
|
||||||
|
for _ in 0..200 {
|
||||||
|
assert_eq!(limiter.check(internal_ip, key), RateLimitResult::Allow);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_key_extraction_cookie() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = key::extract_key(
|
||||||
|
Some("_csrf=abc; ory_kratos_session=session123; lang=en"),
|
||||||
|
None,
|
||||||
|
ip,
|
||||||
|
);
|
||||||
|
assert!(key.is_authenticated());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_key_extraction_bearer() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = key::extract_key(None, Some("Bearer gitea_pat_abc123"), ip);
|
||||||
|
assert!(key.is_authenticated());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_key_extraction_ip_fallback() {
|
||||||
|
let ip: IpAddr = "1.2.3.4".parse().unwrap();
|
||||||
|
let key = key::extract_key(None, None, ip);
|
||||||
|
assert_eq!(key, RateLimitKey::Ip(ip));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_concurrent_different_ips() {
|
||||||
|
let limiter = std::sync::Arc::new(RateLimiter::new(&default_config()));
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
for i in 0..10u8 {
|
||||||
|
let limiter = limiter.clone();
|
||||||
|
handles.push(std::thread::spawn(move || {
|
||||||
|
let ip: IpAddr = format!("203.0.113.{i}").parse().unwrap();
|
||||||
|
let key = RateLimitKey::Ip(ip);
|
||||||
|
let mut allowed = 0;
|
||||||
|
for _ in 0..10 {
|
||||||
|
if limiter.check(ip, key) == RateLimitResult::Allow {
|
||||||
|
allowed += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Each IP gets burst=5, so exactly 5 should be allowed
|
||||||
|
assert_eq!(allowed, 5, "IP 203.0.113.{i} allowed {allowed}, expected 5");
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
for h in handles {
|
||||||
|
h.join().unwrap();
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user