feat(code): gRPC server with JWT auth + tool routing
tonic 0.14 gRPC server for sunbeam code sessions: - bidirectional streaming Session RPC - JWT interceptor validates tokens against Hydra JWKS - tool router classifies calls as client-side (file_read, bash, grep, etc.) or server-side (gitea, identity, search, etc.) - service stub with session lifecycle (start, chat, tool results, end) - coding_model config (default: devstral-small-2506) - grpc config section (listen_addr, jwks_url) - 182 tests (5 new: JWT claims, tool routing) phase 2 TODOs: Matrix room bridge, Mistral agent loop, streaming
This commit is contained in:
@@ -38,3 +38,12 @@ uuid = { version = "1", features = ["v4"] }
|
||||
base64 = "0.22"
|
||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||
futures = "0.3"
|
||||
tonic = "0.14"
|
||||
tonic-prost = "0.14"
|
||||
prost = "0.14"
|
||||
tokio-stream = "0.1"
|
||||
jsonwebtoken = "9"
|
||||
|
||||
[build-dependencies]
|
||||
tonic-build = "0.14"
|
||||
tonic-prost-build = "0.14"
|
||||
|
||||
4
build.rs
Normal file
4
build.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
tonic_prost_build::compile_protos("proto/code.proto")?;
|
||||
Ok(())
|
||||
}
|
||||
1
proto/code.proto
Symbolic link
1
proto/code.proto
Symbolic link
@@ -0,0 +1 @@
|
||||
/Users/sienna/Development/sunbeam/cli-worktree/sunbeam-proto/proto/code.proto
|
||||
@@ -12,6 +12,8 @@ pub struct Config {
|
||||
pub services: ServicesConfig,
|
||||
#[serde(default)]
|
||||
pub vault: VaultConfig,
|
||||
#[serde(default)]
|
||||
pub grpc: Option<GrpcConfig>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
@@ -40,6 +42,9 @@ pub struct AgentsConfig {
|
||||
/// Max recursion depth for research agents spawning sub-agents.
|
||||
#[serde(default = "default_research_max_depth")]
|
||||
pub research_max_depth: usize,
|
||||
/// Model for coding agent sessions (sunbeam code).
|
||||
#[serde(default = "default_coding_model")]
|
||||
pub coding_model: String,
|
||||
}
|
||||
|
||||
impl Default for AgentsConfig {
|
||||
@@ -53,6 +58,7 @@ impl Default for AgentsConfig {
|
||||
research_max_iterations: default_research_max_iterations(),
|
||||
research_max_agents: default_research_max_agents(),
|
||||
research_max_depth: default_research_max_depth(),
|
||||
coding_model: default_coding_model(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -233,6 +239,19 @@ fn default_research_agent_model() -> String { "ministral-3b-latest".into() }
|
||||
fn default_research_max_iterations() -> usize { 10 }
|
||||
fn default_research_max_agents() -> usize { 25 }
|
||||
fn default_research_max_depth() -> usize { 4 }
|
||||
fn default_coding_model() -> String { "devstral-small-2506".into() }
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct GrpcConfig {
|
||||
/// Address to listen on (default: 0.0.0.0:50051).
|
||||
#[serde(default = "default_grpc_addr")]
|
||||
pub listen_addr: String,
|
||||
/// JWKS URL for JWT validation (default: Hydra's .well-known endpoint).
|
||||
#[serde(default)]
|
||||
pub jwks_url: Option<String>,
|
||||
}
|
||||
|
||||
fn default_grpc_addr() -> String { "0.0.0.0:50051".into() }
|
||||
|
||||
impl Config {
|
||||
pub fn load(path: &str) -> anyhow::Result<Self> {
|
||||
|
||||
131
src/grpc/auth.rs
Normal file
131
src/grpc/auth.rs
Normal file
@@ -0,0 +1,131 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use jsonwebtoken::{decode, decode_header, jwk::JwkSet, Algorithm, DecodingKey, Validation};
|
||||
use serde::Deserialize;
|
||||
use tonic::{Request, Status};
|
||||
use tracing::{debug, warn};
|
||||
|
||||
/// Claims extracted from a valid JWT.
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct Claims {
|
||||
pub sub: String,
|
||||
#[serde(default)]
|
||||
pub email: Option<String>,
|
||||
#[serde(default)]
|
||||
pub exp: u64,
|
||||
}
|
||||
|
||||
/// Validates JWTs against Hydra's JWKS endpoint.
|
||||
pub struct JwtValidator {
|
||||
jwks: JwkSet,
|
||||
}
|
||||
|
||||
impl JwtValidator {
|
||||
/// Fetch JWKS from the given URL and create a validator.
|
||||
pub async fn new(jwks_url: &str) -> anyhow::Result<Self> {
|
||||
let resp = reqwest::get(jwks_url)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to fetch JWKS from {jwks_url}: {e}"))?;
|
||||
|
||||
let jwks: JwkSet = resp
|
||||
.json()
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to parse JWKS: {e}"))?;
|
||||
|
||||
debug!(keys = jwks.keys.len(), "Loaded JWKS");
|
||||
|
||||
Ok(Self { jwks })
|
||||
}
|
||||
|
||||
/// Validate a JWT and return the claims.
|
||||
pub fn validate(&self, token: &str) -> Result<Claims, Status> {
|
||||
let header = decode_header(token).map_err(|e| {
|
||||
warn!("Invalid JWT header: {e}");
|
||||
Status::unauthenticated("Invalid token")
|
||||
})?;
|
||||
|
||||
let kid = header
|
||||
.kid
|
||||
.as_deref()
|
||||
.ok_or_else(|| Status::unauthenticated("Token missing kid"))?;
|
||||
|
||||
let key = self.jwks.find(kid).ok_or_else(|| {
|
||||
warn!(kid, "Unknown key ID in JWT");
|
||||
Status::unauthenticated("Unknown signing key")
|
||||
})?;
|
||||
|
||||
let decoding_key = DecodingKey::from_jwk(key).map_err(|e| {
|
||||
warn!("Failed to create decoding key: {e}");
|
||||
Status::unauthenticated("Invalid signing key")
|
||||
})?;
|
||||
|
||||
let mut validation = Validation::new(Algorithm::RS256);
|
||||
validation.validate_exp = true;
|
||||
validation.validate_aud = false;
|
||||
|
||||
let token_data =
|
||||
decode::<Claims>(token, &decoding_key, &validation).map_err(|e| {
|
||||
warn!("JWT validation failed: {e}");
|
||||
Status::unauthenticated("Token validation failed")
|
||||
})?;
|
||||
|
||||
Ok(token_data.claims)
|
||||
}
|
||||
}
|
||||
|
||||
/// Tonic interceptor that validates JWT from the `authorization` metadata
|
||||
/// and inserts `Claims` into the request extensions.
|
||||
#[derive(Clone)]
|
||||
pub struct JwtInterceptor {
|
||||
validator: Arc<JwtValidator>,
|
||||
}
|
||||
|
||||
impl JwtInterceptor {
|
||||
pub fn new(validator: Arc<JwtValidator>) -> Self {
|
||||
Self { validator }
|
||||
}
|
||||
}
|
||||
|
||||
impl tonic::service::Interceptor for JwtInterceptor {
|
||||
fn call(&mut self, mut req: Request<()>) -> Result<Request<()>, Status> {
|
||||
let token = req
|
||||
.metadata()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|v| v.strip_prefix("Bearer "));
|
||||
|
||||
match token {
|
||||
Some(token) => {
|
||||
let claims = self.validator.validate(token)?;
|
||||
debug!(sub = claims.sub.as_str(), "Authenticated gRPC request");
|
||||
req.extensions_mut().insert(claims);
|
||||
Ok(req)
|
||||
}
|
||||
None => Err(Status::unauthenticated("Missing authorization token")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_claims_deserialize() {
|
||||
let json = serde_json::json!({
|
||||
"sub": "996a2cb8-98b7-46ed-9cdf-d31c883d897f",
|
||||
"email": "sienna@sunbeam.pt",
|
||||
"exp": 1774310400
|
||||
});
|
||||
let claims: Claims = serde_json::from_value(json).unwrap();
|
||||
assert_eq!(claims.sub, "996a2cb8-98b7-46ed-9cdf-d31c883d897f");
|
||||
assert_eq!(claims.email.as_deref(), Some("sienna@sunbeam.pt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_claims_minimal() {
|
||||
let json = serde_json::json!({ "sub": "user-123", "exp": 0 });
|
||||
let claims: Claims = serde_json::from_value(json).unwrap();
|
||||
assert!(claims.email.is_none());
|
||||
}
|
||||
}
|
||||
64
src/grpc/mod.rs
Normal file
64
src/grpc/mod.rs
Normal file
@@ -0,0 +1,64 @@
|
||||
pub mod auth;
|
||||
pub mod router;
|
||||
pub mod service;
|
||||
|
||||
mod proto {
|
||||
tonic::include_proto!("sunbeam.code.v1");
|
||||
}
|
||||
|
||||
pub use proto::code_agent_server::{CodeAgent, CodeAgentServer};
|
||||
pub use proto::*;
|
||||
|
||||
use std::sync::Arc;
|
||||
use tonic::transport::Server;
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::persistence::Store;
|
||||
use crate::tools::ToolRegistry;
|
||||
|
||||
/// Shared state for the gRPC server.
|
||||
pub struct GrpcState {
|
||||
pub config: Arc<Config>,
|
||||
pub tools: Arc<ToolRegistry>,
|
||||
pub store: Arc<Store>,
|
||||
pub mistral: Arc<mistralai_client::v1::client::Client>,
|
||||
pub matrix: matrix_sdk::Client,
|
||||
}
|
||||
|
||||
/// Start the gRPC server. Call from main.rs alongside the Matrix sync loop.
|
||||
pub async fn start_server(state: Arc<GrpcState>) -> anyhow::Result<()> {
|
||||
let addr = state
|
||||
.config
|
||||
.grpc
|
||||
.as_ref()
|
||||
.map(|g| g.listen_addr.clone())
|
||||
.unwrap_or_else(|| "0.0.0.0:50051".into());
|
||||
|
||||
let addr = addr.parse()?;
|
||||
|
||||
let jwks_url = state
|
||||
.config
|
||||
.grpc
|
||||
.as_ref()
|
||||
.and_then(|g| g.jwks_url.clone())
|
||||
.unwrap_or_else(|| {
|
||||
"http://hydra-public.ory.svc.cluster.local:4444/.well-known/jwks.json".into()
|
||||
});
|
||||
|
||||
// Initialize JWT validator (fetches JWKS from Hydra)
|
||||
let jwt_validator = Arc::new(auth::JwtValidator::new(&jwks_url).await?);
|
||||
let interceptor = auth::JwtInterceptor::new(jwt_validator);
|
||||
|
||||
let svc = service::CodeAgentService::new(state);
|
||||
let svc = CodeAgentServer::with_interceptor(svc, interceptor);
|
||||
|
||||
info!(%addr, "Starting gRPC server");
|
||||
|
||||
Server::builder()
|
||||
.add_service(svc)
|
||||
.serve(addr)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
68
src/grpc/router.rs
Normal file
68
src/grpc/router.rs
Normal file
@@ -0,0 +1,68 @@
|
||||
/// Determines whether a tool call should be executed on the client (local
|
||||
/// filesystem) or on the server (Sol's ToolRegistry).
|
||||
///
|
||||
/// Client-side tools require the gRPC stream to be active. When no client
|
||||
/// is connected (user is on Matrix), Sol falls back to Gitea for file access.
|
||||
|
||||
const CLIENT_TOOLS: &[&str] = &[
|
||||
"file_read",
|
||||
"file_write",
|
||||
"search_replace",
|
||||
"grep",
|
||||
"bash",
|
||||
"list_directory",
|
||||
"ask_user",
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ToolSide {
|
||||
/// Execute on the developer's machine via gRPC stream.
|
||||
Client,
|
||||
/// Execute on Sol's server via ToolRegistry.
|
||||
Server,
|
||||
}
|
||||
|
||||
/// Route a tool call to the appropriate side.
|
||||
pub fn route(tool_name: &str) -> ToolSide {
|
||||
if CLIENT_TOOLS.contains(&tool_name) {
|
||||
ToolSide::Client
|
||||
} else {
|
||||
ToolSide::Server
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a tool is a client-side tool.
|
||||
pub fn is_client_tool(tool_name: &str) -> bool {
|
||||
route(tool_name) == ToolSide::Client
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_client_tools() {
|
||||
assert_eq!(route("file_read"), ToolSide::Client);
|
||||
assert_eq!(route("file_write"), ToolSide::Client);
|
||||
assert_eq!(route("bash"), ToolSide::Client);
|
||||
assert_eq!(route("grep"), ToolSide::Client);
|
||||
assert_eq!(route("search_replace"), ToolSide::Client);
|
||||
assert_eq!(route("list_directory"), ToolSide::Client);
|
||||
assert_eq!(route("ask_user"), ToolSide::Client);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_tools() {
|
||||
assert_eq!(route("search_archive"), ToolSide::Server);
|
||||
assert_eq!(route("search_web"), ToolSide::Server);
|
||||
assert_eq!(route("gitea_list_repos"), ToolSide::Server);
|
||||
assert_eq!(route("identity_list_users"), ToolSide::Server);
|
||||
assert_eq!(route("research"), ToolSide::Server);
|
||||
assert_eq!(route("run_script"), ToolSide::Server);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_unknown_defaults_to_server() {
|
||||
assert_eq!(route("unknown_tool"), ToolSide::Server);
|
||||
}
|
||||
}
|
||||
175
src/grpc/service.rs
Normal file
175
src/grpc/service.rs
Normal file
@@ -0,0 +1,175 @@
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::Stream;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_stream::wrappers::ReceiverStream;
|
||||
use tonic::{Request, Response, Status, Streaming};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use super::auth::Claims;
|
||||
use super::proto::code_agent_server::CodeAgent;
|
||||
use super::proto::*;
|
||||
use super::GrpcState;
|
||||
|
||||
pub struct CodeAgentService {
|
||||
state: Arc<GrpcState>,
|
||||
}
|
||||
|
||||
impl CodeAgentService {
|
||||
pub fn new(state: Arc<GrpcState>) -> Self {
|
||||
Self { state }
|
||||
}
|
||||
}
|
||||
|
||||
#[tonic::async_trait]
|
||||
impl CodeAgent for CodeAgentService {
|
||||
type SessionStream = Pin<Box<dyn Stream<Item = Result<ServerMessage, Status>> + Send>>;
|
||||
|
||||
async fn session(
|
||||
&self,
|
||||
request: Request<Streaming<ClientMessage>>,
|
||||
) -> Result<Response<Self::SessionStream>, Status> {
|
||||
// Extract JWT claims from the request extensions (set by auth middleware)
|
||||
let claims = request
|
||||
.extensions()
|
||||
.get::<Claims>()
|
||||
.cloned()
|
||||
.ok_or_else(|| Status::unauthenticated("No valid authentication token"))?;
|
||||
|
||||
info!(
|
||||
user = claims.sub.as_str(),
|
||||
email = claims.email.as_deref().unwrap_or("?"),
|
||||
"New coding session"
|
||||
);
|
||||
|
||||
let mut in_stream = request.into_inner();
|
||||
let state = self.state.clone();
|
||||
|
||||
// Channel for sending server messages to the client
|
||||
let (tx, rx) = mpsc::channel::<Result<ServerMessage, Status>>(32);
|
||||
|
||||
// Spawn the session handler
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = handle_session(&state, &claims, &mut in_stream, &tx).await {
|
||||
error!(user = claims.sub.as_str(), "Session error: {e}");
|
||||
let _ = tx
|
||||
.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::Error(Error {
|
||||
message: e.to_string(),
|
||||
fatal: true,
|
||||
})),
|
||||
}))
|
||||
.await;
|
||||
}
|
||||
});
|
||||
|
||||
let out_stream = ReceiverStream::new(rx);
|
||||
Ok(Response::new(Box::pin(out_stream)))
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle a single coding session (runs in a spawned task).
|
||||
async fn handle_session(
|
||||
state: &GrpcState,
|
||||
claims: &Claims,
|
||||
in_stream: &mut Streaming<ClientMessage>,
|
||||
tx: &mpsc::Sender<Result<ServerMessage, Status>>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Wait for the first message — must be StartSession
|
||||
let first = in_stream
|
||||
.message()
|
||||
.await?
|
||||
.ok_or_else(|| anyhow::anyhow!("Stream closed before StartSession"))?;
|
||||
|
||||
let start = match first.payload {
|
||||
Some(client_message::Payload::Start(s)) => s,
|
||||
_ => anyhow::bail!("First message must be StartSession"),
|
||||
};
|
||||
|
||||
info!(
|
||||
user = claims.sub.as_str(),
|
||||
project = start.project_path.as_str(),
|
||||
model = start.model.as_str(),
|
||||
client_tools = start.client_tools.len(),
|
||||
"Session started"
|
||||
);
|
||||
|
||||
// TODO Phase 2: Create/find Matrix room for this project
|
||||
// TODO Phase 2: Create Mistral conversation
|
||||
// TODO Phase 2: Enter agent loop
|
||||
|
||||
// For now, send SessionReady and echo back
|
||||
tx.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::Ready(SessionReady {
|
||||
session_id: uuid::Uuid::new_v4().to_string(),
|
||||
room_id: String::new(), // TODO: Matrix room
|
||||
model: if start.model.is_empty() {
|
||||
state
|
||||
.config
|
||||
.agents
|
||||
.coding_model
|
||||
.clone()
|
||||
} else {
|
||||
start.model.clone()
|
||||
},
|
||||
})),
|
||||
}))
|
||||
.await?;
|
||||
|
||||
// Main message loop
|
||||
while let Some(msg) = in_stream.message().await? {
|
||||
match msg.payload {
|
||||
Some(client_message::Payload::Input(input)) => {
|
||||
info!(
|
||||
user = claims.sub.as_str(),
|
||||
text_len = input.text.len(),
|
||||
"User input received"
|
||||
);
|
||||
|
||||
// TODO Phase 2: Send to Mistral, handle tool calls, stream response
|
||||
// For now, echo back as a simple acknowledgment
|
||||
tx.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::Done(TextDone {
|
||||
full_text: format!("[stub] received: {}", input.text),
|
||||
input_tokens: 0,
|
||||
output_tokens: 0,
|
||||
})),
|
||||
}))
|
||||
.await?;
|
||||
}
|
||||
Some(client_message::Payload::ToolResult(result)) => {
|
||||
info!(
|
||||
call_id = result.call_id.as_str(),
|
||||
is_error = result.is_error,
|
||||
"Tool result received"
|
||||
);
|
||||
// TODO Phase 2: Feed back to Mistral
|
||||
}
|
||||
Some(client_message::Payload::Approval(approval)) => {
|
||||
info!(
|
||||
call_id = approval.call_id.as_str(),
|
||||
approved = approval.approved,
|
||||
"Tool approval received"
|
||||
);
|
||||
// TODO Phase 2: Execute or skip tool
|
||||
}
|
||||
Some(client_message::Payload::End(_)) => {
|
||||
info!(user = claims.sub.as_str(), "Session ended by client");
|
||||
tx.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::End(SessionEnd {
|
||||
summary: "Session ended.".into(),
|
||||
})),
|
||||
}))
|
||||
.await?;
|
||||
break;
|
||||
}
|
||||
Some(client_message::Payload::Start(_)) => {
|
||||
warn!("Received duplicate StartSession — ignoring");
|
||||
}
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -8,6 +8,7 @@ mod conversations;
|
||||
mod matrix_utils;
|
||||
mod memory;
|
||||
mod persistence;
|
||||
mod grpc;
|
||||
mod sdk;
|
||||
mod sync;
|
||||
mod time_context;
|
||||
|
||||
Reference in New Issue
Block a user