feat(code): gRPC server with JWT auth + tool routing

tonic 0.14 gRPC server for sunbeam code sessions:
- bidirectional streaming Session RPC
- JWT interceptor validates tokens against Hydra JWKS
- tool router classifies calls as client-side (file_read, bash,
  grep, etc.) or server-side (gitea, identity, search, etc.)
- service stub with session lifecycle (start, chat, tool results, end)
- coding_model config (default: devstral-small-2506)
- grpc config section (listen_addr, jwks_url)
- 182 tests (5 new: JWT claims, tool routing)

phase 2 TODOs: Matrix room bridge, Mistral agent loop, streaming
This commit is contained in:
2026-03-23 11:35:37 +00:00
parent 2a1d7a003d
commit 35b6246fa7
9 changed files with 472 additions and 0 deletions

View File

@@ -38,3 +38,12 @@ uuid = { version = "1", features = ["v4"] }
base64 = "0.22"
rusqlite = { version = "0.32", features = ["bundled"] }
futures = "0.3"
tonic = "0.14"
tonic-prost = "0.14"
prost = "0.14"
tokio-stream = "0.1"
jsonwebtoken = "9"
[build-dependencies]
tonic-build = "0.14"
tonic-prost-build = "0.14"

4
build.rs Normal file
View File

@@ -0,0 +1,4 @@
fn main() -> Result<(), Box<dyn std::error::Error>> {
tonic_prost_build::compile_protos("proto/code.proto")?;
Ok(())
}

1
proto/code.proto Symbolic link
View File

@@ -0,0 +1 @@
/Users/sienna/Development/sunbeam/cli-worktree/sunbeam-proto/proto/code.proto

View File

