diff --git a/src/auth.rs b/src/auth.rs new file mode 100644 index 0000000..4b8368b --- /dev/null +++ b/src/auth.rs @@ -0,0 +1,769 @@ +//! OAuth2 Authorization Code flow with PKCE for CLI authentication against Hydra. + +use crate::error::{Result, ResultExt, SunbeamError}; +use base64::Engine; +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; +use std::path::PathBuf; + +// --------------------------------------------------------------------------- +// Token cache data +// --------------------------------------------------------------------------- + +/// Cached OAuth2 tokens persisted to disk. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AuthTokens { + pub access_token: String, + pub refresh_token: String, + pub expires_at: DateTime, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub id_token: Option, + pub domain: String, +} + +/// Default client ID when the K8s secret is unavailable. +const DEFAULT_CLIENT_ID: &str = "sunbeam-cli"; + +// --------------------------------------------------------------------------- +// Cache file helpers +// --------------------------------------------------------------------------- + +fn cache_path() -> PathBuf { + dirs::data_dir() + .unwrap_or_else(|| { + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from(".")) + .join(".local/share") + }) + .join("sunbeam") + .join("auth.json") +} + +fn read_cache() -> Result { + let path = cache_path(); + let content = std::fs::read_to_string(&path).map_err(|e| { + SunbeamError::Identity(format!("No cached auth tokens ({}): {e}", path.display())) + })?; + let tokens: AuthTokens = serde_json::from_str(&content) + .ctx("Failed to parse cached auth tokens")?; + Ok(tokens) +} + +fn write_cache(tokens: &AuthTokens) -> Result<()> { + let path = cache_path(); + if let Some(parent) = path.parent() { + std::fs::create_dir_all(parent) + .with_ctx(|| format!("Failed to create auth cache dir: {}", parent.display()))?; + } + let content = serde_json::to_string_pretty(tokens)?; + std::fs::write(&path, &content) + .with_ctx(|| format!("Failed to write auth cache to {}", path.display()))?; + + // Set 0600 permissions on unix + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let perms = std::fs::Permissions::from_mode(0o600); + std::fs::set_permissions(&path, perms) + .with_ctx(|| format!("Failed to set permissions on {}", path.display()))?; + } + + Ok(()) +} + +// --------------------------------------------------------------------------- +// PKCE +// --------------------------------------------------------------------------- + +/// Generate a PKCE code_verifier and code_challenge (S256). +fn generate_pkce() -> (String, String) { + let verifier_bytes: [u8; 32] = rand::random(); + let verifier = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(verifier_bytes); + let challenge = { + let hash = Sha256::digest(verifier.as_bytes()); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hash) + }; + (verifier, challenge) +} + +/// Generate a random state parameter for OAuth2. +fn generate_state() -> String { + let bytes: [u8; 16] = rand::random(); + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes) +} + +// --------------------------------------------------------------------------- +// OIDC discovery +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +struct OidcDiscovery { + authorization_endpoint: String, + token_endpoint: String, +} + +async fn discover_oidc(domain: &str) -> Result { + let url = format!("https://auth.{domain}/.well-known/openid-configuration"); + let client = reqwest::Client::new(); + let resp = client + .get(&url) + .send() + .await + .with_ctx(|| format!("Failed to fetch OIDC discovery from {url}"))?; + + if !resp.status().is_success() { + return Err(SunbeamError::network(format!( + "OIDC discovery returned HTTP {}", + resp.status() + ))); + } + + let discovery: OidcDiscovery = resp + .json() + .await + .ctx("Failed to parse OIDC discovery response")?; + Ok(discovery) +} + +// --------------------------------------------------------------------------- +// Token exchange / refresh +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +struct TokenResponse { + access_token: String, + #[serde(default)] + refresh_token: Option, + #[serde(default)] + expires_in: Option, + #[serde(default)] + id_token: Option, +} + +async fn exchange_code( + token_endpoint: &str, + code: &str, + redirect_uri: &str, + client_id: &str, + code_verifier: &str, +) -> Result { + let client = reqwest::Client::new(); + let resp = client + .post(token_endpoint) + .form(&[ + ("grant_type", "authorization_code"), + ("code", code), + ("redirect_uri", redirect_uri), + ("client_id", client_id), + ("code_verifier", code_verifier), + ]) + .send() + .await + .ctx("Failed to exchange authorization code")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(SunbeamError::identity(format!( + "Token exchange failed (HTTP {status}): {body}" + ))); + } + + let token_resp: TokenResponse = resp.json().await.ctx("Failed to parse token response")?; + Ok(token_resp) +} + +/// Refresh an access token using a refresh token. +async fn refresh_token(cached: &AuthTokens) -> Result { + let discovery = discover_oidc(&cached.domain).await?; + + // Try to get client_id from K8s, fall back to default + let client_id = resolve_client_id().await; + + let client = reqwest::Client::new(); + let resp = client + .post(&discovery.token_endpoint) + .form(&[ + ("grant_type", "refresh_token"), + ("refresh_token", &cached.refresh_token), + ("client_id", &client_id), + ]) + .send() + .await + .ctx("Failed to refresh token")?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().await.unwrap_or_default(); + return Err(SunbeamError::identity(format!( + "Token refresh failed (HTTP {status}): {body}" + ))); + } + + let token_resp: TokenResponse = resp + .json() + .await + .ctx("Failed to parse refresh token response")?; + + let expires_at = Utc::now() + + chrono::Duration::seconds(token_resp.expires_in.unwrap_or(3600)); + + let new_tokens = AuthTokens { + access_token: token_resp.access_token, + refresh_token: token_resp + .refresh_token + .unwrap_or_else(|| cached.refresh_token.clone()), + expires_at, + id_token: token_resp.id_token.or_else(|| cached.id_token.clone()), + domain: cached.domain.clone(), + }; + + write_cache(&new_tokens)?; + Ok(new_tokens) +} + +// --------------------------------------------------------------------------- +// Client ID resolution +// --------------------------------------------------------------------------- + +/// Try to read the client_id from K8s secret `oidc-sunbeam-cli` in `ory` namespace. +/// Falls back to the default client ID. +async fn resolve_client_id() -> String { + match crate::kube::kube_get_secret_field("ory", "oidc-sunbeam-cli", "client_id").await { + Ok(id) if !id.is_empty() => id, + _ => DEFAULT_CLIENT_ID.to_string(), + } +} + +// --------------------------------------------------------------------------- +// JWT payload decoding (minimal, no verification) +// --------------------------------------------------------------------------- + +/// Decode the payload of a JWT (middle segment) without verification. +/// Returns the parsed JSON value. +fn decode_jwt_payload(token: &str) -> Result { + let parts: Vec<&str> = token.splitn(3, '.').collect(); + if parts.len() < 2 { + return Err(SunbeamError::identity("Invalid JWT: not enough segments")); + } + let payload_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD + .decode(parts[1]) + .ctx("Failed to base64-decode JWT payload")?; + let payload: serde_json::Value = + serde_json::from_slice(&payload_bytes).ctx("Failed to parse JWT payload as JSON")?; + Ok(payload) +} + +/// Extract the email claim from an id_token. +fn extract_email(id_token: &str) -> Option { + let payload = decode_jwt_payload(id_token).ok()?; + payload + .get("email") + .and_then(|v| v.as_str()) + .map(|s| s.to_string()) +} + +// --------------------------------------------------------------------------- +// HTTP callback server +// --------------------------------------------------------------------------- + +/// Parsed callback parameters from the OAuth2 redirect. +struct CallbackParams { + code: String, + #[allow(dead_code)] + state: String, +} + +/// Bind a TCP listener for the OAuth2 callback, preferring ports 9876-9880. +async fn bind_callback_listener() -> Result<(tokio::net::TcpListener, u16)> { + for port in 9876..=9880 { + if let Ok(listener) = tokio::net::TcpListener::bind(("127.0.0.1", port)).await { + return Ok((listener, port)); + } + } + // Fall back to ephemeral port + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .ctx("Failed to bind callback listener")?; + let port = listener.local_addr().ctx("No local address")?.port(); + Ok((listener, port)) +} + +/// Wait for a single HTTP callback request, extract code and state, send HTML response. +async fn wait_for_callback( + listener: tokio::net::TcpListener, + expected_state: &str, +) -> Result { + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let (mut stream, _) = listener.accept().await.ctx("Failed to accept callback connection")?; + + let mut buf = vec![0u8; 4096]; + let n = stream + .read(&mut buf) + .await + .ctx("Failed to read callback request")?; + let request = String::from_utf8_lossy(&buf[..n]); + + // Parse the GET request line: "GET /callback?code=...&state=... HTTP/1.1" + let request_line = request + .lines() + .next() + .ctx("Empty callback request")?; + + let path = request_line + .split_whitespace() + .nth(1) + .ctx("No path in callback request")?; + + // Parse query params + let query = path + .split('?') + .nth(1) + .ctx("No query params in callback")?; + + let mut code = None; + let mut state = None; + + for param in query.split('&') { + let mut kv = param.splitn(2, '='); + match (kv.next(), kv.next()) { + (Some("code"), Some(v)) => code = Some(v.to_string()), + (Some("state"), Some(v)) => state = Some(v.to_string()), + _ => {} + } + } + + let code = code.ok_or_else(|| SunbeamError::identity("No 'code' in callback"))?; + let state = state.ok_or_else(|| SunbeamError::identity("No 'state' in callback"))?; + + if state != expected_state { + return Err(SunbeamError::identity( + "OAuth2 state mismatch -- possible CSRF attack", + )); + } + + // Send success response + let html = concat!( + "", + "

