215 lines
7.6 KiB
Rust
215 lines
7.6 KiB
Rust
/// Integration test: starts a mock gRPC server and connects the client.
|
|
/// Tests the full bidirectional stream lifecycle without needing Sol or Mistral.
|
|
|
|
use std::pin::Pin;
|
|
use std::sync::Arc;
|
|
|
|
use futures::Stream;
|
|
use sunbeam_proto::sunbeam_code_v1::code_agent_server::{CodeAgent, CodeAgentServer};
|
|
use sunbeam_proto::sunbeam_code_v1::*;
|
|
use tokio::sync::mpsc;
|
|
use tokio_stream::wrappers::ReceiverStream;
|
|
use tonic::{Request, Response, Status, Streaming};
|
|
|
|
/// Mock server that echoes back user input as assistant text.
|
|
struct MockCodeAgent;
|
|
|
|
#[tonic::async_trait]
|
|
impl CodeAgent for MockCodeAgent {
|
|
type SessionStream = Pin<Box<dyn Stream<Item = Result<ServerMessage, Status>> + Send>>;
|
|
|
|
async fn session(
|
|
&self,
|
|
request: Request<Streaming<ClientMessage>>,
|
|
) -> Result<Response<Self::SessionStream>, Status> {
|
|
let mut in_stream = request.into_inner();
|
|
let (tx, rx) = mpsc::channel(32);
|
|
|
|
tokio::spawn(async move {
|
|
// Wait for StartSession
|
|
if let Ok(Some(msg)) = in_stream.message().await {
|
|
if let Some(client_message::Payload::Start(start)) = msg.payload {
|
|
let _ = tx.send(Ok(ServerMessage {
|
|
payload: Some(server_message::Payload::Ready(SessionReady {
|
|
session_id: "test-session-123".into(),
|
|
room_id: "!test-room:local".into(),
|
|
model: if start.model.is_empty() {
|
|
"devstral-2".into()
|
|
} else {
|
|
start.model
|
|
},
|
|
resumed: false,
|
|
history: vec![],
|
|
})),
|
|
})).await;
|
|
}
|
|
}
|
|
|
|
// Echo loop
|
|
while let Ok(Some(msg)) = in_stream.message().await {
|
|
match msg.payload {
|
|
Some(client_message::Payload::Input(input)) => {
|
|
let _ = tx.send(Ok(ServerMessage {
|
|
payload: Some(server_message::Payload::Done(TextDone {
|
|
full_text: format!("[echo] {}", input.text),
|
|
input_tokens: 10,
|
|
output_tokens: 5,
|
|
})),
|
|
})).await;
|
|
}
|
|
Some(client_message::Payload::End(_)) => {
|
|
let _ = tx.send(Ok(ServerMessage {
|
|
payload: Some(server_message::Payload::End(SessionEnd {
|
|
summary: "Session ended.".into(),
|
|
})),
|
|
})).await;
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
});
|
|
|
|
Ok(Response::new(Box::pin(ReceiverStream::new(rx))))
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_session_lifecycle() {
|
|
// Start mock server on a random port
|
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let addr = listener.local_addr().unwrap();
|
|
|
|
tokio::spawn(async move {
|
|
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
|
tonic::transport::Server::builder()
|
|
.add_service(CodeAgentServer::new(MockCodeAgent))
|
|
.serve_with_incoming(incoming)
|
|
.await
|
|
.unwrap();
|
|
});
|
|
|
|
// Give server a moment to start
|
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
|
|
|
// Connect client
|
|
let endpoint = format!("http://{addr}");
|
|
|
|
use sunbeam_proto::sunbeam_code_v1::code_agent_client::CodeAgentClient;
|
|
let mut client = CodeAgentClient::connect(endpoint).await.unwrap();
|
|
|
|
let (tx, client_rx) = mpsc::channel::<ClientMessage>(32);
|
|
let client_stream = ReceiverStream::new(client_rx);
|
|
let response = client.session(client_stream).await.unwrap();
|
|
let mut rx = response.into_inner();
|
|
|
|
// Send StartSession
|
|
tx.send(ClientMessage {
|
|
payload: Some(client_message::Payload::Start(StartSession {
|
|
project_path: "/test/project".into(),
|
|
prompt_md: "test prompt".into(),
|
|
config_toml: String::new(),
|
|
git_branch: "main".into(),
|
|
git_status: String::new(),
|
|
file_tree: vec!["src/".into(), "Cargo.toml".into()],
|
|
model: "test-model".into(),
|
|
client_tools: vec![],
|
|
})),
|
|
}).await.unwrap();
|
|
|
|
// Receive SessionReady
|
|
let msg = rx.message().await.unwrap().unwrap();
|
|
match msg.payload {
|
|
Some(server_message::Payload::Ready(ready)) => {
|
|
assert_eq!(ready.session_id, "test-session-123");
|
|
assert_eq!(ready.model, "test-model");
|
|
}
|
|
other => panic!("Expected SessionReady, got {other:?}"),
|
|
}
|
|
|
|
// Send a chat message
|
|
tx.send(ClientMessage {
|
|
payload: Some(client_message::Payload::Input(UserInput {
|
|
text: "hello sol".into(),
|
|
})),
|
|
}).await.unwrap();
|
|
|
|
// Receive echo response
|
|
let msg = rx.message().await.unwrap().unwrap();
|
|
match msg.payload {
|
|
Some(server_message::Payload::Done(done)) => {
|
|
assert_eq!(done.full_text, "[echo] hello sol");
|
|
assert_eq!(done.input_tokens, 10);
|
|
assert_eq!(done.output_tokens, 5);
|
|
}
|
|
other => panic!("Expected TextDone, got {other:?}"),
|
|
}
|
|
|
|
// End session
|
|
tx.send(ClientMessage {
|
|
payload: Some(client_message::Payload::End(EndSession {})),
|
|
}).await.unwrap();
|
|
|
|
let msg = rx.message().await.unwrap().unwrap();
|
|
match msg.payload {
|
|
Some(server_message::Payload::End(end)) => {
|
|
assert_eq!(end.summary, "Session ended.");
|
|
}
|
|
other => panic!("Expected SessionEnd, got {other:?}"),
|
|
}
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_multiple_messages() {
|
|
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
|
|
let addr = listener.local_addr().unwrap();
|
|
|
|
tokio::spawn(async move {
|
|
let incoming = tokio_stream::wrappers::TcpListenerStream::new(listener);
|
|
tonic::transport::Server::builder()
|
|
.add_service(CodeAgentServer::new(MockCodeAgent))
|
|
.serve_with_incoming(incoming)
|
|
.await
|
|
.unwrap();
|
|
});
|
|
|
|
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
|
|
|
|
let endpoint = format!("http://{addr}");
|
|
use sunbeam_proto::sunbeam_code_v1::code_agent_client::CodeAgentClient;
|
|
let mut client = CodeAgentClient::connect(endpoint).await.unwrap();
|
|
|
|
let (tx, client_rx) = mpsc::channel::<ClientMessage>(32);
|
|
let client_stream = ReceiverStream::new(client_rx);
|
|
let response = client.session(client_stream).await.unwrap();
|
|
let mut rx = response.into_inner();
|
|
|
|
// Start
|
|
tx.send(ClientMessage {
|
|
payload: Some(client_message::Payload::Start(StartSession {
|
|
project_path: "/test".into(),
|
|
model: "devstral-2".into(),
|
|
..Default::default()
|
|
})),
|
|
}).await.unwrap();
|
|
|
|
let _ = rx.message().await.unwrap().unwrap(); // SessionReady
|
|
|
|
// Send 3 messages and verify each echo
|
|
for i in 0..3 {
|
|
tx.send(ClientMessage {
|
|
payload: Some(client_message::Payload::Input(UserInput {
|
|
text: format!("message {i}"),
|
|
})),
|
|
}).await.unwrap();
|
|
|
|
let msg = rx.message().await.unwrap().unwrap();
|
|
match msg.payload {
|
|
Some(server_message::Payload::Done(done)) => {
|
|
assert_eq!(done.full_text, format!("[echo] message {i}"));
|
|
}
|
|
other => panic!("Expected TextDone for message {i}, got {other:?}"),
|
|
}
|
|
}
|
|
}
|