From 40a6772f995c1d147b8584346b980965951e0b07 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Mon, 23 Mar 2026 20:54:28 +0000 Subject: [PATCH] feat: 13 e2e integration tests against real Mistral API MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Orchestrator tests: - Simple chat roundtrip with token usage verification - Event ordering (Started → Thinking → Done) - Metadata pass-through (opaque bag appears in Started event) - Token usage accuracy (longer prompts → more tokens) - Conversation continuity (multi-turn recall) - Client-side tool dispatch + mock result submission - Failed tool result handling (is_error: true) - Server-side tool execution (search_web via conversation) gRPC tests: - Full roundtrip (StartSession → UserInput → Status → TextDone) - Client tool relay (ToolCall → ToolResult through gRPC stream) - Token counts in TextDone (non-zero verification) - Session resume (same room_id, resumed flag) - Clean disconnect (EndSession → SessionEnd) Infrastructure: - ToolRegistry::new_minimal() — no OpenSearch/Matrix needed - ToolRegistry fields now Option for testability - GrpcState.matrix now Option - grpc_bridge moved to src/grpc/bridge.rs - TestHarness loads API key from .env --- Cargo.toml | 3 + src/grpc/mod.rs | 2 +- src/grpc/session.rs | 36 +- src/integration_test.rs | 923 ++++++++++++++++++++++++++++++++++++++++ src/main.rs | 4 +- src/tools/mod.rs | 43 +- 6 files changed, 980 insertions(+), 31 deletions(-) create mode 100644 src/integration_test.rs diff --git a/Cargo.toml b/Cargo.toml index 93c3b58..e2d69a9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -45,6 +45,9 @@ tokio-stream = "0.1" jsonwebtoken = "9" tokenizers = { version = "0.22", default-features = false, features = ["onig", "http"] } +[dev-dependencies] +dotenv = "0.15" + [build-dependencies] tonic-build = "0.14" tonic-prost-build = "0.14" diff --git a/src/grpc/mod.rs b/src/grpc/mod.rs index 6f3ba9b..071d47e 100644 --- a/src/grpc/mod.rs +++ b/src/grpc/mod.rs @@ -25,7 +25,7 @@ pub struct GrpcState { pub tools: Arc, pub store: Arc, pub mistral: Arc, - pub matrix: matrix_sdk::Client, + pub matrix: Option, pub system_prompt: String, pub orchestrator_agent_id: String, pub orchestrator: Option>, diff --git a/src/grpc/session.rs b/src/grpc/session.rs index 57b50d4..cf5a15d 100644 --- a/src/grpc/session.rs +++ b/src/grpc/session.rs @@ -59,10 +59,11 @@ impl CodeSession { "Resuming existing code session" ); - let room = state.matrix.get_room( + let room = state.matrix.as_ref().and_then(|m| { <&matrix_sdk::ruma::RoomId>::try_from(room_id.as_str()) - .map_err(|e| anyhow::anyhow!("Invalid room ID: {e}"))?, - ); + .ok() + .and_then(|rid| m.get_room(rid)) + }); state.store.touch_code_session(&session_id); @@ -85,21 +86,20 @@ impl CodeSession { // Create private Matrix room for this project let room_name = format!("code: {project_name}"); - let room_id = create_project_room(&state.matrix, &room_name, &claims.email) - .await - .unwrap_or_else(|e| { - warn!("Failed to create Matrix room: {e}"); - format!("!code-{session_id}:local") // fallback ID - }); - - let room = state.matrix.get_room( - <&matrix_sdk::ruma::RoomId>::try_from(room_id.as_str()).ok() - .unwrap_or_else(|| { - // This shouldn't happen but handle gracefully - warn!("Invalid room ID {room_id}, session will work without Matrix bridge"); - <&matrix_sdk::ruma::RoomId>::try_from("!invalid:local").unwrap() - }), - ); + let (room_id, room) = if let Some(ref matrix) = state.matrix { + let rid = create_project_room(matrix, &room_name, &claims.email) + .await + .unwrap_or_else(|e| { + warn!("Failed to create Matrix room: {e}"); + format!("!code-{session_id}:local") + }); + let room = <&matrix_sdk::ruma::RoomId>::try_from(rid.as_str()) + .ok() + .and_then(|r| matrix.get_room(r)); + (rid, room) + } else { + (format!("!code-{session_id}:local"), None) + }; state.store.create_code_session( &session_id, diff --git a/src/integration_test.rs b/src/integration_test.rs new file mode 100644 index 0000000..2dc06f0 --- /dev/null +++ b/src/integration_test.rs @@ -0,0 +1,923 @@ +//! End-to-end integration tests against the real Mistral API. +//! +//! Requires SOL_MISTRAL_API_KEY in .env file. +//! Run: cargo test integration_test -- --test-threads=1 + +#![cfg(test)] + +use std::sync::Arc; +use std::time::Duration; + +use crate::config::Config; +use crate::conversations::ConversationRegistry; +use crate::orchestrator::event::*; +use crate::orchestrator::Orchestrator; +use crate::persistence::Store; +use crate::tools::ToolRegistry; + +// ── Test harness ──────────────────────────────────────────────────────── + +struct TestHarness { + orchestrator: Arc, + event_rx: tokio::sync::broadcast::Receiver, +} + +impl TestHarness { + async fn new() -> Self { + // Load .env from project root + let env_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(".env"); + if let Ok(contents) = std::fs::read_to_string(&env_path) { + for line in contents.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { + continue; + } + if let Some((key, value)) = line.split_once('=') { + std::env::set_var(key.trim(), value.trim()); + } + } + } + + let api_key = std::env::var("SOL_MISTRAL_API_KEY") + .expect("SOL_MISTRAL_API_KEY must be set in .env"); + + let config = test_config(); + let mistral = Arc::new( + mistralai_client::v1::client::Client::new(Some(api_key), None, None, None) + .expect("Failed to create Mistral client"), + ); + let store = Arc::new(Store::open_memory().expect("Failed to create in-memory store")); + let conversations = Arc::new(ConversationRegistry::new( + config.agents.orchestrator_model.clone(), + config.agents.compaction_threshold, + store, + )); + + let tools = Arc::new(ToolRegistry::new_minimal(config.clone())); + + let orchestrator = Arc::new(Orchestrator::new( + config, + tools, + mistral, + conversations, + "you are sol. respond briefly and concisely. lowercase only.".into(), + )); + let event_rx = orchestrator.subscribe(); + + Self { orchestrator, event_rx } + } + + async fn collect_events_for( + &mut self, + request_id: &RequestId, + timeout_secs: u64, + ) -> Vec { + let mut events = Vec::new(); + let deadline = tokio::time::Instant::now() + Duration::from_secs(timeout_secs); + + loop { + match tokio::time::timeout_at(deadline, self.event_rx.recv()).await { + Ok(Ok(event)) => { + if event.request_id() != request_id { + continue; + } + let is_terminal = matches!( + event, + OrchestratorEvent::Done { .. } | OrchestratorEvent::Failed { .. } + ); + events.push(event); + if is_terminal { + break; + } + } + Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => continue, + Ok(Err(_)) => break, + Err(_) => panic!("Timeout after {timeout_secs}s waiting for events"), + } + } + + events + } +} + +fn test_config() -> Arc { + let toml = r#" + [matrix] + homeserver_url = "http://localhost:8008" + user_id = "@test:localhost" + state_store_path = "/tmp/sol-test-state" + db_path = ":memory:" + + [opensearch] + url = "http://localhost:9200" + index = "sol_test" + + [mistral] + default_model = "mistral-medium-latest" + max_tool_iterations = 10 + + [behavior] + instant_responses = true + memory_extraction_enabled = false + + [agents] + orchestrator_model = "mistral-medium-latest" + use_conversations_api = true + agent_prefix = "test" + + [grpc] + listen_addr = "0.0.0.0:0" + dev_mode = true + "#; + Arc::new(Config::from_str(toml).expect("Failed to parse test config")) +} + +fn make_request(text: &str) -> GenerateRequest { + GenerateRequest { + request_id: RequestId::new(), + text: text.into(), + user_id: "test-user".into(), + display_name: None, + conversation_key: format!("test-{}", uuid::Uuid::new_v4()), + is_direct: true, + image: None, + metadata: Metadata::new(), + } +} + +fn make_request_with_key(text: &str, conversation_key: &str) -> GenerateRequest { + GenerateRequest { + request_id: RequestId::new(), + text: text.into(), + user_id: "test-user".into(), + display_name: None, + conversation_key: conversation_key.into(), + is_direct: true, + image: None, + metadata: Metadata::new(), + } +} + +// ── Test 1: Simple chat round-trip ────────────────────────────────────── + +#[tokio::test] +async fn test_simple_chat_roundtrip() { + let mut h = TestHarness::new().await; + let request = make_request("what is 2+2? answer with just the number."); + + let rid = request.request_id.clone(); + let orch = h.orchestrator.clone(); + let gen = tokio::spawn(async move { orch.generate(&request).await }); + + let events = h.collect_events_for(&rid, 30).await; + let result = gen.await.unwrap(); + + assert!(result.is_some(), "Expected a response"); + let text = result.unwrap(); + assert!(text.contains('4'), "Expected '4' in response, got: {text}"); + + assert!(events.iter().any(|e| matches!(e, OrchestratorEvent::Started { .. }))); + assert!(events.iter().any(|e| matches!(e, OrchestratorEvent::Thinking { .. }))); + + let done = events.iter().find(|e| matches!(e, OrchestratorEvent::Done { .. })); + assert!(done.is_some(), "Missing Done event"); + if let Some(OrchestratorEvent::Done { usage, .. }) = done { + assert!(usage.prompt_tokens > 0); + assert!(usage.completion_tokens > 0); + } +} + +// ── Test 2: Conversation continuity ───────────────────────────────────── + +#[tokio::test] +async fn test_conversation_continuity() { + let mut h = TestHarness::new().await; + let conv_key = format!("test-{}", uuid::Uuid::new_v4()); + + // Turn 1 + let r1 = make_request_with_key("my favorite color is cerulean. just acknowledge.", &conv_key); + let rid1 = r1.request_id.clone(); + let orch1 = h.orchestrator.clone(); + let gen1 = tokio::spawn(async move { orch1.generate(&r1).await }); + h.collect_events_for(&rid1, 30).await; + let result1 = gen1.await.unwrap(); + assert!(result1.is_some(), "Turn 1 should get a response"); + + // Turn 2 + let r2 = make_request_with_key("what is my favorite color?", &conv_key); + let rid2 = r2.request_id.clone(); + let orch2 = h.orchestrator.clone(); + let gen2 = tokio::spawn(async move { orch2.generate(&r2).await }); + h.collect_events_for(&rid2, 30).await; + let result2 = gen2.await.unwrap(); + + assert!(result2.is_some(), "Turn 2 should get a response"); + let text = result2.unwrap().to_lowercase(); + assert!(text.contains("cerulean"), "Expected 'cerulean', got: {text}"); +} + +// ── Test 3: Client-side tool dispatch ─────────────────────────────────── + +#[tokio::test] +async fn test_client_tool_dispatch() { + use mistralai_client::v1::conversations::{ + ConversationInput, CreateConversationRequest, + }; + use mistralai_client::v1::agents::AgentTool; + + let mut h = TestHarness::new().await; + + // Create conversation directly with file_read tool defined + let api_key = std::env::var("SOL_MISTRAL_API_KEY") + .expect("SOL_MISTRAL_API_KEY must be set"); + let mistral = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None) + .expect("Failed to create Mistral client"); + + let file_read_tool = AgentTool::function( + "file_read".into(), + "Read a file's contents. Use path for the file path.".into(), + serde_json::json!({ + "type": "object", + "properties": { + "path": { "type": "string", "description": "File path to read" } + }, + "required": ["path"] + }), + ); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text( + "use your file_read tool to read the file at path 'README.md'".into(), + ), + model: Some("mistral-medium-latest".into()), + agent_id: None, + agent_version: None, + name: Some("test-client-tool".into()), + description: None, + instructions: Some("you are a coding assistant. use tools when asked.".into()), + completion_args: None, + tools: Some(vec![file_read_tool]), + handoff_execution: None, + metadata: None, + store: Some(true), + stream: false, + }; + + let conv_response = mistral + .create_conversation_async(&req) + .await + .expect("Failed to create conversation"); + + // Now pass to orchestrator + let request = make_request("use your file_read tool to read README.md"); + let rid = request.request_id.clone(); + let orch = h.orchestrator.clone(); + let orch_submit = h.orchestrator.clone(); + + let gen = tokio::spawn(async move { + orch.generate_from_response(&request, conv_response).await + }); + + let mut got_client_tool = false; + let deadline = tokio::time::Instant::now() + Duration::from_secs(60); + + loop { + match tokio::time::timeout_at(deadline, h.event_rx.recv()).await { + Ok(Ok(event)) => { + if event.request_id() != &rid { continue; } + match event { + OrchestratorEvent::ToolCallDetected { + side: ToolSide::Client, call_id, ref name, .. + } => { + assert_eq!(name, "file_read"); + got_client_tool = true; + orch_submit + .submit_tool_result(&call_id, ToolResultPayload { + text: "# Sol\n\nVirtual librarian for Sunbeam.".into(), + is_error: false, + }) + .await + .expect("Failed to submit tool result"); + } + OrchestratorEvent::Done { .. } | OrchestratorEvent::Failed { .. } => break, + _ => continue, + } + } + Ok(Err(_)) => break, + Err(_) => panic!("Timeout waiting for client tool dispatch"), + } + } + + assert!(got_client_tool, "Expected client-side file_read tool call"); + let result = gen.await.unwrap(); + assert!(result.is_some(), "Expected response after tool execution"); +} + +// ── Test 4: Event ordering ─────────────────────────────────────────────── + +#[tokio::test] +async fn test_event_ordering() { + let mut h = TestHarness::new().await; + let request = make_request("say hello"); + + let rid = request.request_id.clone(); + let orch = h.orchestrator.clone(); + let gen = tokio::spawn(async move { orch.generate(&request).await }); + + let events = h.collect_events_for(&rid, 30).await; + let _ = gen.await; + + // Verify strict ordering: Started → Thinking → Done + assert!(events.len() >= 3, "Expected at least 3 events, got {}", events.len()); + assert!(matches!(events[0], OrchestratorEvent::Started { .. }), "First event should be Started"); + assert!(matches!(events[1], OrchestratorEvent::Thinking { .. }), "Second event should be Thinking"); + assert!(matches!(events.last().unwrap(), OrchestratorEvent::Done { .. }), "Last event should be Done"); +} + +// ── Test 5: Metadata pass-through ─────────────────────────────────────── + +#[tokio::test] +async fn test_metadata_passthrough() { + let mut h = TestHarness::new().await; + let mut request = make_request("hi"); + request.metadata.insert("room_id", "!test-room:localhost"); + request.metadata.insert("custom_key", "custom_value"); + + let rid = request.request_id.clone(); + let orch = h.orchestrator.clone(); + let gen = tokio::spawn(async move { orch.generate(&request).await }); + + let events = h.collect_events_for(&rid, 30).await; + let _ = gen.await; + + // Started event should carry metadata + let started = events.iter().find(|e| matches!(e, OrchestratorEvent::Started { .. })); + assert!(started.is_some(), "Missing Started event"); + if let Some(OrchestratorEvent::Started { metadata, .. }) = started { + assert_eq!(metadata.get("room_id"), Some("!test-room:localhost")); + assert_eq!(metadata.get("custom_key"), Some("custom_value")); + } +} + +// ── Test 6: Token usage accuracy ──────────────────────────────────────── + +#[tokio::test] +async fn test_token_usage_accuracy() { + let mut h = TestHarness::new().await; + + // Short prompt → small token count + let r1 = make_request("say ok"); + let rid1 = r1.request_id.clone(); + let orch1 = h.orchestrator.clone(); + let gen1 = tokio::spawn(async move { orch1.generate(&r1).await }); + let events1 = h.collect_events_for(&rid1, 30).await; + let _ = gen1.await; + + let done1 = events1.iter().find_map(|e| match e { + OrchestratorEvent::Done { usage, .. } => Some(usage.clone()), + _ => None, + }).expect("Missing Done event"); + + // Longer prompt → larger token count + let r2 = make_request( + "write a haiku about the sun setting over the ocean. include imagery of waves." + ); + let rid2 = r2.request_id.clone(); + let orch2 = h.orchestrator.clone(); + let gen2 = tokio::spawn(async move { orch2.generate(&r2).await }); + let events2 = h.collect_events_for(&rid2, 30).await; + let _ = gen2.await; + + let done2 = events2.iter().find_map(|e| match e { + OrchestratorEvent::Done { usage, .. } => Some(usage.clone()), + _ => None, + }).expect("Missing Done event"); + + // Both should have non-zero tokens + assert!(done1.prompt_tokens > 0); + assert!(done1.completion_tokens > 0); + assert!(done2.prompt_tokens > 0); + assert!(done2.completion_tokens > 0); + + // The longer prompt should use more completion tokens (haiku vs "ok") + assert!( + done2.completion_tokens > done1.completion_tokens, + "Longer request should produce more completion tokens: {} vs {}", + done2.completion_tokens, done1.completion_tokens + ); +} + +// ── Test 7: Failed tool result ────────────────────────────────────────── + +#[tokio::test] +async fn test_failed_tool_result() { + use mistralai_client::v1::conversations::{ + ConversationInput, CreateConversationRequest, + }; + use mistralai_client::v1::agents::AgentTool; + + let mut h = TestHarness::new().await; + + let api_key = std::env::var("SOL_MISTRAL_API_KEY").unwrap(); + let mistral = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap(); + + let tool = AgentTool::function( + "file_read".into(), + "Read a file.".into(), + serde_json::json!({ + "type": "object", + "properties": { "path": { "type": "string" } }, + "required": ["path"] + }), + ); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("read the file at /nonexistent/path".into()), + model: Some("mistral-medium-latest".into()), + agent_id: None, + agent_version: None, + name: Some("test-failed-tool".into()), + description: None, + instructions: Some("use tools when asked.".into()), + completion_args: None, + tools: Some(vec![tool]), + handoff_execution: None, + metadata: None, + store: Some(true), + stream: false, + }; + + let conv_response = mistral.create_conversation_async(&req).await.unwrap(); + let request = make_request("read /nonexistent/path"); + let rid = request.request_id.clone(); + let orch = h.orchestrator.clone(); + let orch_submit = h.orchestrator.clone(); + + let gen = tokio::spawn(async move { + orch.generate_from_response(&request, conv_response).await + }); + + // Submit error result when tool is called + let deadline = tokio::time::Instant::now() + Duration::from_secs(60); + loop { + match tokio::time::timeout_at(deadline, h.event_rx.recv()).await { + Ok(Ok(event)) => { + if event.request_id() != &rid { continue; } + match event { + OrchestratorEvent::ToolCallDetected { side: ToolSide::Client, call_id, .. } => { + orch_submit.submit_tool_result(&call_id, ToolResultPayload { + text: "Error: file not found".into(), + is_error: true, + }).await.unwrap(); + } + OrchestratorEvent::ToolCompleted { success, .. } => { + assert!(!success, "Expected tool to report failure"); + } + OrchestratorEvent::Done { .. } | OrchestratorEvent::Failed { .. } => break, + _ => continue, + } + } + Ok(Err(_)) => break, + Err(_) => panic!("Timeout"), + } + } + + // Model should still produce a response (explaining the error) + let result = gen.await.unwrap(); + assert!(result.is_some(), "Expected response even after tool error"); +} + +// ── Test 8: Server-side tool execution (search_web) ───────────────────── +// Note: run_script requires deno sandbox + tool definitions from the agent. +// search_web is more reliably available in test conversations. + +#[tokio::test] +async fn test_server_tool_execution() { + use mistralai_client::v1::conversations::{ + ConversationInput, CreateConversationRequest, + }; + use mistralai_client::v1::agents::AgentTool; + + let mut h = TestHarness::new().await; + + let api_key = std::env::var("SOL_MISTRAL_API_KEY").unwrap(); + let mistral = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap(); + + // Create conversation with search_web tool + let tool = AgentTool::function( + "search_web".into(), + "Search the web. Returns titles, URLs, and snippets.".into(), + serde_json::json!({ + "type": "object", + "properties": { + "query": { "type": "string", "description": "Search query" } + }, + "required": ["query"] + }), + ); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("search the web for 'rust programming language'".into()), + model: Some("mistral-medium-latest".into()), + agent_id: None, + agent_version: None, + name: Some("test-server-tool".into()), + description: None, + instructions: Some("use tools when asked. always use the search_web tool for any web search request.".into()), + completion_args: None, + tools: Some(vec![tool]), + handoff_execution: None, + metadata: None, + store: Some(true), + stream: false, + }; + + let conv_response = mistral.create_conversation_async(&req).await.unwrap(); + let request = make_request("search the web for rust"); + let rid = request.request_id.clone(); + let orch = h.orchestrator.clone(); + let gen = tokio::spawn(async move { + orch.generate_from_response(&request, conv_response).await + }); + + let events = h.collect_events_for(&rid, 60).await; + let result = gen.await.unwrap(); + + // May or may not produce a result (search_web needs SearXNG running) + // But we should at least see the tool call events + let tool_detected = events.iter().find(|e| matches!(e, OrchestratorEvent::ToolCallDetected { .. })); + assert!(tool_detected.is_some(), "Expected ToolCallDetected for search_web"); + + if let Some(OrchestratorEvent::ToolCallDetected { side, name, .. }) = tool_detected { + assert_eq!(*side, ToolSide::Server); + assert_eq!(name, "search_web"); + } + + assert!(events.iter().any(|e| matches!(e, OrchestratorEvent::ToolStarted { .. }))); + assert!(events.iter().any(|e| matches!(e, OrchestratorEvent::ToolCompleted { .. }))); +} + +// ══════════════════════════════════════════════════════════════════════════ +// gRPC integration tests — full round-trip through the gRPC server +// ══════════════════════════════════════════════════════════════════════════ + +mod grpc_tests { + use super::*; + use crate::grpc::{self, GrpcState}; + use crate::grpc::code_agent_client::CodeAgentClient; + use crate::grpc::*; + use tokio::sync::mpsc; + use tokio_stream::wrappers::ReceiverStream; + + /// Start a gRPC server on a random port and return the endpoint URL. + async fn start_test_server() -> (String, Arc) { + let env_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(".env"); + if let Ok(contents) = std::fs::read_to_string(&env_path) { + for line in contents.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { continue; } + if let Some((k, v)) = line.split_once('=') { + std::env::set_var(k.trim(), v.trim()); + } + } + } + + let api_key = std::env::var("SOL_MISTRAL_API_KEY") + .expect("SOL_MISTRAL_API_KEY must be set"); + + let config = test_config(); + let mistral = Arc::new( + mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap(), + ); + let store = Arc::new(Store::open_memory().unwrap()); + let conversations = Arc::new(ConversationRegistry::new( + config.agents.orchestrator_model.clone(), + config.agents.compaction_threshold, + store.clone(), + )); + let tools = Arc::new(ToolRegistry::new_minimal(config.clone())); + let orch = Arc::new(Orchestrator::new( + config.clone(), tools.clone(), mistral.clone(), conversations, + "you are sol. respond briefly. lowercase only.".into(), + )); + + let grpc_state = Arc::new(GrpcState { + config: config.clone(), + tools, + store, + mistral, + matrix: None, // not needed for tests + system_prompt: "you are sol. respond briefly. lowercase only.".into(), + orchestrator_agent_id: String::new(), + orchestrator: Some(orch), + }); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let endpoint = format!("http://{addr}"); + + let state = grpc_state.clone(); + tokio::spawn(async move { + let svc = crate::grpc::service::CodeAgentService::new(state); + let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener); + tonic::transport::Server::builder() + .add_service(CodeAgentServer::new(svc)) + .serve_with_incoming(incoming) + .await + .unwrap(); + }); + + tokio::time::sleep(Duration::from_millis(100)).await; + (endpoint, grpc_state) + } + + /// Connect a gRPC client and send StartSession. Returns (tx, rx, session_ready). + async fn connect_session( + endpoint: &str, + ) -> ( + mpsc::Sender, + tonic::Streaming, + SessionReady, + ) { + let mut client = CodeAgentClient::connect(endpoint.to_string()).await.unwrap(); + + let (tx, client_rx) = mpsc::channel::(32); + let stream = ReceiverStream::new(client_rx); + let response = client.session(stream).await.unwrap(); + let mut rx = response.into_inner(); + + // Send StartSession + tx.send(ClientMessage { + payload: Some(client_message::Payload::Start(StartSession { + project_path: "/tmp/test-project".into(), + prompt_md: String::new(), + config_toml: String::new(), + git_branch: "main".into(), + git_status: String::new(), + file_tree: vec![], + model: "mistral-medium-latest".into(), + client_tools: vec![], + })), + }) + .await + .unwrap(); + + // Wait for SessionReady + let ready = loop { + match rx.message().await.unwrap() { + Some(ServerMessage { payload: Some(server_message::Payload::Ready(r)) }) => break r, + Some(ServerMessage { payload: Some(server_message::Payload::Error(e)) }) => { + panic!("Session start failed: {}", e.message); + } + _ => continue, + } + }; + + (tx, rx, ready) + } + + #[tokio::test] + async fn test_grpc_simple_roundtrip() { + let (endpoint, _state) = start_test_server().await; + let (tx, mut rx, ready) = connect_session(&endpoint).await; + + assert!(!ready.session_id.is_empty()); + assert!(!ready.room_id.is_empty()); + + // Send a message + tx.send(ClientMessage { + payload: Some(client_message::Payload::Input(UserInput { + text: "what is 3+3? answer with just the number.".into(), + })), + }) + .await + .unwrap(); + + // Collect server messages until TextDone + let mut got_status = false; + let mut got_done = false; + let deadline = tokio::time::Instant::now() + Duration::from_secs(30); + + loop { + match tokio::time::timeout_at(deadline, rx.message()).await { + Ok(Ok(Some(msg))) => match msg.payload { + Some(server_message::Payload::Status(_)) => got_status = true, + Some(server_message::Payload::Done(d)) => { + got_done = true; + assert!(d.full_text.contains('6'), "Expected '6', got: {}", d.full_text); + assert!(d.input_tokens > 0, "Expected non-zero input tokens"); + assert!(d.output_tokens > 0, "Expected non-zero output tokens"); + break; + } + Some(server_message::Payload::Error(e)) => { + panic!("Server error: {}", e.message); + } + _ => continue, + }, + Ok(Ok(None)) => panic!("Stream closed before Done"), + Ok(Err(e)) => panic!("Stream error: {e}"), + Err(_) => panic!("Timeout waiting for gRPC response"), + } + } + + assert!(got_status, "Expected at least one Status message"); + assert!(got_done, "Expected TextDone message"); + + // Clean disconnect + tx.send(ClientMessage { + payload: Some(client_message::Payload::End(EndSession {})), + }) + .await + .unwrap(); + } + + #[tokio::test] + async fn test_grpc_client_tool_relay() { + let (endpoint, _state) = start_test_server().await; + let (tx, mut rx, _ready) = connect_session(&endpoint).await; + + // Send a message that should trigger file_read + tx.send(ClientMessage { + payload: Some(client_message::Payload::Input(UserInput { + text: "use file_read to read README.md".into(), + })), + }) + .await + .unwrap(); + + let mut got_tool_call = false; + let mut got_done = false; + let deadline = tokio::time::Instant::now() + Duration::from_secs(60); + + loop { + match tokio::time::timeout_at(deadline, rx.message()).await { + Ok(Ok(Some(msg))) => match msg.payload { + Some(server_message::Payload::ToolCall(tc)) => { + assert!(tc.is_local, "Expected local tool, got: {}", tc.name); + // Model may call file_read, list_directory, or other client tools + got_tool_call = true; + + // Send tool result back + tx.send(ClientMessage { + payload: Some(client_message::Payload::ToolResult( + crate::grpc::ToolResult { + call_id: tc.call_id, + result: "# Sol\nVirtual librarian.".into(), + is_error: false, + }, + )), + }) + .await + .unwrap(); + } + Some(server_message::Payload::Done(d)) => { + got_done = true; + assert!(!d.full_text.is_empty(), "Expected non-empty response"); + break; + } + Some(server_message::Payload::Error(e)) => { + panic!("Server error: {}", e.message); + } + _ => continue, + }, + Ok(Ok(None)) => break, + Ok(Err(e)) => panic!("Stream error: {e}"), + Err(_) => panic!("Timeout"), + } + } + + assert!(got_tool_call, "Expected ToolCall for file_read"); + assert!(got_done, "Expected TextDone after tool execution"); + } + + #[tokio::test] + async fn test_grpc_token_counts() { + let (endpoint, _state) = start_test_server().await; + let (tx, mut rx, _ready) = connect_session(&endpoint).await; + + tx.send(ClientMessage { + payload: Some(client_message::Payload::Input(UserInput { + text: "say hello".into(), + })), + }) + .await + .unwrap(); + + let deadline = tokio::time::Instant::now() + Duration::from_secs(30); + loop { + match tokio::time::timeout_at(deadline, rx.message()).await { + Ok(Ok(Some(msg))) => match msg.payload { + Some(server_message::Payload::Done(d)) => { + assert!(d.input_tokens > 0, "input_tokens should be > 0, got {}", d.input_tokens); + assert!(d.output_tokens > 0, "output_tokens should be > 0, got {}", d.output_tokens); + break; + } + Some(server_message::Payload::Error(e)) => panic!("Error: {}", e.message), + _ => continue, + }, + Ok(Ok(None)) => panic!("Stream closed"), + Ok(Err(e)) => panic!("Stream error: {e}"), + Err(_) => panic!("Timeout"), + } + } + } + + #[tokio::test] + async fn test_grpc_session_resume() { + let (endpoint, _state) = start_test_server().await; + + // Session 1: establish context + let (tx1, mut rx1, ready1) = connect_session(&endpoint).await; + tx1.send(ClientMessage { + payload: Some(client_message::Payload::Input(UserInput { + text: "my secret code is 42. remember it.".into(), + })), + }).await.unwrap(); + + // Wait for response + let deadline = tokio::time::Instant::now() + Duration::from_secs(30); + loop { + match tokio::time::timeout_at(deadline, rx1.message()).await { + Ok(Ok(Some(msg))) => match msg.payload { + Some(server_message::Payload::Done(_)) => break, + Some(server_message::Payload::Error(e)) => panic!("Error: {}", e.message), + _ => continue, + }, + Ok(Ok(None)) => break, + Ok(Err(e)) => panic!("Error: {e}"), + Err(_) => panic!("Timeout"), + } + } + + // Disconnect (don't send End — keeps session active) + drop(tx1); + drop(rx1); + + // Session 2: reconnect — should resume the same session + let (tx2, mut rx2, ready2) = connect_session(&endpoint).await; + + // Should be the same session (same project path → same room) + assert_eq!(ready2.room_id, ready1.room_id, "Should resume same room"); + assert!(ready2.resumed, "Should indicate resumed session"); + + // History requires Matrix (not available in tests) — just check session resumed + + // Ask for recall + tx2.send(ClientMessage { + payload: Some(client_message::Payload::Input(UserInput { + text: "what is my secret code?".into(), + })), + }).await.unwrap(); + + let deadline = tokio::time::Instant::now() + Duration::from_secs(30); + loop { + match tokio::time::timeout_at(deadline, rx2.message()).await { + Ok(Ok(Some(msg))) => match msg.payload { + Some(server_message::Payload::Done(d)) => { + assert!( + d.full_text.contains("42"), + "Expected model to recall '42', got: {}", + d.full_text + ); + break; + } + Some(server_message::Payload::Error(e)) => panic!("Error: {}", e.message), + _ => continue, + }, + Ok(Ok(None)) => panic!("Stream closed"), + Ok(Err(e)) => panic!("Error: {e}"), + Err(_) => panic!("Timeout"), + } + } + } + + #[tokio::test] + async fn test_grpc_clean_disconnect() { + let (endpoint, _state) = start_test_server().await; + let (tx, mut rx, ready) = connect_session(&endpoint).await; + + assert!(!ready.session_id.is_empty()); + + // Clean disconnect + tx.send(ClientMessage { + payload: Some(client_message::Payload::End(EndSession {})), + }).await.unwrap(); + + // Should get SessionEnd + let deadline = tokio::time::Instant::now() + Duration::from_secs(5); + let mut got_end = false; + loop { + match tokio::time::timeout_at(deadline, rx.message()).await { + Ok(Ok(Some(msg))) => match msg.payload { + Some(server_message::Payload::End(_)) => { got_end = true; break; } + _ => continue, + }, + Ok(Ok(None)) => break, + Ok(Err(_)) => break, + Err(_) => break, + } + } + + assert!(got_end, "Server should send SessionEnd on clean disconnect"); + } +} diff --git a/src/main.rs b/src/main.rs index e458cae..11ed79b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -10,6 +10,8 @@ mod memory; mod persistence; mod grpc; mod orchestrator; +#[cfg(test)] +mod integration_test; mod sdk; mod sync; mod time_context; @@ -319,7 +321,7 @@ async fn main() -> anyhow::Result<()> { tools: state.responder.tools(), store: store.clone(), mistral: state.mistral.clone(), - matrix: matrix_client.clone(), + matrix: Some(matrix_client.clone()), system_prompt: system_prompt_text.clone(), orchestrator_agent_id: orchestrator_id, orchestrator: Some(orch), diff --git a/src/tools/mod.rs b/src/tools/mod.rs index a5282a4..82c61bd 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -27,8 +27,8 @@ use crate::sdk::kratos::KratosClient; pub struct ToolRegistry { - opensearch: OpenSearch, - matrix: MatrixClient, + opensearch: Option, + matrix: Option, config: Arc, gitea: Option>, kratos: Option>, @@ -47,8 +47,8 @@ impl ToolRegistry { store: Option>, ) -> Self { Self { - opensearch, - matrix, + opensearch: Some(opensearch), + matrix: Some(matrix), config, gitea, kratos, @@ -57,6 +57,21 @@ impl ToolRegistry { } } + /// Create a minimal ToolRegistry for integration tests. + /// Only `run_script` works (deno sandbox). Tools needing OpenSearch + /// or Matrix will return errors if called. + pub fn new_minimal(config: Arc) -> Self { + Self { + opensearch: None, + matrix: None, + config, + gitea: None, + kratos: None, + mistral: None, + store: None, + } + } + pub fn has_gitea(&self) -> bool { self.gitea.is_some() } @@ -232,7 +247,10 @@ impl ToolRegistry { /// its members are also members of the requesting room. This is enforced /// at the query level — Sol never sees filtered-out results. async fn allowed_room_ids(&self, requesting_room_id: &str) -> Vec { - let rooms = self.matrix.joined_rooms(); + let Some(ref matrix) = self.matrix else { + return vec![requesting_room_id.to_string()]; + }; + let rooms = matrix.joined_rooms(); // Get requesting room's member set let requesting_room = rooms.iter().find(|r| r.room_id().as_str() == requesting_room_id); @@ -304,11 +322,14 @@ impl ToolRegistry { arguments: &str, response_ctx: &ResponseContext, ) -> anyhow::Result { + let os = || self.opensearch.as_ref().ok_or_else(|| anyhow::anyhow!("OpenSearch not configured")); + let mx = || self.matrix.as_ref().ok_or_else(|| anyhow::anyhow!("Matrix not configured")); + match name { "search_archive" => { let allowed = self.allowed_room_ids(&response_ctx.room_id).await; search::search_archive( - &self.opensearch, + os()?, &self.config.opensearch.index, arguments, &allowed, @@ -318,20 +339,20 @@ impl ToolRegistry { "get_room_context" => { let allowed = self.allowed_room_ids(&response_ctx.room_id).await; room_history::get_room_context( - &self.opensearch, + os()?, &self.config.opensearch.index, arguments, &allowed, ) .await } - "list_rooms" => room_info::list_rooms(&self.matrix).await, - "get_room_members" => room_info::get_room_members(&self.matrix, arguments).await, + "list_rooms" => room_info::list_rooms(mx()?).await, + "get_room_members" => room_info::get_room_members(mx()?, arguments).await, "run_script" => { let allowed = self.allowed_room_ids(&response_ctx.room_id).await; script::run_script( - &self.opensearch, - &self.matrix, + os()?, + mx()?, &self.config, arguments, response_ctx,