feat: streaming conversation responses via Mistral API
- Bump mistralai-client to 1.2.0 (conversation streaming support) - session.rs: add create_or_append_conversation_streaming() — calls Mistral with stream:true, emits TextDelta messages to gRPC client as SSE chunks arrive, accumulates into ConversationResponse for orchestrator tool loop. Handles corrupted conversation recovery - service.rs: session_chat_via_orchestrator uses streaming variant - Integration tests: streaming create + append against real Mistral API, SSE parsing, accumulate text + function calls
This commit is contained in:
40
Cargo.lock
generated
40
Cargo.lock
generated
@@ -828,6 +828,15 @@ dependencies = [
|
|||||||
"cfg-if",
|
"cfg-if",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "crossbeam-channel"
|
||||||
|
version = "0.5.15"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-utils",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "crossbeam-deque"
|
name = "crossbeam-deque"
|
||||||
version = "0.8.6"
|
version = "0.8.6"
|
||||||
@@ -1357,7 +1366,7 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
"option-ext",
|
"option-ext",
|
||||||
"redox_users",
|
"redox_users",
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -1482,7 +1491,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3070,9 +3079,9 @@ dependencies = [
|
|||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "mistralai-client"
|
name = "mistralai-client"
|
||||||
version = "1.1.0"
|
version = "1.2.0"
|
||||||
source = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/"
|
source = "sparse+https://src.sunbeam.pt/api/packages/studio/cargo/"
|
||||||
checksum = "6a17c5600508f30965a8e6ab78e947a0db7017edef8ffb022aa9ac71f50ebaff"
|
checksum = "24258b2ba72432f9c147a5a80e6d862c7b3c43db859135c619c948cf6c9e4fdf"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-stream",
|
"async-stream",
|
||||||
"async-trait",
|
"async-trait",
|
||||||
@@ -3154,7 +3163,7 @@ version = "0.50.3"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
checksum = "7957b9740744892f114936ab4a57b3f487491bbeafaf8083688b16841a4240e5"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -3777,7 +3786,7 @@ dependencies = [
|
|||||||
"once_cell",
|
"once_cell",
|
||||||
"socket2",
|
"socket2",
|
||||||
"tracing",
|
"tracing",
|
||||||
"windows-sys 0.60.2",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4290,7 +4299,7 @@ dependencies = [
|
|||||||
"errno",
|
"errno",
|
||||||
"libc",
|
"libc",
|
||||||
"linux-raw-sys 0.12.1",
|
"linux-raw-sys 0.12.1",
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -4736,6 +4745,7 @@ dependencies = [
|
|||||||
"tonic-prost",
|
"tonic-prost",
|
||||||
"tonic-prost-build",
|
"tonic-prost-build",
|
||||||
"tracing",
|
"tracing",
|
||||||
|
"tracing-appender",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
"tree-sitter",
|
"tree-sitter",
|
||||||
"tree-sitter-python",
|
"tree-sitter-python",
|
||||||
@@ -5355,7 +5365,7 @@ dependencies = [
|
|||||||
"getrandom 0.4.2",
|
"getrandom 0.4.2",
|
||||||
"once_cell",
|
"once_cell",
|
||||||
"rustix 1.1.4",
|
"rustix 1.1.4",
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@@ -5838,6 +5848,18 @@ dependencies = [
|
|||||||
"tracing-core",
|
"tracing-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tracing-appender"
|
||||||
|
version = "0.2.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "786d480bce6247ab75f005b14ae1624ad978d3029d9113f0a22fa1ac773faeaf"
|
||||||
|
dependencies = [
|
||||||
|
"crossbeam-channel",
|
||||||
|
"thiserror 2.0.18",
|
||||||
|
"time",
|
||||||
|
"tracing-subscriber",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tracing-attributes"
|
name = "tracing-attributes"
|
||||||
version = "0.1.31"
|
version = "0.1.31"
|
||||||
@@ -6425,7 +6447,7 @@ version = "0.1.11"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
checksum = "c2a7b1c03c876122aa43f3020e6c3c3ee5c05081c9a00739faf7503aeba10d22"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"windows-sys 0.61.2",
|
"windows-sys 0.59.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ name = "sol"
|
|||||||
path = "src/main.rs"
|
path = "src/main.rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
mistralai-client = { version = "1.1.0", registry = "sunbeam" }
|
mistralai-client = { version = "1.2.0", registry = "sunbeam" }
|
||||||
matrix-sdk = { version = "0.9", features = ["e2e-encryption", "sqlite"] }
|
matrix-sdk = { version = "0.9", features = ["e2e-encryption", "sqlite"] }
|
||||||
opensearch = "2"
|
opensearch = "2"
|
||||||
tokio = { version = "1", features = ["full"] }
|
tokio = { version = "1", features = ["full"] }
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ async fn session_chat_via_orchestrator(
|
|||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
use crate::orchestrator::event::*;
|
use crate::orchestrator::event::*;
|
||||||
|
|
||||||
let conversation_response = session.create_or_append_conversation(text).await?;
|
let conversation_response = session.create_or_append_conversation_streaming(text, tx).await?;
|
||||||
session.post_to_matrix(text).await;
|
session.post_to_matrix(text).await;
|
||||||
|
|
||||||
let request_id = RequestId::new();
|
let request_id = RequestId::new();
|
||||||
|
|||||||
@@ -388,46 +388,160 @@ you also have access to server-side tools: search_archive, search_web, research,
|
|||||||
tools
|
tools
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create or append to the Mistral conversation. Returns the response
|
/// Create or append to the Mistral conversation with streaming.
|
||||||
/// for the orchestrator to run through its tool loop.
|
/// Emits TextDelta messages to the gRPC client as chunks arrive.
|
||||||
pub async fn create_or_append_conversation(
|
/// Returns the accumulated response for the orchestrator's tool loop.
|
||||||
|
pub async fn create_or_append_conversation_streaming(
|
||||||
&mut self,
|
&mut self,
|
||||||
text: &str,
|
text: &str,
|
||||||
|
client_tx: &tokio::sync::mpsc::Sender<Result<ServerMessage, tonic::Status>>,
|
||||||
) -> anyhow::Result<ConversationResponse> {
|
) -> anyhow::Result<ConversationResponse> {
|
||||||
|
use futures::StreamExt;
|
||||||
|
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
|
||||||
|
|
||||||
let context_header = self.build_context_header(text).await;
|
let context_header = self.build_context_header(text).await;
|
||||||
let input_text = format!("{context_header}\n{text}");
|
let input_text = format!("{context_header}\n{text}");
|
||||||
|
|
||||||
|
let conv_id_for_accumulate: String;
|
||||||
|
let mut events = Vec::new();
|
||||||
|
|
||||||
if let Some(ref conv_id) = self.conversation_id {
|
if let Some(ref conv_id) = self.conversation_id {
|
||||||
|
conv_id_for_accumulate = conv_id.clone();
|
||||||
let req = AppendConversationRequest {
|
let req = AppendConversationRequest {
|
||||||
inputs: ConversationInput::Text(input_text.clone()),
|
inputs: ConversationInput::Text(input_text.clone()),
|
||||||
completion_args: None,
|
completion_args: None,
|
||||||
handoff_execution: None,
|
handoff_execution: None,
|
||||||
store: Some(true),
|
store: Some(true),
|
||||||
tool_confirmations: None,
|
tool_confirmations: None,
|
||||||
stream: false,
|
stream: true,
|
||||||
};
|
};
|
||||||
match self.state
|
match self.state.mistral
|
||||||
.mistral
|
.append_conversation_stream_async(conv_id, &req)
|
||||||
.append_conversation_async(conv_id, &req)
|
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(resp) => Ok(resp),
|
Ok(stream) => {
|
||||||
|
tokio::pin!(stream);
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(event) => {
|
||||||
|
if let Some(delta) = event.text_delta() {
|
||||||
|
let _ = client_tx.send(Ok(ServerMessage {
|
||||||
|
payload: Some(server_message::Payload::Delta(TextDelta {
|
||||||
|
text: delta,
|
||||||
|
})),
|
||||||
|
})).await;
|
||||||
|
}
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Stream event error: {}", e.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Err(e) if e.message.contains("function calls and responses")
|
Err(e) if e.message.contains("function calls and responses")
|
||||||
|| e.message.contains("invalid_request_error") =>
|
|| e.message.contains("invalid_request_error") =>
|
||||||
{
|
{
|
||||||
warn!(
|
warn!(
|
||||||
conversation_id = conv_id.as_str(),
|
conversation_id = conv_id.as_str(),
|
||||||
error = e.message.as_str(),
|
error = e.message.as_str(),
|
||||||
"Conversation corrupted — creating fresh conversation"
|
"Conversation corrupted — creating fresh conversation (streaming)"
|
||||||
);
|
);
|
||||||
self.conversation_id = None;
|
self.conversation_id = None;
|
||||||
self.create_fresh_conversation(input_text).await
|
return self.create_fresh_conversation_streaming(input_text, client_tx).await;
|
||||||
}
|
}
|
||||||
Err(e) => Err(anyhow::anyhow!("append_conversation failed: {}", e.message)),
|
Err(e) => return Err(anyhow::anyhow!("append_conversation_stream failed: {}", e.message)),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
self.create_fresh_conversation(input_text).await
|
return self.create_fresh_conversation_streaming(input_text, client_tx).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Ok(conversation_stream::accumulate(&conv_id_for_accumulate, &events))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a fresh streaming conversation — used on first message or after corruption.
|
||||||
|
async fn create_fresh_conversation_streaming(
|
||||||
|
&mut self,
|
||||||
|
input_text: String,
|
||||||
|
client_tx: &tokio::sync::mpsc::Sender<Result<ServerMessage, tonic::Status>>,
|
||||||
|
) -> anyhow::Result<ConversationResponse> {
|
||||||
|
use futures::StreamExt;
|
||||||
|
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
|
||||||
|
|
||||||
|
let instructions = self.build_instructions();
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text(input_text),
|
||||||
|
model: Some(self.model.clone()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: Some(format!("code-{}", self.project_name)),
|
||||||
|
description: None,
|
||||||
|
instructions: Some(instructions),
|
||||||
|
completion_args: None,
|
||||||
|
tools: Some(self.build_tool_definitions()),
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(true),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = self.state.mistral
|
||||||
|
.create_conversation_stream_async(&req)
|
||||||
|
.await
|
||||||
|
.map_err(|e| anyhow::anyhow!("create_conversation_stream failed: {}", e.message))?;
|
||||||
|
|
||||||
|
tokio::pin!(stream);
|
||||||
|
let mut events = Vec::new();
|
||||||
|
let mut conversation_id = String::new();
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
match result {
|
||||||
|
Ok(event) => {
|
||||||
|
if let Some(delta) = event.text_delta() {
|
||||||
|
let _ = client_tx.send(Ok(ServerMessage {
|
||||||
|
payload: Some(server_message::Payload::Delta(TextDelta {
|
||||||
|
text: delta,
|
||||||
|
})),
|
||||||
|
})).await;
|
||||||
|
}
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Stream event error: {}", e.message);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = conversation_stream::accumulate(&conversation_id, &events);
|
||||||
|
|
||||||
|
// Extract conversation_id from the accumulated response
|
||||||
|
if !resp.conversation_id.is_empty() {
|
||||||
|
conversation_id = resp.conversation_id.clone();
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we still don't have a conversation_id, the streaming didn't return one.
|
||||||
|
// This happens with the Conversations API — the ID comes from the response headers
|
||||||
|
// or the ResponseDone event. Let's check events for it.
|
||||||
|
if conversation_id.is_empty() {
|
||||||
|
// Fallback: create non-streaming to get the conversation_id
|
||||||
|
warn!("Streaming didn't return conversation_id — falling back to non-streaming create");
|
||||||
|
return self.create_fresh_conversation(
|
||||||
|
"".into() // empty — conversation already created, just need the ID
|
||||||
|
).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
self.conversation_id = Some(conversation_id.clone());
|
||||||
|
self.state.store.set_code_session_conversation(
|
||||||
|
&self.session_id,
|
||||||
|
&conversation_id,
|
||||||
|
);
|
||||||
|
|
||||||
|
info!(
|
||||||
|
conversation_id = conversation_id.as_str(),
|
||||||
|
"Created streaming Mistral conversation for code session"
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Post user message to the Matrix room.
|
/// Post user message to the Matrix room.
|
||||||
|
|||||||
@@ -6051,3 +6051,213 @@ class Calculator:
|
|||||||
assert!(sessions.is_empty(), "Session should be marked complete, not running");
|
assert!(sessions.is_empty(), "Session should be marked complete, not running");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ══════════════════════════════════════════════════════════════════════════
|
||||||
|
// Conversation streaming — Mistral API integration
|
||||||
|
// ══════════════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
mod conversation_stream_tests {
|
||||||
|
use futures::StreamExt;
|
||||||
|
use mistralai_client::v1::conversation_stream::{self, ConversationEvent};
|
||||||
|
use mistralai_client::v1::conversations::*;
|
||||||
|
|
||||||
|
fn load_env() -> Option<String> {
|
||||||
|
let env_path = std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join(".env");
|
||||||
|
if let Ok(contents) = std::fs::read_to_string(&env_path) {
|
||||||
|
for line in contents.lines() {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.is_empty() || line.starts_with('#') { continue; }
|
||||||
|
if let Some((k, v)) = line.split_once('=') {
|
||||||
|
std::env::set_var(k.trim(), v.trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::env::var("SOL_MISTRAL_API_KEY").ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_streaming_create_conversation() {
|
||||||
|
let Some(api_key) = load_env() else { eprintln!("Skipping: no API key"); return; };
|
||||||
|
let client = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap();
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is 7 * 8? Answer with just the number.".into()),
|
||||||
|
model: Some("mistral-medium-latest".into()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Answer concisely.".into()),
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(false),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = client.create_conversation_stream_async(&req).await.unwrap();
|
||||||
|
tokio::pin!(stream);
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
let mut saw_text = false;
|
||||||
|
let mut saw_done = false;
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
let event = result.unwrap();
|
||||||
|
if event.text_delta().is_some() { saw_text = true; }
|
||||||
|
if matches!(&event, ConversationEvent::ResponseDone { .. }) { saw_done = true; }
|
||||||
|
events.push(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(saw_text, "Should receive text deltas");
|
||||||
|
assert!(saw_done, "Should receive ResponseDone");
|
||||||
|
|
||||||
|
let resp = conversation_stream::accumulate("", &events);
|
||||||
|
let text = resp.assistant_text().unwrap_or_default();
|
||||||
|
assert!(text.contains("56"), "Should compute 7*8=56, got: {text}");
|
||||||
|
assert!(resp.usage.total_tokens > 0, "Should have token usage");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_streaming_append_conversation() {
|
||||||
|
let Some(api_key) = load_env() else { eprintln!("Skipping: no API key"); return; };
|
||||||
|
let client = mistralai_client::v1::client::Client::new(Some(api_key), None, None, None).unwrap();
|
||||||
|
|
||||||
|
// Create non-streaming first
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("The magic word is SUNBEAM. Just say ok.".into()),
|
||||||
|
model: Some("mistral-medium-latest".into()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Be brief.".into()),
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(true),
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation_async(&create_req).await.unwrap();
|
||||||
|
|
||||||
|
// Append with streaming
|
||||||
|
let append_req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is the magic word?".into()),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: Some(true),
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
let stream = client
|
||||||
|
.append_conversation_stream_async(&created.conversation_id, &append_req)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::pin!(stream);
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
events.push(result.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = conversation_stream::accumulate(&created.conversation_id, &events);
|
||||||
|
let text = resp.assistant_text().unwrap_or_default().to_uppercase();
|
||||||
|
assert!(text.contains("SUNBEAM"), "Should recall magic word, got: {text}");
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = client.delete_conversation_async(&created.conversation_id).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_sse_line_data_prefix() {
|
||||||
|
let line = r#"data: {"type":"conversation.response.started"}"#;
|
||||||
|
let event = conversation_stream::parse_sse_line(line).unwrap();
|
||||||
|
assert!(event.is_some());
|
||||||
|
assert!(matches!(event.unwrap(), ConversationEvent::ResponseStarted { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_sse_line_event_prefix_skipped() {
|
||||||
|
let line = "event: conversation.response.started";
|
||||||
|
assert!(conversation_stream::parse_sse_line(line).unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_sse_line_done() {
|
||||||
|
assert!(conversation_stream::parse_sse_line("data: [DONE]").unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_sse_line_empty() {
|
||||||
|
assert!(conversation_stream::parse_sse_line("").unwrap().is_none());
|
||||||
|
assert!(conversation_stream::parse_sse_line(" ").unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accumulate_text_and_usage() {
|
||||||
|
let events = vec![
|
||||||
|
ConversationEvent::ResponseStarted { created_at: None },
|
||||||
|
ConversationEvent::MessageOutput {
|
||||||
|
id: "m1".into(),
|
||||||
|
content: serde_json::json!("hello "),
|
||||||
|
role: "assistant".into(),
|
||||||
|
output_index: 0,
|
||||||
|
content_index: 0,
|
||||||
|
model: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::MessageOutput {
|
||||||
|
id: "m1".into(),
|
||||||
|
content: serde_json::json!("world"),
|
||||||
|
role: "assistant".into(),
|
||||||
|
output_index: 0,
|
||||||
|
content_index: 0,
|
||||||
|
model: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::ResponseDone {
|
||||||
|
usage: ConversationUsageInfo {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
created_at: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let resp = conversation_stream::accumulate("conv-123", &events);
|
||||||
|
assert_eq!(resp.conversation_id, "conv-123");
|
||||||
|
assert_eq!(resp.assistant_text().unwrap(), "hello world");
|
||||||
|
assert_eq!(resp.usage.total_tokens, 15);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accumulate_with_function_calls() {
|
||||||
|
let events = vec![
|
||||||
|
ConversationEvent::FunctionCall {
|
||||||
|
id: "fc-1".into(),
|
||||||
|
name: "search_web".into(),
|
||||||
|
tool_call_id: "tc-1".into(),
|
||||||
|
arguments: r#"{"query":"rust"}"#.into(),
|
||||||
|
output_index: 0,
|
||||||
|
model: None,
|
||||||
|
confirmation_status: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::ResponseDone {
|
||||||
|
usage: ConversationUsageInfo {
|
||||||
|
prompt_tokens: 20,
|
||||||
|
completion_tokens: 10,
|
||||||
|
total_tokens: 30,
|
||||||
|
},
|
||||||
|
created_at: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let resp = conversation_stream::accumulate("conv-456", &events);
|
||||||
|
assert!(resp.assistant_text().is_none());
|
||||||
|
let calls = resp.function_calls();
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "search_web");
|
||||||
|
assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc-1"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user