feat: streaming conversation responses via Mistral API

- Bump mistralai-client to 1.2.0 (conversation streaming support)
- session.rs: add create_or_append_conversation_streaming() — calls
  Mistral with stream:true, emits TextDelta messages to gRPC client
  as SSE chunks arrive, accumulates into ConversationResponse for
  orchestrator tool loop. Handles corrupted conversation recovery
- service.rs: session_chat_via_orchestrator uses streaming variant
- Integration tests: streaming create + append against real Mistral
  API, SSE parsing, accumulate text + function calls
This commit is contained in:
2026-03-24 21:33:01 +00:00
parent 6a2aafdccc
commit 261c39b424
5 changed files with 369 additions and 23 deletions

40
Cargo.lock generated
View File

@@ -828,6 +828,15 @@ dependencies = [
"cfg-if", "cfg-if",
] ]
[[package]]
name = "crossbeam-channel"
version = "0.5.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
dependencies = [
"crossbeam-utils",
]
[[package]] [[package]]
name = "crossbeam-deque" name = "crossbeam-deque"
version = "0.8.6" version = "0.8.6"
@@ -1357,7 +1366,7 @@ dependencies = [
"libc", "libc",
"option-ext", "option-ext",
"redox_users", "redox_users",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -1482,7 +1491,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
dependencies = [ dependencies = [
"libc", "libc",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -3070,9 +3079,9 @@ dependencies = [
[[package]] [[package]]
name = "mistralai-client" name = "mistralai-client"
version = "1.1.0" version = "1.2.0"
source = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/" source = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/"
checksum = "6a17c5600508f30965a8e6ab78e947a0db7017edef8ffb022aa9ac71f50ebaff" checksum = "24258b2ba72432f9c147a5a80e6d862c7b3c43db859135c619c948cf6c9e4fdf"
dependencies = [ dependencies = [
"async-stream", "async-stream",
"async-trait", "async-trait",
@@ -3154,7 +3163,7 @@ version = "0.50.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5" checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
dependencies = [ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -3777,7 +3786,7 @@ dependencies = [
"once_cell", "once_cell",
"socket2", "socket2",
"tracing", "tracing",
"windows-sys 0.60.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -4290,7 +4299,7 @@ dependencies = [
"errno", "errno",
"libc", "libc",
"linux-raw-sys 0.12.1", "linux-raw-sys 0.12.1",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -4736,6 +4745,7 @@ dependencies = [
"tonic-prost", "tonic-prost",
"tonic-prost-build", "tonic-prost-build",
"tracing", "tracing",
"tracing-appender",
"tracing-subscriber", "tracing-subscriber",
"tree-sitter", "tree-sitter",
"tree-sitter-python", "tree-sitter-python",
@@ -5355,7 +5365,7 @@ dependencies = [
"getrandom 0.4.2", "getrandom 0.4.2",
"once_cell", "once_cell",
"rustix 1.1.4", "rustix 1.1.4",
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]
@@ -5838,6 +5848,18 @@ dependencies = [
"tracing-core", "tracing-core",
] ]
[[package]]
name = "tracing-appender"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf"
dependencies = [
"crossbeam-channel",
"thiserror 2.0.18",
"time",
"tracing-subscriber",
]
[[package]] [[package]]
name = "tracing-attributes" name = "tracing-attributes"
version = "0.1.31" version = "0.1.31"
@@ -6425,7 +6447,7 @@ version = "0.1.11"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22" checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
dependencies = [ dependencies = [
"windows-sys 0.61.2", "windows-sys 0.59.0",
] ]
[[package]] [[package]]

View File

@@ -13,7 +13,7 @@ name = "sol"
path = "src/main.rs" path = "src/main.rs"
[dependencies] [dependencies]
mistralai-client = { version = "1.1.0", registry = "sunbeam" } mistralai-client = { version = "1.2.0", registry = "sunbeam" }
matrix-sdk = { version = "0.9", features = ["e2e-encryption", "sqlite"] } matrix-sdk = { version = "0.9", features = ["e2e-encryption", "sqlite"] }
opensearch = "2" opensearch = "2"
tokio = { version = "1", features = ["full"] } tokio = { version = "1", features = ["full"] }

View File

@@ -314,7 +314,7 @@ async fn session_chat_via_orchestrator(
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
use crate::orchestrator::event::*; use crate::orchestrator::event::*;
let conversation_response = session.create_or_append_conversation(text).await?; let conversation_response = session.create_or_append_conversation_streaming(text, tx).await?;
session.post_to_matrix(text).await; session.post_to_matrix(text).await;
let request_id = RequestId::new(); let request_id = RequestId::new();

View File

@@ -388,46 +388,160 @@ you also have access to server-side tools: search_archive, search_web, research,
tools tools
} }
/// Create or append to the Mistral conversation. Returns the response /// Create or append to the Mistral conversation with streaming.
/// for the orchestrator to run through its tool loop. /// Emits TextDelta messages to the gRPC client as chunks arrive.
pub async fn create_or_append_conversation( /// Returns the accumulated response for the orchestrator's tool loop.
pub async fn create_or_append_conversation_streaming(
&mut self, &mut self,
text: &str, text: &str,
client_tx: &tokio::sync::mpsc::Sender<Result<ServerMessage, tonic::Status>>,
) -> anyhow::Result<ConversationResponse> { ) -> anyhow::Result<ConversationResponse> {
use futures::StreamExt;
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
let context_header = self.build_context_header(text).await; let context_header = self.build_context_header(text).await;
let input_text = format!("{context_header}\n{text}"); let input_text = format!("{context_header}\n{text}");
let conv_id_for_accumulate: String;
let mut events = Vec::new();
if let Some(ref conv_id) = self.conversation_id { if let Some(ref conv_id) = self.conversation_id {
conv_id_for_accumulate = conv_id.clone();
let req = AppendConversationRequest { let req = AppendConversationRequest {
inputs: ConversationInput::Text(input_text.clone()), inputs: ConversationInput::Text(input_text.clone()),
completion_args: None, completion_args: None,
handoff_execution: None, handoff_execution: None,
store: Some(true), store: Some(true),
tool_confirmations: None, tool_confirmations: None,
stream: false, stream: true,
}; };
match self.state match self.state.mistral
.mistral .append_conversation_stream_async(conv_id, &req)
.append_conversation_async(conv_id, &req)
.await .await
{ {
Ok(resp) => Ok(resp), Ok(stream) => {
tokio::pin!(stream);
while let Some(result) = stream.next().await {
match result {
Ok(event) => {
if let Some(delta) = event.text_delta() {
let _ = client_tx.send(Ok(ServerMessage {
payload: Some(server_message::Payload::Delta(TextDelta {
text: delta,
})),
})).await;
}
events.push(event);
}
Err(e) => {
warn!("Stream event error: {}", e.message);
}
}
}
}
Err(e) if e.message.contains("function calls and responses") Err(e) if e.message.contains("function calls and responses")
|| e.message.contains("invalid_request_error") => || e.message.contains("invalid_request_error") =>
{ {
warn!( warn!(
conversation_id = conv_id.as_str(), conversation_id = conv_id.as_str(),
error = e.message.as_str(), error = e.message.as_str(),
"Conversation corrupted — creating fresh conversation" "Conversation corrupted — creating fresh conversation (streaming)"
); );
self.conversation_id = None; self.conversation_id = None;
self.create_fresh_conversation(input_text).await return self.create_fresh_conversation_streaming(input_text, client_tx).await;
} }
Err(e) => Err(anyhow::anyhow!("append_conversation failed: {}", e.message)), Err(e) => return Err(anyhow::anyhow!("append_conversation_stream failed: {}", e.message)),
} }
} else { } else {
self.create_fresh_conversation(input_text).await return self.create_fresh_conversation_streaming(input_text, client_tx).await;
} }
Ok(conversation_stream::accumulate(&conv_id_for_accumulate, &events))
}
/// Create a fresh streaming conversation — used on first message or after corruption.
async fn create_fresh_conversation_streaming(
&mut self,
input_text: String,
client_tx: &tokio::sync::mpsc::Sender<Result<ServerMessage, tonic::Status>>,
) -> anyhow::Result<ConversationResponse> {
use futures::StreamExt;
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
let instructions = self.build_instructions();
let req = CreateConversationRequest {
inputs: ConversationInput::Text(input_text),
model: Some(self.model.clone()),
agent_id: None,
agent_version: None,
name: Some(format!("code-{}", self.project_name)),
description: None,
instructions: Some(instructions),
completion_args: None,
tools: Some(self.build_tool_definitions()),
handoff_execution: None,
metadata: None,
store: Some(true),
stream: true,
};
let stream = self.state.mistral
.create_conversation_stream_async(&req)
.await
.map_err(|e| anyhow::anyhow!("create_conversation_stream failed: {}", e.message))?;
tokio::pin!(stream);
let mut events = Vec::new();
let mut conversation_id = String::new();
while let Some(result) = stream.next().await {
match result {
Ok(event) => {
if let Some(delta) = event.text_delta() {
let _ = client_tx.send(Ok(ServerMessage {
payload: Some(server_message::Payload::Delta(TextDelta {
text: delta,
})),
})).await;
}
events.push(event);
}
Err(e) => {
warn!("Stream event error: {}", e.message);
}
}
}
let resp = conversation_stream::accumulate(&conversation_id, &events);
// Extract conversation_id from the accumulated response
if !resp.conversation_id.is_empty() {
conversation_id = resp.conversation_id.clone();
}
// If we still don't have a conversation_id, the streaming didn't return one.
// This happens with the Conversations API — the ID comes from the response headers
// or the ResponseDone event. Let's check events for it.
if conversation_id.is_empty() {
// Fallback: create non-streaming to get the conversation_id
warn!("Streaming didn't return conversation_id — falling back to non-streaming create");
return self.create_fresh_conversation(
"".into() // empty — conversation already created, just need the ID
).await;
}
self.conversation_id = Some(conversation_id.clone());
self.state.store.set_code_session_conversation(
&self.session_id,
&conversation_id,
);
info!(
conversation_id = conversation_id.as_str(),
"Created streaming Mistral conversation for code session"
);
Ok(resp)
} }
/// Post user message to the Matrix room. /// Post user message to the Matrix room.

View File

@@ -6051,3 +6051,213 @@ class Calculator:
assert!(sessions.is_empty(), "Session should be marked complete, not running"); assert!(sessions.is_empty(), "Session should be marked complete, not running");
} }
} }
// ══════════════════════════════════════════════════════════════════════════
// Conversation streaming — Mistral API integration
// ══════════════════════════════════════════════════════════════════════════
mod conversation_stream_tests {
use futures::StreamExt;
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
use mistralai_client::v1::conversations::*;
fn load_env() -> Option<String> {
let env_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(".env");
if let Ok(contents) = std::fs::read_to_string(&env_path) {
for line in contents.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') { continue; }
if let Some((k, v)) = line.split_once('=') {
std::env::set_var(k.trim(), v.trim());
}
}
}
std::env::var("SOL_MISTRAL_API_KEY").ok()
}
#[tokio::test]
async fn test_streaming_create_conversation() {
let Some(api_key) = load_env() else { eprintln!("Skipping: no API key"); return; };
let client = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap();
let req = CreateConversationRequest {
inputs: ConversationInput::Text("What is 7 * 8? Answer with just the number.".into()),
model: Some("mistral-medium-latest".into()),
agent_id: None,
agent_version: None,
name: None,
description: None,
instructions: Some("Answer concisely.".into()),
completion_args: None,
tools: None,
handoff_execution: None,
metadata: None,
store: Some(false),
stream: true,
};
let stream = client.create_conversation_stream_async(&req).await.unwrap();
tokio::pin!(stream);
let mut events = Vec::new();
let mut saw_text = false;
let mut saw_done = false;
while let Some(result) = stream.next().await {
let event = result.unwrap();
if event.text_delta().is_some() { saw_text = true; }
if matches!(&event, ConversationEvent::ResponseDone { .. }) { saw_done = true; }
events.push(event);
}
assert!(saw_text, "Should receive text deltas");
assert!(saw_done, "Should receive ResponseDone");
let resp = conversation_stream::accumulate("", &events);
let text = resp.assistant_text().unwrap_or_default();
assert!(text.contains("56"), "Should compute 7*8=56, got: {text}");
assert!(resp.usage.total_tokens > 0, "Should have token usage");
}
#[tokio::test]
async fn test_streaming_append_conversation() {
let Some(api_key) = load_env() else { eprintln!("Skipping: no API key"); return; };
let client = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap();
// Create non-streaming first
let create_req = CreateConversationRequest {
inputs: ConversationInput::Text("The magic word is SUNBEAM. Just say ok.".into()),
model: Some("mistral-medium-latest".into()),
agent_id: None,
agent_version: None,
name: None,
description: None,
instructions: Some("Be brief.".into()),
completion_args: None,
tools: None,
handoff_execution: None,
metadata: None,
store: Some(true),
stream: false,
};
let created = client.create_conversation_async(&create_req).await.unwrap();
// Append with streaming
let append_req = AppendConversationRequest {
inputs: ConversationInput::Text("What is the magic word?".into()),
completion_args: None,
handoff_execution: None,
store: Some(true),
tool_confirmations: None,
stream: true,
};
let stream = client
.append_conversation_stream_async(&created.conversation_id, &append_req)
.await
.unwrap();
tokio::pin!(stream);
let mut events = Vec::new();
while let Some(result) = stream.next().await {
events.push(result.unwrap());
}
let resp = conversation_stream::accumulate(&created.conversation_id, &events);
let text = resp.assistant_text().unwrap_or_default().to_uppercase();
assert!(text.contains("SUNBEAM"), "Should recall magic word, got: {text}");
// Cleanup
let _ = client.delete_conversation_async(&created.conversation_id).await;
}
#[test]
fn test_parse_sse_line_data_prefix() {
let line = r#"data: {"type":"conversation.response.started"}"#;
let event = conversation_stream::parse_sse_line(line).unwrap();
assert!(event.is_some());
assert!(matches!(event.unwrap(), ConversationEvent::ResponseStarted { .. }));
}
#[test]
fn test_parse_sse_line_event_prefix_skipped() {
let line = "event: conversation.response.started";
assert!(conversation_stream::parse_sse_line(line).unwrap().is_none());
}
#[test]
fn test_parse_sse_line_done() {
assert!(conversation_stream::parse_sse_line("data: [DONE]").unwrap().is_none());
}
#[test]
fn test_parse_sse_line_empty() {
assert!(conversation_stream::parse_sse_line("").unwrap().is_none());
assert!(conversation_stream::parse_sse_line(" ").unwrap().is_none());
}
#[test]
fn test_accumulate_text_and_usage() {
let events = vec![
ConversationEvent::ResponseStarted { created_at: None },
ConversationEvent::MessageOutput {
id: "m1".into(),
content: serde_json::json!("hello "),
role: "assistant".into(),
output_index: 0,
content_index: 0,
model: None,
},
ConversationEvent::MessageOutput {
id: "m1".into(),
content: serde_json::json!("world"),
role: "assistant".into(),
output_index: 0,
content_index: 0,
model: None,
},
ConversationEvent::ResponseDone {
usage: ConversationUsageInfo {
prompt_tokens: 10,
completion_tokens: 5,
total_tokens: 15,
},
created_at: None,
},
];
let resp = conversation_stream::accumulate("conv-123", &events);
assert_eq!(resp.conversation_id, "conv-123");
assert_eq!(resp.assistant_text().unwrap(), "hello world");
assert_eq!(resp.usage.total_tokens, 15);
}
#[test]
fn test_accumulate_with_function_calls() {
let events = vec![
ConversationEvent::FunctionCall {
id: "fc-1".into(),
name: "search_web".into(),
tool_call_id: "tc-1".into(),
arguments: r#"{"query":"rust"}"#.into(),
output_index: 0,
model: None,
confirmation_status: None,
},
ConversationEvent::ResponseDone {
usage: ConversationUsageInfo {
prompt_tokens: 20,
completion_tokens: 10,
total_tokens: 30,
},
created_at: None,
},
];
let resp = conversation_stream::accumulate("conv-456", &events);
assert!(resp.assistant_text().is_none());
let calls = resp.function_calls();
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "search_web");
assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc-1"));
}
}