feat: 13 e2e integration tests against real Mistral API

Orchestrator tests:
- Simple chat roundtrip with token usage verification
- Event ordering (Started → Thinking → Done)
- Metadata pass-through (opaque bag appears in Started event)
- Token usage accuracy (longer prompts → more tokens)
- Conversation continuity (multi-turn recall)
- Client-side tool dispatch + mock result submission
- Failed tool result handling (is_error: true)
- Server-side tool execution (search_web via conversation)

gRPC tests:
- Full roundtrip (StartSession → UserInput → Status → TextDone)
- Client tool relay (ToolCall → ToolResult through gRPC stream)
- Token counts in TextDone (non-zero verification)
- Session resume (same room_id, resumed flag)
- Clean disconnect (EndSession → SessionEnd)

Infrastructure:
- ToolRegistry::new_minimal() — no OpenSearch/Matrix needed
- ToolRegistry fields now Option for testability
- GrpcState.matrix now Option
- grpc_bridge moved to src/grpc/bridge.rs
- TestHarness loads API key from .env
This commit is contained in:
2026-03-23 20:54:28 +00:00
parent 2810143f76
commit 40a6772f99
6 changed files with 980 additions and 31 deletions

View File

@@ -45,6 +45,9 @@ tokio-stream = "0.1"
jsonwebtoken = "9" jsonwebtoken = "9"
tokenizers = { version = "0.22", default-features = false, features = ["onig", "http"] } tokenizers = { version = "0.22", default-features = false, features = ["onig", "http"] }
[dev-dependencies]
dotenv = "0.15"
[build-dependencies] [build-dependencies]
tonic-build = "0.14" tonic-build = "0.14"
tonic-prost-build = "0.14" tonic-prost-build = "0.14"

View File

