diff --git a/Cargo.toml b/Cargo.toml index 05a3082..85d5234 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -38,3 +38,12 @@ uuid = { version = "1", features = ["v4"] } base64 = "0.22" rusqlite = { version = "0.32", features = ["bundled"] } futures = "0.3" +tonic = "0.14" +tonic-prost = "0.14" +prost = "0.14" +tokio-stream = "0.1" +jsonwebtoken = "9" + +[build-dependencies] +tonic-build = "0.14" +tonic-prost-build = "0.14" diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..8c3f998 --- /dev/null +++ b/build.rs @@ -0,0 +1,4 @@ +fn main() -> Result<(), Box> { + tonic_prost_build::compile_protos("proto/code.proto")?; + Ok(()) +} diff --git a/proto/code.proto b/proto/code.proto new file mode 120000 index 0000000..bd45f08 --- /dev/null +++ b/proto/code.proto @@ -0,0 +1 @@ +/Users/sienna/Development/sunbeam/cli-worktree/sunbeam-proto/proto/code.proto \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index f4921bb..c6d0c56 100644 --- a/src/config.rs +++ b/src/config.rs @@ -12,6 +12,8 @@ pub struct Config { pub services: ServicesConfig, #[serde(default)] pub vault: VaultConfig, + #[serde(default)] + pub grpc: Option, } #[derive(Debug, Clone, Deserialize)] @@ -40,6 +42,9 @@ pub struct AgentsConfig { /// Max recursion depth for research agents spawning sub-agents. #[serde(default = "default_research_max_depth")] pub research_max_depth: usize, + /// Model for coding agent sessions (sunbeam code). + #[serde(default = "default_coding_model")] + pub coding_model: String, } impl Default for AgentsConfig { @@ -53,6 +58,7 @@ impl Default for AgentsConfig { research_max_iterations: default_research_max_iterations(), research_max_agents: default_research_max_agents(), research_max_depth: default_research_max_depth(), + coding_model: default_coding_model(), } } } @@ -233,6 +239,19 @@ fn default_research_agent_model() -> String { "ministral-3b-latest".into() } fn default_research_max_iterations() -> usize { 10 } fn default_research_max_agents() -> usize { 25 } fn default_research_max_depth() -> usize { 4 } +fn default_coding_model() -> String { "devstral-small-2506".into() } + +#[derive(Debug, Clone, Deserialize)] +pub struct GrpcConfig { + /// Address to listen on (default: 0.0.0.0:50051). + #[serde(default = "default_grpc_addr")] + pub listen_addr: String, + /// JWKS URL for JWT validation (default: Hydra's .well-known endpoint). + #[serde(default)] + pub jwks_url: Option, +} + +fn default_grpc_addr() -> String { "0.0.0.0:50051".into() } impl Config { pub fn load(path: &str) -> anyhow::Result { diff --git a/src/grpc/auth.rs b/src/grpc/auth.rs new file mode 100644 index 0000000..4a85a68 --- /dev/null +++ b/src/grpc/auth.rs @@ -0,0 +1,131 @@ +use std::sync::Arc; + +use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, Validation}; +use serde::Deserialize; +use tonic::{Request, Status}; +use tracing::{debug, warn}; + +/// Claims extracted from a valid JWT. +#[derive(Debug, Clone, Deserialize)] +pub struct Claims { + pub sub: String, + #[serde(default)] + pub email: Option, + #[serde(default)] + pub exp: u64, +} + +/// Validates JWTs against Hydra's JWKS endpoint. +pub struct JwtValidator { + jwks: JwkSet, +} + +impl JwtValidator { + /// Fetch JWKS from the given URL and create a validator. + pub async fn new(jwks_url: &str) -> anyhow::Result { + let resp = reqwest::get(jwks_url) + .await + .map_err(|e| anyhow::anyhow!("Failed to fetch JWKS from {jwks_url}: {e}"))?; + + let jwks: JwkSet = resp + .json() + .await + .map_err(|e| anyhow::anyhow!("Failed to parse JWKS: {e}"))?; + + debug!(keys = jwks.keys.len(), "Loaded JWKS"); + + Ok(Self { jwks }) + } + + /// Validate a JWT and return the claims. + pub fn validate(&self, token: &str) -> Result { + let header = decode_header(token).map_err(|e| { + warn!("Invalid JWT header: {e}"); + Status::unauthenticated("Invalid token") + })?; + + let kid = header + .kid + .as_deref() + .ok_or_else(|| Status::unauthenticated("Token missing kid"))?; + + let key = self.jwks.find(kid).ok_or_else(|| { + warn!(kid, "Unknown key ID in JWT"); + Status::unauthenticated("Unknown signing key") + })?; + + let decoding_key = DecodingKey::from_jwk(key).map_err(|e| { + warn!("Failed to create decoding key: {e}"); + Status::unauthenticated("Invalid signing key") + })?; + + let mut validation = Validation::new(Algorithm::RS256); + validation.validate_exp = true; + validation.validate_aud = false; + + let token_data = + decode::(token, &decoding_key, &validation).map_err(|e| { + warn!("JWT validation failed: {e}"); + Status::unauthenticated("Token validation failed") + })?; + + Ok(token_data.claims) + } +} + +/// Tonic interceptor that validates JWT from the `authorization` metadata +/// and inserts `Claims` into the request extensions. +#[derive(Clone)] +pub struct JwtInterceptor { + validator: Arc, +} + +impl JwtInterceptor { + pub fn new(validator: Arc) -> Self { + Self { validator } + } +} + +impl tonic::service::Interceptor for JwtInterceptor { + fn call(&mut self, mut req: Request<()>) -> Result, Status> { + let token = req + .metadata() + .get("authorization") + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.strip_prefix("Bearer ")); + + match token { + Some(token) => { + let claims = self.validator.validate(token)?; + debug!(sub = claims.sub.as_str(), "Authenticated gRPC request"); + req.extensions_mut().insert(claims); + Ok(req) + } + None => Err(Status::unauthenticated("Missing authorization token")), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_claims_deserialize() { + let json = serde_json::json!({ + "sub": "996a2cb8-98b7-46ed-9cdf-d31c883d897f", + "email": "sienna@sunbeam.pt", + "exp": 1774310400 + }); + let claims: Claims = serde_json::from_value(json).unwrap(); + assert_eq!(claims.sub, "996a2cb8-98b7-46ed-9cdf-d31c883d897f"); + assert_eq!(claims.email.as_deref(), Some("sienna@sunbeam.pt")); + } + + #[test] + fn test_claims_minimal() { + let json = serde_json::json!({ "sub": "user-123", "exp": 0 }); + let claims: Claims = serde_json::from_value(json).unwrap(); + assert!(claims.email.is_none()); + } +} diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs new file mode 100644 index 0000000..42d52c7 --- /dev/null +++ b/src/grpc/mod.rs @@ -0,0 +1,64 @@ +pub mod auth; +pub mod router; +pub mod service; + +mod proto { + tonic::include_proto!("sunbeam.code.v1"); +} + +pub use proto::code_agent_server::{CodeAgent, CodeAgentServer}; +pub use proto::*; + +use std::sync::Arc; +use tonic::transport::Server; +use tracing::{error, info}; + +use crate::config::Config; +use crate::persistence::Store; +use crate::tools::ToolRegistry; + +/// Shared state for the gRPC server. +pub struct GrpcState { + pub config: Arc, + pub tools: Arc, + pub store: Arc, + pub mistral: Arc, + pub matrix: matrix_sdk::Client, +} + +/// Start the gRPC server. Call from main.rs alongside the Matrix sync loop. +pub async fn start_server(state: Arc) -> anyhow::Result<()> { + let addr = state + .config + .grpc + .as_ref() + .map(|g| g.listen_addr.clone()) + .unwrap_or_else(|| "0.0.0.0:50051".into()); + + let addr = addr.parse()?; + + let jwks_url = state + .config + .grpc + .as_ref() + .and_then(|g| g.jwks_url.clone()) + .unwrap_or_else(|| { + "http://hydra-public.ory.svc.cluster.local:4444/.well-known/jwks.json".into() + }); + + // Initialize JWT validator (fetches JWKS from Hydra) + let jwt_validator = Arc::new(auth::JwtValidator::new(&jwks_url).await?); + let interceptor = auth::JwtInterceptor::new(jwt_validator); + + let svc = service::CodeAgentService::new(state); + let svc = CodeAgentServer::with_interceptor(svc, interceptor); + + info!(%addr, "Starting gRPC server"); + + Server::builder() + .add_service(svc) + .serve(addr) + .await?; + + Ok(()) +} diff --git a/src/grpc/router.rs b/src/grpc/router.rs new file mode 100644 index 0000000..6d4e3cc --- /dev/null +++ b/src/grpc/router.rs @@ -0,0 +1,68 @@ +/// Determines whether a tool call should be executed on the client (local +/// filesystem) or on the server (Sol's ToolRegistry). +/// +/// Client-side tools require the gRPC stream to be active. When no client +/// is connected (user is on Matrix), Sol falls back to Gitea for file access. + +const CLIENT_TOOLS: &[&str] = &[ + "file_read", + "file_write", + "search_replace", + "grep", + "bash", + "list_directory", + "ask_user", +]; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToolSide { + /// Execute on the developer's machine via gRPC stream. + Client, + /// Execute on Sol's server via ToolRegistry. + Server, +} + +/// Route a tool call to the appropriate side. +pub fn route(tool_name: &str) -> ToolSide { + if CLIENT_TOOLS.contains(&tool_name) { + ToolSide::Client + } else { + ToolSide::Server + } +} + +/// Check if a tool is a client-side tool. +pub fn is_client_tool(tool_name: &str) -> bool { + route(tool_name) == ToolSide::Client +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_client_tools() { + assert_eq!(route("file_read"), ToolSide::Client); + assert_eq!(route("file_write"), ToolSide::Client); + assert_eq!(route("bash"), ToolSide::Client); + assert_eq!(route("grep"), ToolSide::Client); + assert_eq!(route("search_replace"), ToolSide::Client); + assert_eq!(route("list_directory"), ToolSide::Client); + assert_eq!(route("ask_user"), ToolSide::Client); + } + + #[test] + fn test_server_tools() { + assert_eq!(route("search_archive"), ToolSide::Server); + assert_eq!(route("search_web"), ToolSide::Server); + assert_eq!(route("gitea_list_repos"), ToolSide::Server); + assert_eq!(route("identity_list_users"), ToolSide::Server); + assert_eq!(route("research"), ToolSide::Server); + assert_eq!(route("run_script"), ToolSide::Server); + } + + #[test] + fn test_unknown_defaults_to_server() { + assert_eq!(route("unknown_tool"), ToolSide::Server); + } +} diff --git a/src/grpc/service.rs b/src/grpc/service.rs new file mode 100644 index 0000000..6d358ac --- /dev/null +++ b/src/grpc/service.rs @@ -0,0 +1,175 @@ +use std::pin::Pin; +use std::sync::Arc; + +use futures::Stream; +use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tonic::{Request, Response, Status, Streaming}; +use tracing::{error, info, warn}; + +use super::auth::Claims; +use super::proto::code_agent_server::CodeAgent; +use super::proto::*; +use super::GrpcState; + +pub struct CodeAgentService { + state: Arc, +} + +impl CodeAgentService { + pub fn new(state: Arc) -> Self { + Self { state } + } +} + +#[tonic::async_trait] +impl CodeAgent for CodeAgentService { + type SessionStream = Pin> + Send>>; + + async fn session( + &self, + request: Request>, + ) -> Result, Status> { + // Extract JWT claims from the request extensions (set by auth middleware) + let claims = request + .extensions() + .get::() + .cloned() + .ok_or_else(|| Status::unauthenticated("No valid authentication token"))?; + + info!( + user = claims.sub.as_str(), + email = claims.email.as_deref().unwrap_or("?"), + "New coding session" + ); + + let mut in_stream = request.into_inner(); + let state = self.state.clone(); + + // Channel for sending server messages to the client + let (tx, rx) = mpsc::channel::>(32); + + // Spawn the session handler + tokio::spawn(async move { + if let Err(e) = handle_session(&state, &claims, &mut in_stream, &tx).await { + error!(user = claims.sub.as_str(), "Session error: {e}"); + let _ = tx + .send(Ok(ServerMessage { + payload: Some(server_message::Payload::Error(Error { + message: e.to_string(), + fatal: true, + })), + })) + .await; + } + }); + + let out_stream = ReceiverStream::new(rx); + Ok(Response::new(Box::pin(out_stream))) + } +} + +/// Handle a single coding session (runs in a spawned task). +async fn handle_session( + state: &GrpcState, + claims: &Claims, + in_stream: &mut Streaming, + tx: &mpsc::Sender>, +) -> anyhow::Result<()> { + // Wait for the first message — must be StartSession + let first = in_stream + .message() + .await? + .ok_or_else(|| anyhow::anyhow!("Stream closed before StartSession"))?; + + let start = match first.payload { + Some(client_message::Payload::Start(s)) => s, + _ => anyhow::bail!("First message must be StartSession"), + }; + + info!( + user = claims.sub.as_str(), + project = start.project_path.as_str(), + model = start.model.as_str(), + client_tools = start.client_tools.len(), + "Session started" + ); + + // TODO Phase 2: Create/find Matrix room for this project + // TODO Phase 2: Create Mistral conversation + // TODO Phase 2: Enter agent loop + + // For now, send SessionReady and echo back + tx.send(Ok(ServerMessage { + payload: Some(server_message::Payload::Ready(SessionReady { + session_id: uuid::Uuid::new_v4().to_string(), + room_id: String::new(), // TODO: Matrix room + model: if start.model.is_empty() { + state + .config + .agents + .coding_model + .clone() + } else { + start.model.clone() + }, + })), + })) + .await?; + + // Main message loop + while let Some(msg) = in_stream.message().await? { + match msg.payload { + Some(client_message::Payload::Input(input)) => { + info!( + user = claims.sub.as_str(), + text_len = input.text.len(), + "User input received" + ); + + // TODO Phase 2: Send to Mistral, handle tool calls, stream response + // For now, echo back as a simple acknowledgment + tx.send(Ok(ServerMessage { + payload: Some(server_message::Payload::Done(TextDone { + full_text: format!("[stub] received: {}", input.text), + input_tokens: 0, + output_tokens: 0, + })), + })) + .await?; + } + Some(client_message::Payload::ToolResult(result)) => { + info!( + call_id = result.call_id.as_str(), + is_error = result.is_error, + "Tool result received" + ); + // TODO Phase 2: Feed back to Mistral + } + Some(client_message::Payload::Approval(approval)) => { + info!( + call_id = approval.call_id.as_str(), + approved = approval.approved, + "Tool approval received" + ); + // TODO Phase 2: Execute or skip tool + } + Some(client_message::Payload::End(_)) => { + info!(user = claims.sub.as_str(), "Session ended by client"); + tx.send(Ok(ServerMessage { + payload: Some(server_message::Payload::End(SessionEnd { + summary: "Session ended.".into(), + })), + })) + .await?; + break; + } + Some(client_message::Payload::Start(_)) => { + warn!("Received duplicate StartSession — ignoring"); + } + None => continue, + } + } + + Ok(()) +} diff --git a/src/main.rs b/src/main.rs index 11d52cd..c6046c6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ mod conversations; mod matrix_utils; mod memory; mod persistence; +mod grpc; mod sdk; mod sync; mod time_context;