@@ -12,6 +12,8 @@ pub struct Config {
pub services: ServicesConfig,
#[serde(default)]
pub vault: VaultConfig,
#[serde(default)]
pub grpc: Option<GrpcConfig>,
}
#[derive(Debug, Clone, Deserialize)]
@@ -40,6 +42,9 @@ pub struct AgentsConfig {
/// Max recursion depth for research agents spawning sub-agents.
#[serde(default = "default_research_max_depth")]
pub research_max_depth: usize,
/// Model for coding agent sessions (sunbeam code).
#[serde(default = "default_coding_model")]
pub coding_model: String,
}
impl Default for AgentsConfig {
@@ -53,6 +58,7 @@ impl Default for AgentsConfig {
research_max_iterations: default_research_max_iterations(),
research_max_agents: default_research_max_agents(),
research_max_depth: default_research_max_depth(),
coding_model: default_coding_model(),
}
}
}
@@ -233,6 +239,19 @@ fn default_research_agent_model() -> String { "ministral-3b-latest".into() }
fn default_research_max_iterations() -> usize { 10 }
fn default_research_max_agents() -> usize { 25 }
fn default_research_max_depth() -> usize { 4 }
fn default_coding_model() -> String { "devstral-small-2506".into() }
#[derive(Debug, Clone, Deserialize)]
pub struct GrpcConfig {
/// Address to listen on (default: 0.0.0.0:50051).
#[serde(default = "default_grpc_addr")]
pub listen_addr: String,
/// JWKS URL for JWT validation (default: Hydra's .well-known endpoint).
#[serde(default)]
pub jwks_url: Option<String>,
}
fn default_grpc_addr() -> String { "0.0.0.0:50051".into() }
impl Config {
pub fn load(path: &str) -> anyhow::Result<Self> {

131
src/grpc/auth.rs Normal file
View File

@@ -0,0 +1,131 @@
use std::sync::Arc;
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, Validation};
use serde::Deserialize;
use tonic::{Request, Status};
use tracing::{debug, warn};
/// Claims extracted from a valid JWT.
#[derive(Debug, Clone, Deserialize)]
pub struct Claims {
pub sub: String,
#[serde(default)]
pub email: Option<String>,
#[serde(default)]
pub exp: u64,
}
/// Validates JWTs against Hydra's JWKS endpoint.
pub struct JwtValidator {
jwks: JwkSet,
}
impl JwtValidator {
/// Fetch JWKS from the given URL and create a validator.
pub async fn new(jwks_url: &str) -> anyhow::Result<Self> {
let resp = reqwest::get(jwks_url)
.await
.map_err(|e| anyhow::anyhow!("Failed to fetch JWKS from {jwks_url}: {e}"))?;
let jwks: JwkSet = resp
.json()
.await
.map_err(|e| anyhow::anyhow!("Failed to parse JWKS: {e}"))?;
debug!(keys = jwks.keys.len(), "Loaded JWKS");
Ok(Self { jwks })
}
/// Validate a JWT and return the claims.
pub fn validate(&self, token: &str) -> Result<Claims, Status> {
let header = decode_header(token).map_err(|e| {
warn!("Invalid JWT header: {e}");
Status::unauthenticated("Invalid token")
})?;
let kid = header
.kid
.as_deref()
.ok_or_else(|| Status::unauthenticated("Token missing kid"))?;
let key = self.jwks.find(kid).ok_or_else(|| {
warn!(kid, "Unknown key ID in JWT");
Status::unauthenticated("Unknown signing key")
})?;
let decoding_key = DecodingKey::from_jwk(key).map_err(|e| {
warn!("Failed to create decoding key: {e}");
Status::unauthenticated("Invalid signing key")
})?;
let mut validation = Validation::new(Algorithm::RS256);
validation.validate_exp = true;
validation.validate_aud = false;
let token_data =
decode::<Claims>(token, &decoding_key, &validation).map_err(|e| {
warn!("JWT validation failed: {e}");
Status::unauthenticated("Token validation failed")
})?;
Ok(token_data.claims)
}
}
/// Tonic interceptor that validates JWT from the `authorization` metadata
/// and inserts `Claims` into the request extensions.
#[derive(Clone)]
pub struct JwtInterceptor {
validator: Arc<JwtValidator>,
}
impl JwtInterceptor {
pub fn new(validator: Arc<JwtValidator>) -> Self {
Self { validator }
}
}
impl tonic::service::Interceptor for JwtInterceptor {
fn call(&mut self, mut req: Request<()>) -> Result<Request<()>, Status> {
let token = req
.metadata()
.get("authorization")
.and_then(|v| v.to_str().ok())
.and_then(|v| v.strip_prefix("Bearer "));
match token {
Some(token) => {
let claims = self.validator.validate(token)?;
debug!(sub = claims.sub.as_str(), "Authenticated gRPC request");
req.extensions_mut().insert(claims);
Ok(req)
}
None => Err(Status::unauthenticated("Missing authorization token")),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_claims_deserialize() {
let json = serde_json::json!({
"sub": "996a2cb8-98b7-46ed-9cdf-d31c883d897f",
"email": "sienna@sunbeam.pt",
"exp": 1774310400
});
let claims: Claims = serde_json::from_value(json).unwrap();
assert_eq!(claims.sub, "996a2cb8-98b7-46ed-9cdf-d31c883d897f");
assert_eq!(claims.email.as_deref(), Some("sienna@sunbeam.pt"));
}
#[test]
fn test_claims_minimal() {
let json = serde_json::json!({ "sub": "user-123", "exp": 0 });
let claims: Claims = serde_json::from_value(json).unwrap();
assert!(claims.email.is_none());
}
}

64
src/grpc/mod.rs Normal file
View File

@@ -0,0 +1,64 @@
pub mod auth;
pub mod router;
pub mod service;
mod proto {
tonic::include_proto!("sunbeam.code.v1");
}
pub use proto::code_agent_server::{CodeAgent, CodeAgentServer};
pub use proto::*;
use std::sync::Arc;
use tonic::transport::Server;
use tracing::{error, info};
use crate::config::Config;
use crate::persistence::Store;
use crate::tools::ToolRegistry;
/// Shared state for the gRPC server.
pub struct GrpcState {
pub config: Arc<Config>,
pub tools: Arc<ToolRegistry>,
pub store: Arc<Store>,
pub mistral: Arc<mistralai_client::v1::client::Client>,
pub matrix: matrix_sdk::Client,
}
/// Start the gRPC server. Call from main.rs alongside the Matrix sync loop.
pub async fn start_server(state: Arc<GrpcState>) -> anyhow::Result<()> {
let addr = state
.config
.grpc
.as_ref()
.map(|g| g.listen_addr.clone())
.unwrap_or_else(|| "0.0.0.0:50051".into());
let addr = addr.parse()?;
let jwks_url = state
.config
.grpc
.as_ref()
.and_then(|g| g.jwks_url.clone())
.unwrap_or_else(|| {
"http://hydra-public.ory.svc.cluster.local:4444/.well-known/jwks.json".into()
});
// Initialize JWT validator (fetches JWKS from Hydra)
let jwt_validator = Arc::new(auth::JwtValidator::new(&jwks_url).await?);
let interceptor = auth::JwtInterceptor::new(jwt_validator);
let svc = service::CodeAgentService::new(state);
let svc = CodeAgentServer::with_interceptor(svc, interceptor);
info!(%addr, "Starting gRPC server");
Server::builder()
.add_service(svc)
.serve(addr)
.await?;
Ok(())
}

68
src/grpc/router.rs Normal file
View File

@@ -0,0 +1,68 @@
/// Determines whether a tool call should be executed on the client (local
/// filesystem) or on the server (Sol's ToolRegistry).
///
/// Client-side tools require the gRPC stream to be active. When no client
/// is connected (user is on Matrix), Sol falls back to Gitea for file access.
const CLIENT_TOOLS: &[&str] = &[
"file_read",
"file_write",
"search_replace",
"grep",
"bash",
"list_directory",
"ask_user",
];
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolSide {
/// Execute on the developer's machine via gRPC stream.
Client,
/// Execute on Sol's server via ToolRegistry.
Server,
}
/// Route a tool call to the appropriate side.
pub fn route(tool_name: &str) -> ToolSide {
if CLIENT_TOOLS.contains(&tool_name) {
ToolSide::Client
} else {
ToolSide::Server
}
}
/// Check if a tool is a client-side tool.
pub fn is_client_tool(tool_name: &str) -> bool {
route(tool_name) == ToolSide::Client
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_tools() {
assert_eq!(route("file_read"), ToolSide::Client);
assert_eq!(route("file_write"), ToolSide::Client);
assert_eq!(route("bash"), ToolSide::Client);
assert_eq!(route("grep"), ToolSide::Client);
assert_eq!(route("search_replace"), ToolSide::Client);
assert_eq!(route("list_directory"), ToolSide::Client);
assert_eq!(route("ask_user"), ToolSide::Client);
}
#[test]
fn test_server_tools() {
assert_eq!(route("search_archive"), ToolSide::Server);
assert_eq!(route("search_web"), ToolSide::Server);
assert_eq!(route("gitea_list_repos"), ToolSide::Server);
assert_eq!(route("identity_list_users"), ToolSide::Server);
assert_eq!(route("research"), ToolSide::Server);
assert_eq!(route("run_script"), ToolSide::Server);
}
#[test]
fn test_unknown_defaults_to_server() {
assert_eq!(route("unknown_tool"), ToolSide::Server);
}
}

175
src/grpc/service.rs Normal file
View File

@@ -0,0 +1,175 @@
use std::pin::Pin;
use std::sync::Arc;
use futures::Stream;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tonic::{Request, Response, Status, Streaming};
use tracing::{error, info, warn};
use super::auth::Claims;
use super::proto::code_agent_server::CodeAgent;
use super::proto::*;
use super::GrpcState;
pub struct CodeAgentService {
state: Arc<GrpcState>,
}
impl CodeAgentService {
pub fn new(state: Arc<GrpcState>) -> Self {
Self { state }
}
}
#[tonic::async_trait]
impl CodeAgent for CodeAgentService {
type SessionStream = Pin<Box<dyn Stream<Item = Result<ServerMessage, Status>> + Send>>;
async fn session(
&self,
request: Request<Streaming<ClientMessage>>,
) -> Result<Response<Self::SessionStream>, Status> {
// Extract JWT claims from the request extensions (set by auth middleware)
let claims = request
.extensions()
.get::<Claims>()
.cloned()
.ok_or_else(|| Status::unauthenticated("No valid authentication token"))?;
info!(
user = claims.sub.as_str(),
email = claims.email.as_deref().unwrap_or("?"),
"New coding session"
);
let mut in_stream = request.into_inner();
let state = self.state.clone();
// Channel for sending server messages to the client
let (tx, rx) = mpsc::channel::<Result<ServerMessage, Status>>(32);
// Spawn the session handler
tokio::spawn(async move {
if let Err(e) = handle_session(&state, &claims, &mut in_stream, &tx).await {
error!(user = claims.sub.as_str(), "Session error: {e}");
let _ = tx
.send(Ok(ServerMessage {
payload: Some(server_message::Payload::Error(Error {
message: e.to_string(),
fatal: true,
})),
}))
.await;
}
});
let out_stream = ReceiverStream::new(rx);
Ok(Response::new(Box::pin(out_stream)))
}
}
/// Handle a single coding session (runs in a spawned task).
async fn handle_session(
state: &GrpcState,
claims: &Claims,
in_stream: &mut Streaming<ClientMessage>,
tx: &mpsc::Sender<Result<ServerMessage, Status>>,
) -> anyhow::Result<()> {
// Wait for the first message — must be StartSession
let first = in_stream
.message()
.await?
.ok_or_else(|| anyhow::anyhow!("Stream closed before StartSession"))?;
let start = match first.payload {
Some(client_message::Payload::Start(s)) => s,
_ => anyhow::bail!("First message must be StartSession"),
};
info!(
user = claims.sub.as_str(),
project = start.project_path.as_str(),
model = start.model.as_str(),
client_tools = start.client_tools.len(),
"Session started"
);
// TODO Phase 2: Create/find Matrix room for this project
// TODO Phase 2: Create Mistral conversation
// TODO Phase 2: Enter agent loop
// For now, send SessionReady and echo back
tx.send(Ok(ServerMessage {
payload: Some(server_message::Payload::Ready(SessionReady {
session_id: uuid::Uuid::new_v4().to_string(),
room_id: String::new(), // TODO: Matrix room
model: if start.model.is_empty() {
state
.config
.agents
.coding_model
.clone()
} else {
start.model.clone()
},
})),
}))
.await?;
// Main message loop
while let Some(msg) = in_stream.message().await? {
match msg.payload {
Some(client_message::Payload::Input(input)) => {
info!(
user = claims.sub.as_str(),
text_len = input.text.len(),
"User input received"
);
// TODO Phase 2: Send to Mistral, handle tool calls, stream response
// For now, echo back as a simple acknowledgment
tx.send(Ok(ServerMessage {
payload: Some(server_message::Payload::Done(TextDone {
full_text: format!("[stub] received: {}", input.text),
input_tokens: 0,
output_tokens: 0,
})),
}))
.await?;
}
Some(client_message::Payload::ToolResult(result)) => {
info!(
call_id = result.call_id.as_str(),
is_error = result.is_error,
"Tool result received"
);
// TODO Phase 2: Feed back to Mistral
}
Some(client_message::Payload::Approval(approval)) => {
info!(
call_id = approval.call_id.as_str(),
approved = approval.approved,
"Tool approval received"
);
// TODO Phase 2: Execute or skip tool
}
Some(client_message::Payload::End(_)) => {
info!(user = claims.sub.as_str(), "Session ended by client");
tx.send(Ok(ServerMessage {
payload: Some(server_message::Payload::End(SessionEnd {
summary: "Session ended.".into(),
})),
}))
.await?;
break;
}
Some(client_message::Payload::Start(_)) => {
warn!("Received duplicate StartSession — ignoring");
}
None => continue,
}
}
Ok(())
}

View File

@@ -8,6 +8,7 @@ mod conversations;
mod matrix_utils;
mod memory;
mod persistence;
mod grpc;
mod sdk;
mod sync;
mod time_context;