@@ -25,7 +25,7 @@ pub struct GrpcState {
pub tools: Arc<ToolRegistry>, pub tools: Arc<ToolRegistry>,
pub store: Arc<Store>, pub store: Arc<Store>,
pub mistral: Arc<mistralai_client::v1::client::Client>, pub mistral: Arc<mistralai_client::v1::client::Client>,
pub matrix: matrix_sdk::Client, pub matrix: Option<matrix_sdk::Client>,
pub system_prompt: String, pub system_prompt: String,
pub orchestrator_agent_id: String, pub orchestrator_agent_id: String,
pub orchestrator: Option<Arc<crate::orchestrator::Orchestrator>>, pub orchestrator: Option<Arc<crate::orchestrator::Orchestrator>>,

View File

@@ -59,10 +59,11 @@ impl CodeSession {
"Resuming existing code session" "Resuming existing code session"
); );
let room = state.matrix.get_room( let room = state.matrix.as_ref().and_then(|m| {
<&matrix_sdk::ruma::RoomId>::try_from(room_id.as_str()) <&matrix_sdk::ruma::RoomId>::try_from(room_id.as_str())
.map_err(|e| anyhow::anyhow!("Invalid room ID: {e}"))?, .ok()
); .and_then(|rid| m.get_room(rid))
});
state.store.touch_code_session(&session_id); state.store.touch_code_session(&session_id);
@@ -85,21 +86,20 @@ impl CodeSession {
// Create private Matrix room for this project // Create private Matrix room for this project
let room_name = format!("code: {project_name}"); let room_name = format!("code: {project_name}");
let room_id = create_project_room(&state.matrix, &room_name, &claims.email) let (room_id, room) = if let Some(ref matrix) = state.matrix {
.await let rid = create_project_room(matrix, &room_name, &claims.email)
.unwrap_or_else(|e| { .await
warn!("Failed to create Matrix room: {e}"); .unwrap_or_else(|e| {
format!("!code-{session_id}:local") // fallback ID warn!("Failed to create Matrix room: {e}");
}); format!("!code-{session_id}:local")
});
let room = state.matrix.get_room( let room = <&matrix_sdk::ruma::RoomId>::try_from(rid.as_str())
<&matrix_sdk::ruma::RoomId>::try_from(room_id.as_str()).ok() .ok()
.unwrap_or_else(|| { .and_then(|r| matrix.get_room(r));
// This shouldn't happen but handle gracefully (rid, room)
warn!("Invalid room ID {room_id}, session will work without Matrix bridge"); } else {
<&matrix_sdk::ruma::RoomId>::try_from("!invalid:local").unwrap() (format!("!code-{session_id}:local"), None)
}), };
);
state.store.create_code_session( state.store.create_code_session(
&session_id, &session_id,

923
src/integration_test.rs Normal file
View File

@@ -0,0 +1,923 @@
//! End-to-end integration tests against the real Mistral API.
//!
//! Requires SOL_MISTRAL_API_KEY in .env file.
//! Run: cargo test integration_test -- --test-threads=1
#![cfg(test)]
use std::sync::Arc;
use std::time::Duration;
use crate::config::Config;
use crate::conversations::ConversationRegistry;
use crate::orchestrator::event::*;
use crate::orchestrator::Orchestrator;
use crate::persistence::Store;
use crate::tools::ToolRegistry;
// ── Test harness ────────────────────────────────────────────────────────
struct TestHarness {
orchestrator: Arc<Orchestrator>,
event_rx: tokio::sync::broadcast::Receiver<OrchestratorEvent>,
}
impl TestHarness {
async fn new() -> Self {
// Load .env from project root
let env_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(".env");
if let Ok(contents) = std::fs::read_to_string(&env_path) {
for line in contents.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=') {
std::env::set_var(key.trim(), value.trim());
}
}
}
let api_key = std::env::var("SOL_MISTRAL_API_KEY")
.expect("SOL_MISTRAL_API_KEY must be set in .env");
let config = test_config();
let mistral = Arc::new(
mistralai_client::v1::client::Client::new(Some(api_key), None, None, None)
.expect("Failed to create Mistral client"),
);
let store = Arc::new(Store::open_memory().expect("Failed to create in-memory store"));
let conversations = Arc::new(ConversationRegistry::new(
config.agents.orchestrator_model.clone(),
config.agents.compaction_threshold,
store,
));
let tools = Arc::new(ToolRegistry::new_minimal(config.clone()));
let orchestrator = Arc::new(Orchestrator::new(
config,
tools,
mistral,
conversations,
"you are sol. respond briefly and concisely. lowercase only.".into(),
));
let event_rx = orchestrator.subscribe();
Self { orchestrator, event_rx }
}
async fn collect_events_for(
&mut self,
request_id: &RequestId,
timeout_secs: u64,
) -> Vec<OrchestratorEvent> {
let mut events = Vec::new();
let deadline = tokio::time::Instant::now() + Duration::from_secs(timeout_secs);
loop {
match tokio::time::timeout_at(deadline, self.event_rx.recv()).await {
Ok(Ok(event)) => {
if event.request_id() != request_id {
continue;
}
let is_terminal = matches!(
event,
OrchestratorEvent::Done { .. } | OrchestratorEvent::Failed { .. }
);
events.push(event);
if is_terminal {
break;
}
}
Ok(Err(tokio::sync::broadcast::error::RecvError::Lagged(_))) => continue,
Ok(Err(_)) => break,
Err(_) => panic!("Timeout after {timeout_secs}s waiting for events"),
}
}
events
}
}
fn test_config() -> Arc<Config> {
let toml = r#"
[matrix]
homeserver_url = "http://localhost:8008"
user_id = "@test:localhost"
state_store_path = "/tmp/sol-test-state"
db_path = ":memory:"
[opensearch]
url = "http://localhost:9200"
index = "sol_test"
[mistral]
default_model = "mistral-medium-latest"
max_tool_iterations = 10
[behavior]
instant_responses = true
memory_extraction_enabled = false
[agents]
orchestrator_model = "mistral-medium-latest"
use_conversations_api = true
agent_prefix = "test"
[grpc]
listen_addr = "0.0.0.0:0"
dev_mode = true
"#;
Arc::new(Config::from_str(toml).expect("Failed to parse test config"))
}
fn make_request(text: &str) -> GenerateRequest {
GenerateRequest {
request_id: RequestId::new(),
text: text.into(),
user_id: "test-user".into(),
display_name: None,
conversation_key: format!("test-{}", uuid::Uuid::new_v4()),
is_direct: true,
image: None,
metadata: Metadata::new(),
}
}
fn make_request_with_key(text: &str, conversation_key: &str) -> GenerateRequest {
GenerateRequest {
request_id: RequestId::new(),
text: text.into(),
user_id: "test-user".into(),
display_name: None,
conversation_key: conversation_key.into(),
is_direct: true,
image: None,
metadata: Metadata::new(),
}
}
// ── Test 1: Simple chat round-trip ──────────────────────────────────────
#[tokio::test]
async fn test_simple_chat_roundtrip() {
let mut h = TestHarness::new().await;
let request = make_request("what is 2+2? answer with just the number.");
let rid = request.request_id.clone();
let orch = h.orchestrator.clone();
let gen = tokio::spawn(async move { orch.generate(&request).await });
let events = h.collect_events_for(&rid, 30).await;
let result = gen.await.unwrap();
assert!(result.is_some(), "Expected a response");
let text = result.unwrap();
assert!(text.contains('4'), "Expected '4' in response, got: {text}");
assert!(events.iter().any(|e| matches!(e, OrchestratorEvent::Started { .. })));
assert!(events.iter().any(|e| matches!(e, OrchestratorEvent::Thinking { .. })));
let done = events.iter().find(|e| matches!(e, OrchestratorEvent::Done { .. }));
assert!(done.is_some(), "Missing Done event");
if let Some(OrchestratorEvent::Done { usage, .. }) = done {
assert!(usage.prompt_tokens > 0);
assert!(usage.completion_tokens > 0);
}
}
// ── Test 2: Conversation continuity ─────────────────────────────────────
#[tokio::test]
async fn test_conversation_continuity() {
let mut h = TestHarness::new().await;
let conv_key = format!("test-{}", uuid::Uuid::new_v4());
// Turn 1
let r1 = make_request_with_key("my favorite color is cerulean. just acknowledge.", &conv_key);
let rid1 = r1.request_id.clone();
let orch1 = h.orchestrator.clone();
let gen1 = tokio::spawn(async move { orch1.generate(&r1).await });
h.collect_events_for(&rid1, 30).await;
let result1 = gen1.await.unwrap();
assert!(result1.is_some(), "Turn 1 should get a response");
// Turn 2
let r2 = make_request_with_key("what is my favorite color?", &conv_key);
let rid2 = r2.request_id.clone();
let orch2 = h.orchestrator.clone();
let gen2 = tokio::spawn(async move { orch2.generate(&r2).await });
h.collect_events_for(&rid2, 30).await;
let result2 = gen2.await.unwrap();
assert!(result2.is_some(), "Turn 2 should get a response");
let text = result2.unwrap().to_lowercase();
assert!(text.contains("cerulean"), "Expected 'cerulean', got: {text}");
}
// ── Test 3: Client-side tool dispatch ───────────────────────────────────
#[tokio::test]
async fn test_client_tool_dispatch() {
use mistralai_client::v1::conversations::{
ConversationInput, CreateConversationRequest,
};
use mistralai_client::v1::agents::AgentTool;
let mut h = TestHarness::new().await;
// Create conversation directly with file_read tool defined
let api_key = std::env::var("SOL_MISTRAL_API_KEY")
.expect("SOL_MISTRAL_API_KEY must be set");
let mistral = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None)
.expect("Failed to create Mistral client");
let file_read_tool = AgentTool::function(
"file_read".into(),
"Read a file's contents. Use path for the file path.".into(),
serde_json::json!({
"type": "object",
"properties": {
"path": { "type": "string", "description": "File path to read" }
},
"required": ["path"]
}),
);
let req = CreateConversationRequest {
inputs: ConversationInput::Text(
"use your file_read tool to read the file at path 'README.md'".into(),
),
model: Some("mistral-medium-latest".into()),
agent_id: None,
agent_version: None,
name: Some("test-client-tool".into()),
description: None,
instructions: Some("you are a coding assistant. use tools when asked.".into()),
completion_args: None,
tools: Some(vec![file_read_tool]),
handoff_execution: None,
metadata: None,
store: Some(true),
stream: false,
};
let conv_response = mistral
.create_conversation_async(&req)
.await
.expect("Failed to create conversation");
// Now pass to orchestrator
let request = make_request("use your file_read tool to read README.md");
let rid = request.request_id.clone();
let orch = h.orchestrator.clone();
let orch_submit = h.orchestrator.clone();
let gen = tokio::spawn(async move {
orch.generate_from_response(&request, conv_response).await
});
let mut got_client_tool = false;
let deadline = tokio::time::Instant::now() + Duration::from_secs(60);
loop {
match tokio::time::timeout_at(deadline, h.event_rx.recv()).await {
Ok(Ok(event)) => {
if event.request_id() != &rid { continue; }
match event {
OrchestratorEvent::ToolCallDetected {
side: ToolSide::Client, call_id, ref name, ..
} => {
assert_eq!(name, "file_read");
got_client_tool = true;
orch_submit
.submit_tool_result(&call_id, ToolResultPayload {
text: "# Sol\n\nVirtual librarian for Sunbeam.".into(),
is_error: false,
})
.await
.expect("Failed to submit tool result");
}
OrchestratorEvent::Done { .. } | OrchestratorEvent::Failed { .. } => break,
_ => continue,
}
}
Ok(Err(_)) => break,
Err(_) => panic!("Timeout waiting for client tool dispatch"),
}
}
assert!(got_client_tool, "Expected client-side file_read tool call");
let result = gen.await.unwrap();
assert!(result.is_some(), "Expected response after tool execution");
}
// ── Test 4: Event ordering ───────────────────────────────────────────────
#[tokio::test]
async fn test_event_ordering() {
let mut h = TestHarness::new().await;
let request = make_request("say hello");
let rid = request.request_id.clone();
let orch = h.orchestrator.clone();
let gen = tokio::spawn(async move { orch.generate(&request).await });
let events = h.collect_events_for(&rid, 30).await;
let _ = gen.await;
// Verify strict ordering: Started → Thinking → Done
assert!(events.len() >= 3, "Expected at least 3 events, got {}", events.len());
assert!(matches!(events[0], OrchestratorEvent::Started { .. }), "First event should be Started");
assert!(matches!(events[1], OrchestratorEvent::Thinking { .. }), "Second event should be Thinking");
assert!(matches!(events.last().unwrap(), OrchestratorEvent::Done { .. }), "Last event should be Done");
}
// ── Test 5: Metadata pass-through ───────────────────────────────────────
#[tokio::test]
async fn test_metadata_passthrough() {
let mut h = TestHarness::new().await;
let mut request = make_request("hi");
request.metadata.insert("room_id", "!test-room:localhost");
request.metadata.insert("custom_key", "custom_value");
let rid = request.request_id.clone();
let orch = h.orchestrator.clone();
let gen = tokio::spawn(async move { orch.generate(&request).await });
let events = h.collect_events_for(&rid, 30).await;
let _ = gen.await;
// Started event should carry metadata
let started = events.iter().find(|e| matches!(e, OrchestratorEvent::Started { .. }));
assert!(started.is_some(), "Missing Started event");
if let Some(OrchestratorEvent::Started { metadata, .. }) = started {
assert_eq!(metadata.get("room_id"), Some("!test-room:localhost"));
assert_eq!(metadata.get("custom_key"), Some("custom_value"));
}
}
// ── Test 6: Token usage accuracy ────────────────────────────────────────
#[tokio::test]
async fn test_token_usage_accuracy() {
let mut h = TestHarness::new().await;
// Short prompt → small token count
let r1 = make_request("say ok");
let rid1 = r1.request_id.clone();
let orch1 = h.orchestrator.clone();
let gen1 = tokio::spawn(async move { orch1.generate(&r1).await });
let events1 = h.collect_events_for(&rid1, 30).await;
let _ = gen1.await;
let done1 = events1.iter().find_map(|e| match e {
OrchestratorEvent::Done { usage, .. } => Some(usage.clone()),
_ => None,
}).expect("Missing Done event");
// Longer prompt → larger token count
let r2 = make_request(
"write a haiku about the sun setting over the ocean. include imagery of waves."
);
let rid2 = r2.request_id.clone();
let orch2 = h.orchestrator.clone();
let gen2 = tokio::spawn(async move { orch2.generate(&r2).await });
let events2 = h.collect_events_for(&rid2, 30).await;
let _ = gen2.await;
let done2 = events2.iter().find_map(|e| match e {
OrchestratorEvent::Done { usage, .. } => Some(usage.clone()),
_ => None,
}).expect("Missing Done event");
// Both should have non-zero tokens
assert!(done1.prompt_tokens > 0);
assert!(done1.completion_tokens > 0);
assert!(done2.prompt_tokens > 0);
assert!(done2.completion_tokens > 0);
// The longer prompt should use more completion tokens (haiku vs "ok")
assert!(
done2.completion_tokens > done1.completion_tokens,
"Longer request should produce more completion tokens: {} vs {}",
done2.completion_tokens, done1.completion_tokens
);
}
// ── Test 7: Failed tool result ──────────────────────────────────────────
#[tokio::test]
async fn test_failed_tool_result() {
use mistralai_client::v1::conversations::{
ConversationInput, CreateConversationRequest,
};
use mistralai_client::v1::agents::AgentTool;
let mut h = TestHarness::new().await;
let api_key = std::env::var("SOL_MISTRAL_API_KEY").unwrap();
let mistral = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap();
let tool = AgentTool::function(
"file_read".into(),
"Read a file.".into(),
serde_json::json!({
"type": "object",
"properties": { "path": { "type": "string" } },
"required": ["path"]
}),
);
let req = CreateConversationRequest {
inputs: ConversationInput::Text("read the file at /nonexistent/path".into()),
model: Some("mistral-medium-latest".into()),
agent_id: None,
agent_version: None,
name: Some("test-failed-tool".into()),
description: None,
instructions: Some("use tools when asked.".into()),
completion_args: None,
tools: Some(vec![tool]),
handoff_execution: None,
metadata: None,
store: Some(true),
stream: false,
};
let conv_response = mistral.create_conversation_async(&req).await.unwrap();
let request = make_request("read /nonexistent/path");
let rid = request.request_id.clone();
let orch = h.orchestrator.clone();
let orch_submit = h.orchestrator.clone();
let gen = tokio::spawn(async move {
orch.generate_from_response(&request, conv_response).await
});
// Submit error result when tool is called
let deadline = tokio::time::Instant::now() + Duration::from_secs(60);
loop {
match tokio::time::timeout_at(deadline, h.event_rx.recv()).await {
Ok(Ok(event)) => {
if event.request_id() != &rid { continue; }
match event {
OrchestratorEvent::ToolCallDetected { side: ToolSide::Client, call_id, .. } => {
orch_submit.submit_tool_result(&call_id, ToolResultPayload {
text: "Error: file not found".into(),
is_error: true,
}).await.unwrap();
}
OrchestratorEvent::ToolCompleted { success, .. } => {
assert!(!success, "Expected tool to report failure");
}
OrchestratorEvent::Done { .. } | OrchestratorEvent::Failed { .. } => break,
_ => continue,
}
}
Ok(Err(_)) => break,
Err(_) => panic!("Timeout"),
}
}
// Model should still produce a response (explaining the error)
let result = gen.await.unwrap();
assert!(result.is_some(), "Expected response even after tool error");
}
// ── Test 8: Server-side tool execution (search_web) ─────────────────────
// Note: run_script requires deno sandbox + tool definitions from the agent.
// search_web is more reliably available in test conversations.
#[tokio::test]
async fn test_server_tool_execution() {
use mistralai_client::v1::conversations::{
ConversationInput, CreateConversationRequest,
};
use mistralai_client::v1::agents::AgentTool;
let mut h = TestHarness::new().await;
let api_key = std::env::var("SOL_MISTRAL_API_KEY").unwrap();
let mistral = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap();
// Create conversation with search_web tool
let tool = AgentTool::function(
"search_web".into(),
"Search the web. Returns titles, URLs, and snippets.".into(),
serde_json::json!({
"type": "object",
"properties": {
"query": { "type": "string", "description": "Search query" }
},
"required": ["query"]
}),
);
let req = CreateConversationRequest {
inputs: ConversationInput::Text("search the web for 'rust programming language'".into()),
model: Some("mistral-medium-latest".into()),
agent_id: None,
agent_version: None,
name: Some("test-server-tool".into()),
description: None,
instructions: Some("use tools when asked. always use the search_web tool for any web search request.".into()),
completion_args: None,
tools: Some(vec![tool]),
handoff_execution: None,
metadata: None,
store: Some(true),
stream: false,
};
let conv_response = mistral.create_conversation_async(&req).await.unwrap();
let request = make_request("search the web for rust");
let rid = request.request_id.clone();
let orch = h.orchestrator.clone();
let gen = tokio::spawn(async move {
orch.generate_from_response(&request, conv_response).await
});
let events = h.collect_events_for(&rid, 60).await;
let result = gen.await.unwrap();
// May or may not produce a result (search_web needs SearXNG running)
// But we should at least see the tool call events
let tool_detected = events.iter().find(|e| matches!(e, OrchestratorEvent::ToolCallDetected { .. }));
assert!(tool_detected.is_some(), "Expected ToolCallDetected for search_web");
if let Some(OrchestratorEvent::ToolCallDetected { side, name, .. }) = tool_detected {
assert_eq!(*side, ToolSide::Server);
assert_eq!(name, "search_web");
}
assert!(events.iter().any(|e| matches!(e, OrchestratorEvent::ToolStarted { .. })));
assert!(events.iter().any(|e| matches!(e, OrchestratorEvent::ToolCompleted { .. })));
}
// ══════════════════════════════════════════════════════════════════════════
// gRPC integration tests — full round-trip through the gRPC server
// ══════════════════════════════════════════════════════════════════════════
mod grpc_tests {
use super::*;
use crate::grpc::{self, GrpcState};
use crate::grpc::code_agent_client::CodeAgentClient;
use crate::grpc::*;
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
/// Start a gRPC server on a random port and return the endpoint URL.
async fn start_test_server() -> (String, Arc<GrpcState>) {
let env_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(".env");
if let Ok(contents) = std::fs::read_to_string(&env_path) {
for line in contents.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') { continue; }
if let Some((k, v)) = line.split_once('=') {
std::env::set_var(k.trim(), v.trim());
}
}
}
let api_key = std::env::var("SOL_MISTRAL_API_KEY")
.expect("SOL_MISTRAL_API_KEY must be set");
let config = test_config();
let mistral = Arc::new(
mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap(),
);
let store = Arc::new(Store::open_memory().unwrap());
let conversations = Arc::new(ConversationRegistry::new(
config.agents.orchestrator_model.clone(),
config.agents.compaction_threshold,
store.clone(),
));
let tools = Arc::new(ToolRegistry::new_minimal(config.clone()));
let orch = Arc::new(Orchestrator::new(
config.clone(), tools.clone(), mistral.clone(), conversations,
"you are sol. respond briefly. lowercase only.".into(),
));
let grpc_state = Arc::new(GrpcState {
config: config.clone(),
tools,
store,
mistral,
matrix: None, // not needed for tests
system_prompt: "you are sol. respond briefly. lowercase only.".into(),
orchestrator_agent_id: String::new(),
orchestrator: Some(orch),
});
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let endpoint = format!("http://{addr}");
let state = grpc_state.clone();
tokio::spawn(async move {
let svc = crate::grpc::service::CodeAgentService::new(state);
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
tonic::transport::Server::builder()
.add_service(CodeAgentServer::new(svc))
.serve_with_incoming(incoming)
.await
.unwrap();
});
tokio::time::sleep(Duration::from_millis(100)).await;
(endpoint, grpc_state)
}
/// Connect a gRPC client and send StartSession. Returns (tx, rx, session_ready).
async fn connect_session(
endpoint: &str,
) -> (
mpsc::Sender<ClientMessage>,
tonic::Streaming<ServerMessage>,
SessionReady,
) {
let mut client = CodeAgentClient::connect(endpoint.to_string()).await.unwrap();
let (tx, client_rx) = mpsc::channel::<ClientMessage>(32);
let stream = ReceiverStream::new(client_rx);
let response = client.session(stream).await.unwrap();
let mut rx = response.into_inner();
// Send StartSession
tx.send(ClientMessage {
payload: Some(client_message::Payload::Start(StartSession {
project_path: "/tmp/test-project".into(),
prompt_md: String::new(),
config_toml: String::new(),
git_branch: "main".into(),
git_status: String::new(),
file_tree: vec![],
model: "mistral-medium-latest".into(),
client_tools: vec![],
})),
})
.await
.unwrap();
// Wait for SessionReady
let ready = loop {
match rx.message().await.unwrap() {
Some(ServerMessage { payload: Some(server_message::Payload::Ready(r)) }) => break r,
Some(ServerMessage { payload: Some(server_message::Payload::Error(e)) }) => {
panic!("Session start failed: {}", e.message);
}
_ => continue,
}
};
(tx, rx, ready)
}
#[tokio::test]
async fn test_grpc_simple_roundtrip() {
let (endpoint, _state) = start_test_server().await;
let (tx, mut rx, ready) = connect_session(&endpoint).await;
assert!(!ready.session_id.is_empty());
assert!(!ready.room_id.is_empty());
// Send a message
tx.send(ClientMessage {
payload: Some(client_message::Payload::Input(UserInput {
text: "what is 3+3? answer with just the number.".into(),
})),
})
.await
.unwrap();
// Collect server messages until TextDone
let mut got_status = false;
let mut got_done = false;
let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
loop {
match tokio::time::timeout_at(deadline, rx.message()).await {
Ok(Ok(Some(msg))) => match msg.payload {
Some(server_message::Payload::Status(_)) => got_status = true,
Some(server_message::Payload::Done(d)) => {
got_done = true;
assert!(d.full_text.contains('6'), "Expected '6', got: {}", d.full_text);
assert!(d.input_tokens > 0, "Expected non-zero input tokens");
assert!(d.output_tokens > 0, "Expected non-zero output tokens");
break;
}
Some(server_message::Payload::Error(e)) => {
panic!("Server error: {}", e.message);
}
_ => continue,
},
Ok(Ok(None)) => panic!("Stream closed before Done"),
Ok(Err(e)) => panic!("Stream error: {e}"),
Err(_) => panic!("Timeout waiting for gRPC response"),
}
}
assert!(got_status, "Expected at least one Status message");
assert!(got_done, "Expected TextDone message");
// Clean disconnect
tx.send(ClientMessage {
payload: Some(client_message::Payload::End(EndSession {})),
})
.await
.unwrap();
}
#[tokio::test]
async fn test_grpc_client_tool_relay() {
let (endpoint, _state) = start_test_server().await;
let (tx, mut rx, _ready) = connect_session(&endpoint).await;
// Send a message that should trigger file_read
tx.send(ClientMessage {
payload: Some(client_message::Payload::Input(UserInput {
text: "use file_read to read README.md".into(),
})),
})
.await
.unwrap();
let mut got_tool_call = false;
let mut got_done = false;
let deadline = tokio::time::Instant::now() + Duration::from_secs(60);
loop {
match tokio::time::timeout_at(deadline, rx.message()).await {
Ok(Ok(Some(msg))) => match msg.payload {
Some(server_message::Payload::ToolCall(tc)) => {
assert!(tc.is_local, "Expected local tool, got: {}", tc.name);
// Model may call file_read, list_directory, or other client tools
got_tool_call = true;
// Send tool result back
tx.send(ClientMessage {
payload: Some(client_message::Payload::ToolResult(
crate::grpc::ToolResult {
call_id: tc.call_id,
result: "# Sol\nVirtual librarian.".into(),
is_error: false,
},
)),
})
.await
.unwrap();
}
Some(server_message::Payload::Done(d)) => {
got_done = true;
assert!(!d.full_text.is_empty(), "Expected non-empty response");
break;
}
Some(server_message::Payload::Error(e)) => {
panic!("Server error: {}", e.message);
}
_ => continue,
},
Ok(Ok(None)) => break,
Ok(Err(e)) => panic!("Stream error: {e}"),
Err(_) => panic!("Timeout"),
}
}
assert!(got_tool_call, "Expected ToolCall for file_read");
assert!(got_done, "Expected TextDone after tool execution");
}
#[tokio::test]
async fn test_grpc_token_counts() {
let (endpoint, _state) = start_test_server().await;
let (tx, mut rx, _ready) = connect_session(&endpoint).await;
tx.send(ClientMessage {
payload: Some(client_message::Payload::Input(UserInput {
text: "say hello".into(),
})),
})
.await
.unwrap();
let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
loop {
match tokio::time::timeout_at(deadline, rx.message()).await {
Ok(Ok(Some(msg))) => match msg.payload {
Some(server_message::Payload::Done(d)) => {
assert!(d.input_tokens > 0, "input_tokens should be > 0, got {}", d.input_tokens);
assert!(d.output_tokens > 0, "output_tokens should be > 0, got {}", d.output_tokens);
break;
}
Some(server_message::Payload::Error(e)) => panic!("Error: {}", e.message),
_ => continue,
},
Ok(Ok(None)) => panic!("Stream closed"),
Ok(Err(e)) => panic!("Stream error: {e}"),
Err(_) => panic!("Timeout"),
}
}
}
#[tokio::test]
async fn test_grpc_session_resume() {
let (endpoint, _state) = start_test_server().await;
// Session 1: establish context
let (tx1, mut rx1, ready1) = connect_session(&endpoint).await;
tx1.send(ClientMessage {
payload: Some(client_message::Payload::Input(UserInput {
text: "my secret code is 42. remember it.".into(),
})),
}).await.unwrap();
// Wait for response
let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
loop {
match tokio::time::timeout_at(deadline, rx1.message()).await {
Ok(Ok(Some(msg))) => match msg.payload {
Some(server_message::Payload::Done(_)) => break,
Some(server_message::Payload::Error(e)) => panic!("Error: {}", e.message),
_ => continue,
},
Ok(Ok(None)) => break,
Ok(Err(e)) => panic!("Error: {e}"),
Err(_) => panic!("Timeout"),
}
}
// Disconnect (don't send End — keeps session active)
drop(tx1);
drop(rx1);
// Session 2: reconnect — should resume the same session
let (tx2, mut rx2, ready2) = connect_session(&endpoint).await;
// Should be the same session (same project path → same room)
assert_eq!(ready2.room_id, ready1.room_id, "Should resume same room");
assert!(ready2.resumed, "Should indicate resumed session");
// History requires Matrix (not available in tests) — just check session resumed
// Ask for recall
tx2.send(ClientMessage {
payload: Some(client_message::Payload::Input(UserInput {
text: "what is my secret code?".into(),
})),
}).await.unwrap();
let deadline = tokio::time::Instant::now() + Duration::from_secs(30);
loop {
match tokio::time::timeout_at(deadline, rx2.message()).await {
Ok(Ok(Some(msg))) => match msg.payload {
Some(server_message::Payload::Done(d)) => {
assert!(
d.full_text.contains("42"),
"Expected model to recall '42', got: {}",
d.full_text
);
break;
}
Some(server_message::Payload::Error(e)) => panic!("Error: {}", e.message),
_ => continue,
},
Ok(Ok(None)) => panic!("Stream closed"),
Ok(Err(e)) => panic!("Error: {e}"),
Err(_) => panic!("Timeout"),
}
}
}
#[tokio::test]
async fn test_grpc_clean_disconnect() {
let (endpoint, _state) = start_test_server().await;
let (tx, mut rx, ready) = connect_session(&endpoint).await;
assert!(!ready.session_id.is_empty());
// Clean disconnect
tx.send(ClientMessage {
payload: Some(client_message::Payload::End(EndSession {})),
}).await.unwrap();
// Should get SessionEnd
let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
let mut got_end = false;
loop {
match tokio::time::timeout_at(deadline, rx.message()).await {
Ok(Ok(Some(msg))) => match msg.payload {
Some(server_message::Payload::End(_)) => { got_end = true; break; }
_ => continue,
},
Ok(Ok(None)) => break,
Ok(Err(_)) => break,
Err(_) => break,
}
}
assert!(got_end, "Server should send SessionEnd on clean disconnect");
}
}