Authentication successful

", + "

You can close this tab and return to the terminal.

", + "" + ); + let response = format!( + "HTTP/1.1 200 OK\r\nContent-Type: text/html\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + html.len(), + html + ); + let _ = stream.write_all(response.as_bytes()).await; + let _ = stream.shutdown().await; + + Ok(CallbackParams { code, state }) +} + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +/// Get a valid access token, refreshing if needed. +/// +/// Returns the access token string ready for use in Authorization headers. +/// If no cached token exists or refresh fails, returns an error prompting +/// the user to run `sunbeam auth login`. +pub async fn get_token() -> Result { + let cached = match read_cache() { + Ok(tokens) => tokens, + Err(_) => { + return Err(SunbeamError::identity( + "Not logged in. Run `sunbeam auth login` to authenticate.", + )); + } + }; + + // Check if access token is still valid (>60s remaining) + let now = Utc::now(); + if cached.expires_at > now + chrono::Duration::seconds(60) { + return Ok(cached.access_token); + } + + // Try to refresh + if !cached.refresh_token.is_empty() { + match refresh_token(&cached).await { + Ok(new_tokens) => return Ok(new_tokens.access_token), + Err(e) => { + crate::output::warn(&format!("Token refresh failed: {e}")); + } + } + } + + Err(SunbeamError::identity( + "Session expired. Run `sunbeam auth login` to re-authenticate.", + )) +} + +/// Interactive browser-based OAuth2 login. +pub async fn cmd_auth_login() -> Result<()> { + crate::output::step("Authenticating with Hydra"); + + // Resolve domain + let config = crate::config::load_config(); + let domain = if !config.production_host.is_empty() { + // Extract domain from production host if available + let host = &config.production_host; + let raw = host.split('@').last().unwrap_or(host); + let raw = raw.split(':').next().unwrap_or(raw); + // If it looks like an IP or hostname, try to get domain from cluster + if raw.contains('.') && !raw.chars().next().unwrap_or('0').is_ascii_digit() { + raw.to_string() + } else { + crate::kube::get_domain().await? + } + } else { + crate::kube::get_domain().await? + }; + + crate::output::ok(&format!("Domain: {domain}")); + + // OIDC discovery + let discovery = discover_oidc(&domain).await?; + + // Resolve client_id + let client_id = resolve_client_id().await; + + // Generate PKCE + let (code_verifier, code_challenge) = generate_pkce(); + + // Generate state + let state = generate_state(); + + // Bind callback listener + let (listener, port) = bind_callback_listener().await?; + let redirect_uri = format!("http://localhost:{port}/callback"); + + // Build authorization URL + let auth_url = format!( + "{}?client_id={}&redirect_uri={}&response_type=code&scope={}&code_challenge={}&code_challenge_method=S256&state={}", + discovery.authorization_endpoint, + urlencoding(&client_id), + urlencoding(&redirect_uri), + "openid+email+profile+offline_access", + code_challenge, + state, + ); + + crate::output::ok("Opening browser for login..."); + println!("\n {auth_url}\n"); + + // Try to open the browser + let _open_result = open_browser(&auth_url); + + // Wait for callback + crate::output::ok("Waiting for authentication callback..."); + let callback = wait_for_callback(listener, &state).await?; + + // Exchange code for tokens + crate::output::ok("Exchanging authorization code for tokens..."); + let token_resp = exchange_code( + &discovery.token_endpoint, + &callback.code, + &redirect_uri, + &client_id, + &code_verifier, + ) + .await?; + + let expires_at = Utc::now() + + chrono::Duration::seconds(token_resp.expires_in.unwrap_or(3600)); + + let tokens = AuthTokens { + access_token: token_resp.access_token, + refresh_token: token_resp.refresh_token.unwrap_or_default(), + expires_at, + id_token: token_resp.id_token.clone(), + domain: domain.clone(), + }; + + write_cache(&tokens)?; + + // Print success with email if available + if let Some(ref id_token) = tokens.id_token { + if let Some(email) = extract_email(id_token) { + crate::output::ok(&format!("Logged in as {email}")); + } else { + crate::output::ok("Logged in successfully"); + } + } else { + crate::output::ok("Logged in successfully"); + } + + Ok(()) +} + +/// Remove cached auth tokens. +pub async fn cmd_auth_logout() -> Result<()> { + let path = cache_path(); + if path.exists() { + std::fs::remove_file(&path) + .with_ctx(|| format!("Failed to remove {}", path.display()))?; + crate::output::ok("Logged out (cached tokens removed)"); + } else { + crate::output::ok("Not logged in (no cached tokens to remove)"); + } + Ok(()) +} + +/// Print current auth status. +pub async fn cmd_auth_status() -> Result<()> { + match read_cache() { + Ok(tokens) => { + let now = Utc::now(); + let expired = tokens.expires_at <= now; + + // Try to get email from id_token + let identity = tokens + .id_token + .as_deref() + .and_then(extract_email) + .unwrap_or_else(|| "unknown".to_string()); + + if expired { + crate::output::ok(&format!( + "Logged in as {identity} (token expired at {})", + tokens.expires_at.format("%Y-%m-%d %H:%M:%S UTC") + )); + if !tokens.refresh_token.is_empty() { + crate::output::ok("Token can be refreshed automatically on next use"); + } + } else { + crate::output::ok(&format!( + "Logged in as {identity} (token valid until {})", + tokens.expires_at.format("%Y-%m-%d %H:%M:%S UTC") + )); + } + crate::output::ok(&format!("Domain: {}", tokens.domain)); + } + Err(_) => { + crate::output::ok("Not logged in. Run `sunbeam auth login` to authenticate."); + } + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// Utility helpers +// --------------------------------------------------------------------------- + +/// Minimal percent-encoding for URL query parameters. +fn urlencoding(s: &str) -> String { + let mut out = String::with_capacity(s.len()); + for b in s.bytes() { + match b { + b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { + out.push(b as char); + } + _ => { + out.push_str(&format!("%{:02X}", b)); + } + } + } + out +} + +/// Try to open a URL in the default browser. +fn open_browser(url: &str) -> std::result::Result<(), std::io::Error> { + #[cfg(target_os = "macos")] + { + std::process::Command::new("open").arg(url).spawn()?; + } + #[cfg(target_os = "linux")] + { + std::process::Command::new("xdg-open").arg(url).spawn()?; + } + #[cfg(not(any(target_os = "macos", target_os = "linux")))] + { + let _ = url; + // No-op on unsupported platforms; URL is printed to the terminal. + } + Ok(()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + use chrono::Duration; + + #[test] + fn test_pkce_generation() { + let (verifier, challenge) = generate_pkce(); + + // Verifier should be base64url-encoded 32 bytes -> 43 chars + assert_eq!(verifier.len(), 43); + + // Challenge should be base64url-encoded SHA256 -> 43 chars + assert_eq!(challenge.len(), 43); + + // Verify the challenge matches the verifier + let expected_hash = Sha256::digest(verifier.as_bytes()); + let expected_challenge = + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(expected_hash); + assert_eq!(challenge, expected_challenge); + + // Two calls should produce different values + let (v2, c2) = generate_pkce(); + assert_ne!(verifier, v2); + assert_ne!(challenge, c2); + } + + #[test] + fn test_token_cache_roundtrip() { + let tokens = AuthTokens { + access_token: "access_abc".to_string(), + refresh_token: "refresh_xyz".to_string(), + expires_at: Utc::now() + Duration::hours(1), + id_token: Some("eyJhbGciOiJSUzI1NiJ9.eyJlbWFpbCI6InRlc3RAZXhhbXBsZS5jb20ifQ.sig".to_string()), + domain: "sunbeam.pt".to_string(), + }; + + let json = serde_json::to_string_pretty(&tokens).unwrap(); + let deserialized: AuthTokens = serde_json::from_str(&json).unwrap(); + + assert_eq!(deserialized.access_token, "access_abc"); + assert_eq!(deserialized.refresh_token, "refresh_xyz"); + assert_eq!(deserialized.domain, "sunbeam.pt"); + assert!(deserialized.id_token.is_some()); + + // Verify expires_at survives roundtrip (within 1 second tolerance) + let diff = (deserialized.expires_at - tokens.expires_at) + .num_milliseconds() + .abs(); + assert!(diff < 1000, "expires_at drift: {diff}ms"); + } + + #[test] + fn test_token_cache_roundtrip_no_id_token() { + let tokens = AuthTokens { + access_token: "access".to_string(), + refresh_token: "refresh".to_string(), + expires_at: Utc::now() + Duration::hours(1), + id_token: None, + domain: "example.com".to_string(), + }; + + let json = serde_json::to_string(&tokens).unwrap(); + // id_token should be absent from the JSON when None + assert!(!json.contains("id_token")); + + let deserialized: AuthTokens = serde_json::from_str(&json).unwrap(); + assert!(deserialized.id_token.is_none()); + } + + #[test] + fn test_token_expiry_check_valid() { + let tokens = AuthTokens { + access_token: "valid".to_string(), + refresh_token: "refresh".to_string(), + expires_at: Utc::now() + Duration::hours(1), + id_token: None, + domain: "example.com".to_string(), + }; + + let now = Utc::now(); + // Token is valid: more than 60 seconds until expiry + assert!(tokens.expires_at > now + Duration::seconds(60)); + } + + #[test] + fn test_token_expiry_check_expired() { + let tokens = AuthTokens { + access_token: "expired".to_string(), + refresh_token: "refresh".to_string(), + expires_at: Utc::now() - Duration::hours(1), + id_token: None, + domain: "example.com".to_string(), + }; + + let now = Utc::now(); + // Token is expired + assert!(tokens.expires_at <= now + Duration::seconds(60)); + } + + #[test] + fn test_token_expiry_check_almost_expired() { + let tokens = AuthTokens { + access_token: "almost".to_string(), + refresh_token: "refresh".to_string(), + expires_at: Utc::now() + Duration::seconds(30), + id_token: None, + domain: "example.com".to_string(), + }; + + let now = Utc::now(); + // Token expires in 30s, which is within the 60s threshold + assert!(tokens.expires_at <= now + Duration::seconds(60)); + } + + #[test] + fn test_jwt_payload_decode() { + // Build a fake JWT: header.payload.signature + let payload_json = r#"{"email":"user@example.com","sub":"12345"}"#; + let encoded_payload = + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload_json.as_bytes()); + let fake_jwt = format!("eyJhbGciOiJSUzI1NiJ9.{encoded_payload}.fakesig"); + + let payload = decode_jwt_payload(&fake_jwt).unwrap(); + assert_eq!(payload["email"], "user@example.com"); + assert_eq!(payload["sub"], "12345"); + } + + #[test] + fn test_extract_email() { + let payload_json = r#"{"email":"alice@sunbeam.pt","name":"Alice"}"#; + let encoded_payload = + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload_json.as_bytes()); + let fake_jwt = format!("eyJhbGciOiJSUzI1NiJ9.{encoded_payload}.fakesig"); + + assert_eq!(extract_email(&fake_jwt), Some("alice@sunbeam.pt".to_string())); + } + + #[test] + fn test_extract_email_missing() { + let payload_json = r#"{"sub":"12345","name":"Bob"}"#; + let encoded_payload = + base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(payload_json.as_bytes()); + let fake_jwt = format!("eyJhbGciOiJSUzI1NiJ9.{encoded_payload}.fakesig"); + + assert_eq!(extract_email(&fake_jwt), None); + } + + #[test] + fn test_urlencoding() { + assert_eq!(urlencoding("hello"), "hello"); + assert_eq!(urlencoding("hello world"), "hello%20world"); + assert_eq!( + urlencoding("http://localhost:9876/callback"), + "http%3A%2F%2Flocalhost%3A9876%2Fcallback" + ); + } + + #[test] + fn test_generate_state() { + let s1 = generate_state(); + let s2 = generate_state(); + assert_ne!(s1, s2); + // 16 bytes base64url -> 22 chars + assert_eq!(s1.len(), 22); + } + + #[test] + fn test_cache_path_is_under_sunbeam() { + let path = cache_path(); + let path_str = path.to_string_lossy(); + assert!(path_str.contains("sunbeam")); + assert!(path_str.ends_with("auth.json")); + } +} diff --git a/src/main.rs b/src/main.rs index b5b62e7..c6feb42 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,7 @@ #[macro_use] mod error; +mod auth; mod checks; mod cli; mod cluster; @@ -12,6 +13,7 @@ mod kube; mod manifests; mod openbao; mod output; +mod pm; mod secrets; mod services; mod tools;