feat: streaming conversation responses via Mistral API
- Bump mistralai-client to 1.2.0 (conversation streaming support) - session.rs: add create_or_append_conversation_streaming() — calls Mistral with stream:true, emits TextDelta messages to gRPC client as SSE chunks arrive, accumulates into ConversationResponse for orchestrator tool loop. Handles corrupted conversation recovery - service.rs: session_chat_via_orchestrator uses streaming variant - Integration tests: streaming create + append against real Mistral API, SSE parsing, accumulate text + function calls
This commit is contained in:
40
Cargo.lock
generated
40
Cargo.lock
generated
@@ -828,6 +828,15 @@ dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-channel"
|
||||
version = "0.5.15"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
|
||||
dependencies = [
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.6"
|
||||
@@ -1357,7 +1366,7 @@ dependencies = [
|
||||
"libc",
|
||||
"option-ext",
|
||||
"redox_users",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -1482,7 +1491,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3070,9 +3079,9 @@ dependencies = [
|
||||
|
||||
[[package]]
|
||||
name = "mistralai-client"
|
||||
version = "1.1.0"
|
||||
version = "1.2.0"
|
||||
source = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/"
|
||||
checksum = "6a17c5600508f30965a8e6ab78e947a0db7017edef8ffb022aa9ac71f50ebaff"
|
||||
checksum = "24258b2ba72432f9c147a5a80e6d862c7b3c43db859135c619c948cf6c9e4fdf"
|
||||
dependencies = [
|
||||
"async-stream",
|
||||
"async-trait",
|
||||
@@ -3154,7 +3163,7 @@ version = "0.50.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -3777,7 +3786,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"socket2",
|
||||
"tracing",
|
||||
"windows-sys 0.60.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4290,7 +4299,7 @@ dependencies = [
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys 0.12.1",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -4736,6 +4745,7 @@ dependencies = [
|
||||
"tonic-prost",
|
||||
"tonic-prost-build",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
"tree-sitter",
|
||||
"tree-sitter-python",
|
||||
@@ -5355,7 +5365,7 @@ dependencies = [
|
||||
"getrandom 0.4.2",
|
||||
"once_cell",
|
||||
"rustix 1.1.4",
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@@ -5838,6 +5848,18 @@ dependencies = [
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-appender"
|
||||
version = "0.2.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf"
|
||||
dependencies = [
|
||||
"crossbeam-channel",
|
||||
"thiserror 2.0.18",
|
||||
"time",
|
||||
"tracing-subscriber",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tracing-attributes"
|
||||
version = "0.1.31"
|
||||
@@ -6425,7 +6447,7 @@ version = "0.1.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||
dependencies = [
|
||||
"windows-sys 0.61.2",
|
||||
"windows-sys 0.59.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -13,7 +13,7 @@ name = "sol"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
mistralai-client = { version = "1.1.0", registry = "sunbeam" }
|
||||
mistralai-client = { version = "1.2.0", registry = "sunbeam" }
|
||||
matrix-sdk = { version = "0.9", features = ["e2e-encryption", "sqlite"] }
|
||||
opensearch = "2"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
|
||||
@@ -314,7 +314,7 @@ async fn session_chat_via_orchestrator(
|
||||
) -> anyhow::Result<()> {
|
||||
use crate::orchestrator::event::*;
|
||||
|
||||
let conversation_response = session.create_or_append_conversation(text).await?;
|
||||
let conversation_response = session.create_or_append_conversation_streaming(text, tx).await?;
|
||||
session.post_to_matrix(text).await;
|
||||
|
||||
let request_id = RequestId::new();
|
||||
|
||||
@@ -388,46 +388,160 @@ you also have access to server-side tools: search_archive, search_web, research,
|
||||
tools
|
||||
}
|
||||
|
||||
/// Create or append to the Mistral conversation. Returns the response
|
||||
/// for the orchestrator to run through its tool loop.
|
||||
pub async fn create_or_append_conversation(
|
||||
/// Create or append to the Mistral conversation with streaming.
|
||||
/// Emits TextDelta messages to the gRPC client as chunks arrive.
|
||||
/// Returns the accumulated response for the orchestrator's tool loop.
|
||||
pub async fn create_or_append_conversation_streaming(
|
||||
&mut self,
|
||||
text: &str,
|
||||
client_tx: &tokio::sync::mpsc::Sender<Result<ServerMessage, tonic::Status>>,
|
||||
) -> anyhow::Result<ConversationResponse> {
|
||||
use futures::StreamExt;
|
||||
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
|
||||
|
||||
let context_header = self.build_context_header(text).await;
|
||||
let input_text = format!("{context_header}\n{text}");
|
||||
|
||||
let conv_id_for_accumulate: String;
|
||||
let mut events = Vec::new();
|
||||
|
||||
if let Some(ref conv_id) = self.conversation_id {
|
||||
conv_id_for_accumulate = conv_id.clone();
|
||||
let req = AppendConversationRequest {
|
||||
inputs: ConversationInput::Text(input_text.clone()),
|
||||
completion_args: None,
|
||||
handoff_execution: None,
|
||||
store: Some(true),
|
||||
tool_confirmations: None,
|
||||
stream: false,
|
||||
stream: true,
|
||||
};
|
||||
match self.state
|
||||
.mistral
|
||||
.append_conversation_async(conv_id, &req)
|
||||
match self.state.mistral
|
||||
.append_conversation_stream_async(conv_id, &req)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => Ok(resp),
|
||||
Ok(stream) => {
|
||||
tokio::pin!(stream);
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
if let Some(delta) = event.text_delta() {
|
||||
let _ = client_tx.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::Delta(TextDelta {
|
||||
text: delta,
|
||||
})),
|
||||
})).await;
|
||||
}
|
||||
events.push(event);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Stream event error: {}", e.message);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) if e.message.contains("function calls and responses")
|
||||
|| e.message.contains("invalid_request_error") =>
|
||||
{
|
||||
warn!(
|
||||
conversation_id = conv_id.as_str(),
|
||||
error = e.message.as_str(),
|
||||
"Conversation corrupted — creating fresh conversation"
|
||||
"Conversation corrupted — creating fresh conversation (streaming)"
|
||||
);
|
||||
self.conversation_id = None;
|
||||
self.create_fresh_conversation(input_text).await
|
||||
return self.create_fresh_conversation_streaming(input_text, client_tx).await;
|
||||
}
|
||||
Err(e) => Err(anyhow::anyhow!("append_conversation failed: {}", e.message)),
|
||||
Err(e) => return Err(anyhow::anyhow!("append_conversation_stream failed: {}", e.message)),
|
||||
}
|
||||
} else {
|
||||
self.create_fresh_conversation(input_text).await
|
||||
return self.create_fresh_conversation_streaming(input_text, client_tx).await;
|
||||
}
|
||||
|
||||
Ok(conversation_stream::accumulate(&conv_id_for_accumulate, &events))
|
||||
}
|
||||
|
||||
/// Create a fresh streaming conversation — used on first message or after corruption.
|
||||
async fn create_fresh_conversation_streaming(
|
||||
&mut self,
|
||||
input_text: String,
|
||||
client_tx: &tokio::sync::mpsc::Sender<Result<ServerMessage, tonic::Status>>,
|
||||
) -> anyhow::Result<ConversationResponse> {
|
||||
use futures::StreamExt;
|
||||
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
|
||||
|
||||
let instructions = self.build_instructions();
|
||||
let req = CreateConversationRequest {
|
||||
inputs: ConversationInput::Text(input_text),
|
||||
model: Some(self.model.clone()),
|
||||
agent_id: None,
|
||||
agent_version: None,
|
||||
name: Some(format!("code-{}", self.project_name)),
|
||||
description: None,
|
||||
instructions: Some(instructions),
|
||||
completion_args: None,
|
||||
tools: Some(self.build_tool_definitions()),
|
||||
handoff_execution: None,
|
||||
metadata: None,
|
||||
store: Some(true),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let stream = self.state.mistral
|
||||
.create_conversation_stream_async(&req)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("create_conversation_stream failed: {}", e.message))?;
|
||||
|
||||
tokio::pin!(stream);
|
||||
let mut events = Vec::new();
|
||||
let mut conversation_id = String::new();
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
match result {
|
||||
Ok(event) => {
|
||||
if let Some(delta) = event.text_delta() {
|
||||
let _ = client_tx.send(Ok(ServerMessage {
|
||||
payload: Some(server_message::Payload::Delta(TextDelta {
|
||||
text: delta,
|
||||
})),
|
||||
})).await;
|
||||
}
|
||||
events.push(event);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Stream event error: {}", e.message);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let resp = conversation_stream::accumulate(&conversation_id, &events);
|
||||
|
||||
// Extract conversation_id from the accumulated response
|
||||
if !resp.conversation_id.is_empty() {
|
||||
conversation_id = resp.conversation_id.clone();
|
||||
}
|
||||
|
||||
// If we still don't have a conversation_id, the streaming didn't return one.
|
||||
// This happens with the Conversations API — the ID comes from the response headers
|
||||
// or the ResponseDone event. Let's check events for it.
|
||||
if conversation_id.is_empty() {
|
||||
// Fallback: create non-streaming to get the conversation_id
|
||||
warn!("Streaming didn't return conversation_id — falling back to non-streaming create");
|
||||
return self.create_fresh_conversation(
|
||||
"".into() // empty — conversation already created, just need the ID
|
||||
).await;
|
||||
}
|
||||
|
||||
self.conversation_id = Some(conversation_id.clone());
|
||||
self.state.store.set_code_session_conversation(
|
||||
&self.session_id,
|
||||
&conversation_id,
|
||||
);
|
||||
|
||||
info!(
|
||||
conversation_id = conversation_id.as_str(),
|
||||
"Created streaming Mistral conversation for code session"
|
||||
);
|
||||
|
||||
Ok(resp)
|
||||
}
|
||||
|
||||
/// Post user message to the Matrix room.
|
||||
|
||||
@@ -6051,3 +6051,213 @@ class Calculator:
|
||||
assert!(sessions.is_empty(), "Session should be marked complete, not running");
|
||||
}
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════
|
||||
// Conversation streaming — Mistral API integration
|
||||
// ══════════════════════════════════════════════════════════════════════════
|
||||
|
||||
mod conversation_stream_tests {
|
||||
use futures::StreamExt;
|
||||
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
|
||||
use mistralai_client::v1::conversations::*;
|
||||
|
||||
fn load_env() -> Option<String> {
|
||||
let env_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(".env");
|
||||
if let Ok(contents) = std::fs::read_to_string(&env_path) {
|
||||
for line in contents.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') { continue; }
|
||||
if let Some((k, v)) = line.split_once('=') {
|
||||
std::env::set_var(k.trim(), v.trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
std::env::var("SOL_MISTRAL_API_KEY").ok()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_create_conversation() {
|
||||
let Some(api_key) = load_env() else { eprintln!("Skipping: no API key"); return; };
|
||||
let client = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap();
|
||||
|
||||
let req = CreateConversationRequest {
|
||||
inputs: ConversationInput::Text("What is 7 * 8? Answer with just the number.".into()),
|
||||
model: Some("mistral-medium-latest".into()),
|
||||
agent_id: None,
|
||||
agent_version: None,
|
||||
name: None,
|
||||
description: None,
|
||||
instructions: Some("Answer concisely.".into()),
|
||||
completion_args: None,
|
||||
tools: None,
|
||||
handoff_execution: None,
|
||||
metadata: None,
|
||||
store: Some(false),
|
||||
stream: true,
|
||||
};
|
||||
|
||||
let stream = client.create_conversation_stream_async(&req).await.unwrap();
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut events = Vec::new();
|
||||
let mut saw_text = false;
|
||||
let mut saw_done = false;
|
||||
|
||||
while let Some(result) = stream.next().await {
|
||||
let event = result.unwrap();
|
||||
if event.text_delta().is_some() { saw_text = true; }
|
||||
if matches!(&event, ConversationEvent::ResponseDone { .. }) { saw_done = true; }
|
||||
events.push(event);
|
||||
}
|
||||
|
||||
assert!(saw_text, "Should receive text deltas");
|
||||
assert!(saw_done, "Should receive ResponseDone");
|
||||
|
||||
let resp = conversation_stream::accumulate("", &events);
|
||||
let text = resp.assistant_text().unwrap_or_default();
|
||||
assert!(text.contains("56"), "Should compute 7*8=56, got: {text}");
|
||||
assert!(resp.usage.total_tokens > 0, "Should have token usage");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_append_conversation() {
|
||||
let Some(api_key) = load_env() else { eprintln!("Skipping: no API key"); return; };
|
||||
let client = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap();
|
||||
|
||||
// Create non-streaming first
|
||||
let create_req = CreateConversationRequest {
|
||||
inputs: ConversationInput::Text("The magic word is SUNBEAM. Just say ok.".into()),
|
||||
model: Some("mistral-medium-latest".into()),
|
||||
agent_id: None,
|
||||
agent_version: None,
|
||||
name: None,
|
||||
description: None,
|
||||
instructions: Some("Be brief.".into()),
|
||||
completion_args: None,
|
||||
tools: None,
|
||||
handoff_execution: None,
|
||||
metadata: None,
|
||||
store: Some(true),
|
||||
stream: false,
|
||||
};
|
||||
let created = client.create_conversation_async(&create_req).await.unwrap();
|
||||
|
||||
// Append with streaming
|
||||
let append_req = AppendConversationRequest {
|
||||
inputs: ConversationInput::Text("What is the magic word?".into()),
|
||||
completion_args: None,
|
||||
handoff_execution: None,
|
||||
store: Some(true),
|
||||
tool_confirmations: None,
|
||||
stream: true,
|
||||
};
|
||||
let stream = client
|
||||
.append_conversation_stream_async(&created.conversation_id, &append_req)
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::pin!(stream);
|
||||
|
||||
let mut events = Vec::new();
|
||||
while let Some(result) = stream.next().await {
|
||||
events.push(result.unwrap());
|
||||
}
|
||||
|
||||
let resp = conversation_stream::accumulate(&created.conversation_id, &events);
|
||||
let text = resp.assistant_text().unwrap_or_default().to_uppercase();
|
||||
assert!(text.contains("SUNBEAM"), "Should recall magic word, got: {text}");
|
||||
|
||||
// Cleanup
|
||||
let _ = client.delete_conversation_async(&created.conversation_id).await;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_sse_line_data_prefix() {
|
||||
let line = r#"data: {"type":"conversation.response.started"}"#;
|
||||
let event = conversation_stream::parse_sse_line(line).unwrap();
|
||||
assert!(event.is_some());
|
||||
assert!(matches!(event.unwrap(), ConversationEvent::ResponseStarted { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_sse_line_event_prefix_skipped() {
|
||||
let line = "event: conversation.response.started";
|
||||
assert!(conversation_stream::parse_sse_line(line).unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_sse_line_done() {
|
||||
assert!(conversation_stream::parse_sse_line("data: [DONE]").unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_sse_line_empty() {
|
||||
assert!(conversation_stream::parse_sse_line("").unwrap().is_none());
|
||||
assert!(conversation_stream::parse_sse_line(" ").unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accumulate_text_and_usage() {
|
||||
let events = vec![
|
||||
ConversationEvent::ResponseStarted { created_at: None },
|
||||
ConversationEvent::MessageOutput {
|
||||
id: "m1".into(),
|
||||
content: serde_json::json!("hello "),
|
||||
role: "assistant".into(),
|
||||
output_index: 0,
|
||||
content_index: 0,
|
||||
model: None,
|
||||
},
|
||||
ConversationEvent::MessageOutput {
|
||||
id: "m1".into(),
|
||||
content: serde_json::json!("world"),
|
||||
role: "assistant".into(),
|
||||
output_index: 0,
|
||||
content_index: 0,
|
||||
model: None,
|
||||
},
|
||||
ConversationEvent::ResponseDone {
|
||||
usage: ConversationUsageInfo {
|
||||
prompt_tokens: 10,
|
||||
completion_tokens: 5,
|
||||
total_tokens: 15,
|
||||
},
|
||||
created_at: None,
|
||||
},
|
||||
];
|
||||
|
||||
let resp = conversation_stream::accumulate("conv-123", &events);
|
||||
assert_eq!(resp.conversation_id, "conv-123");
|
||||
assert_eq!(resp.assistant_text().unwrap(), "hello world");
|
||||
assert_eq!(resp.usage.total_tokens, 15);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_accumulate_with_function_calls() {
|
||||
let events = vec![
|
||||
ConversationEvent::FunctionCall {
|
||||
id: "fc-1".into(),
|
||||
name: "search_web".into(),
|
||||
tool_call_id: "tc-1".into(),
|
||||
arguments: r#"{"query":"rust"}"#.into(),
|
||||
output_index: 0,
|
||||
model: None,
|
||||
confirmation_status: None,
|
||||
},
|
||||
ConversationEvent::ResponseDone {
|
||||
usage: ConversationUsageInfo {
|
||||
prompt_tokens: 20,
|
||||
completion_tokens: 10,
|
||||
total_tokens: 30,
|
||||
},
|
||||
created_at: None,
|
||||
},
|
||||
];
|
||||
|
||||
let resp = conversation_stream::accumulate("conv-456", &events);
|
||||
assert!(resp.assistant_text().is_none());
|
||||
let calls = resp.function_calls();
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].name, "search_web");
|
||||
assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc-1"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user