View File

@@ -10,6 +10,8 @@ mod memory;
mod persistence; mod persistence;
mod grpc; mod grpc;
mod orchestrator; mod orchestrator;
#[cfg(test)]
mod integration_test;
mod sdk; mod sdk;
mod sync; mod sync;
mod time_context; mod time_context;
@@ -319,7 +321,7 @@ async fn main() -> anyhow::Result<()> {
tools: state.responder.tools(), tools: state.responder.tools(),
store: store.clone(), store: store.clone(),
mistral: state.mistral.clone(), mistral: state.mistral.clone(),
matrix: matrix_client.clone(), matrix: Some(matrix_client.clone()),
system_prompt: system_prompt_text.clone(), system_prompt: system_prompt_text.clone(),
orchestrator_agent_id: orchestrator_id, orchestrator_agent_id: orchestrator_id,
orchestrator: Some(orch), orchestrator: Some(orch),

View File

@@ -27,8 +27,8 @@ use crate::sdk::kratos::KratosClient;
pub struct ToolRegistry { pub struct ToolRegistry {
opensearch: OpenSearch, opensearch: Option<OpenSearch>,
matrix: MatrixClient, matrix: Option<MatrixClient>,
config: Arc<Config>, config: Arc<Config>,
gitea: Option<Arc<GiteaClient>>, gitea: Option<Arc<GiteaClient>>,
kratos: Option<Arc<KratosClient>>, kratos: Option<Arc<KratosClient>>,
@@ -47,8 +47,8 @@ impl ToolRegistry {
store: Option<Arc<Store>>, store: Option<Arc<Store>>,
) -> Self { ) -> Self {
Self { Self {
opensearch, opensearch: Some(opensearch),
matrix, matrix: Some(matrix),
config, config,
gitea, gitea,
kratos, kratos,
@@ -57,6 +57,21 @@ impl ToolRegistry {
} }
} }
/// Create a minimal ToolRegistry for integration tests.
/// Only `run_script` works (deno sandbox). Tools needing OpenSearch
/// or Matrix will return errors if called.
pub fn new_minimal(config: Arc<Config>) -> Self {
Self {
opensearch: None,
matrix: None,
config,
gitea: None,
kratos: None,
mistral: None,
store: None,
}
}
pub fn has_gitea(&self) -> bool { pub fn has_gitea(&self) -> bool {
self.gitea.is_some() self.gitea.is_some()
} }
@@ -232,7 +247,10 @@ impl ToolRegistry {
/// its members are also members of the requesting room. This is enforced /// its members are also members of the requesting room. This is enforced
/// at the query level — Sol never sees filtered-out results. /// at the query level — Sol never sees filtered-out results.
async fn allowed_room_ids(&self, requesting_room_id: &str) -> Vec<String> { async fn allowed_room_ids(&self, requesting_room_id: &str) -> Vec<String> {
let rooms = self.matrix.joined_rooms(); let Some(ref matrix) = self.matrix else {
return vec![requesting_room_id.to_string()];
};
let rooms = matrix.joined_rooms();
// Get requesting room's member set // Get requesting room's member set
let requesting_room = rooms.iter().find(|r| r.room_id().as_str() == requesting_room_id); let requesting_room = rooms.iter().find(|r| r.room_id().as_str() == requesting_room_id);
@@ -304,11 +322,14 @@ impl ToolRegistry {
arguments: &str, arguments: &str,
response_ctx: &ResponseContext, response_ctx: &ResponseContext,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let os = || self.opensearch.as_ref().ok_or_else(|| anyhow::anyhow!("OpenSearch not configured"));
let mx = || self.matrix.as_ref().ok_or_else(|| anyhow::anyhow!("Matrix not configured"));
match name { match name {
"search_archive" => { "search_archive" => {
let allowed = self.allowed_room_ids(&response_ctx.room_id).await; let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
search::search_archive( search::search_archive(
&self.opensearch, os()?,
&self.config.opensearch.index, &self.config.opensearch.index,
arguments, arguments,
&allowed, &allowed,
@@ -318,20 +339,20 @@ impl ToolRegistry {
"get_room_context" => { "get_room_context" => {
let allowed = self.allowed_room_ids(&response_ctx.room_id).await; let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
room_history::get_room_context( room_history::get_room_context(
&self.opensearch, os()?,
&self.config.opensearch.index, &self.config.opensearch.index,
arguments, arguments,
&allowed, &allowed,
) )
.await .await
} }
"list_rooms" => room_info::list_rooms(&self.matrix).await, "list_rooms" => room_info::list_rooms(mx()?).await,
"get_room_members" => room_info::get_room_members(&self.matrix, arguments).await, "get_room_members" => room_info::get_room_members(mx()?, arguments).await,
"run_script" => { "run_script" => {
let allowed = self.allowed_room_ids(&response_ctx.room_id).await; let allowed = self.allowed_room_ids(&response_ctx.room_id).await;
script::run_script( script::run_script(
&self.opensearch, os()?,
&self.matrix, mx()?,
&self.config, &self.config,
arguments, arguments,
response_ctx, response_ctx,