From 4c7f1cde0a5df76a210fdf6b34806b1369c3b0a2 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Tue, 24 Mar 2026 21:16:39 +0000 Subject: [PATCH] feat: streaming Conversations API support (v1.2.0) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add conversation_stream module with full streaming support for the Mistral Conversations API: - ConversationEvent enum matching API SSE event types: ResponseStarted, MessageOutput (text deltas), FunctionCall, ResponseDone (with usage), ResponseError, tool execution, agent handoff events - parse_sse_line() handles SSE format (skips event: lines, parses data: JSON, handles [DONE] and comments) - accumulate() collects streaming events into a ConversationResponse - create_conversation_stream_async() and append_conversation_stream_async() client methods - Byte-boundary buffering in sse_to_conversation_events — handles JSON split across TCP frames - Integration tests hit real Mistral API: create stream, append stream, stream/non-stream output equivalence --- Cargo.toml | 2 +- src/v1/client.rs | 101 ++++++- src/v1/conversation_stream.rs | 395 +++++++++++++++++++++++++++ src/v1/mod.rs | 1 + tests/v1_conversation_stream_test.rs | 183 +++++++++++++ 5 files changed, 679 insertions(+), 3 deletions(-) create mode 100644 src/v1/conversation_stream.rs create mode 100644 tests/v1_conversation_stream_test.rs diff --git a/Cargo.toml b/Cargo.toml index 9d4a04c..d1620e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "mistralai-client" description = "Mistral AI API client library for Rust (unofficial)." license = "Apache-2.0" -version = "1.1.0" +version = "1.2.0" edition = "2021" rust-version = "1.76.0" diff --git a/src/v1/client.rs b/src/v1/client.rs index 5c10c96..3526dbd 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -10,8 +10,8 @@ use std::{ }; use crate::v1::{ - agents, audio, batch, chat, chat_stream, constants, conversations, embedding, error, files, - fim, fine_tuning, model_list, moderation, ocr, tool, utils, + agents, audio, batch, chat, chat_stream, constants, conversation_stream, conversations, + embedding, error, files, fim, fine_tuning, model_list, moderation, ocr, tool, utils, }; #[derive(Debug)] @@ -1054,6 +1054,103 @@ impl Client { .map_err(|e| self.to_api_error(e)) } + /// Create a conversation with streaming response. + /// Returns a stream of `ConversationEvent`s as SSE events arrive. + pub async fn create_conversation_stream_async( + &self, + request: &conversations::CreateConversationRequest, + ) -> Result< + impl futures::Stream>, + error::ApiError, + > { + // Ensure stream is true + let mut req = request.clone(); + req.stream = true; + + let response = self.post_stream("/conversations", &req).await?; + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(error::ApiError { + message: format!("{}: {}", status, text), + }); + } + + Ok(self.sse_to_conversation_events(response)) + } + + /// Append to a conversation with streaming response. + /// Returns a stream of `ConversationEvent`s as SSE events arrive. + pub async fn append_conversation_stream_async( + &self, + conversation_id: &str, + request: &conversations::AppendConversationRequest, + ) -> Result< + impl futures::Stream>, + error::ApiError, + > { + let mut req = request.clone(); + req.stream = true; + + let response = self + .post_stream(&format!("/conversations/{}", conversation_id), &req) + .await?; + if !response.status().is_success() { + let status = response.status(); + let text = response.text().await.unwrap_or_default(); + return Err(error::ApiError { + message: format!("{}: {}", status, text), + }); + } + + Ok(self.sse_to_conversation_events(response)) + } + + /// Convert a raw SSE response into a stream of conversation events. + /// Handles byte-boundary splits by buffering incomplete lines. + fn sse_to_conversation_events( + &self, + response: reqwest::Response, + ) -> impl futures::Stream> { + use futures::stream; + + let mut buffer = String::new(); + + response.bytes_stream().flat_map(move |bytes_result| { + match bytes_result { + Ok(bytes) => { + let text = match String::from_utf8(bytes.to_vec()) { + Ok(t) => t, + Err(e) => return stream::iter(vec![Err(error::ApiError { + message: format!("UTF-8 decode error: {e}"), + })]), + }; + + buffer.push_str(&text); + + let mut events = Vec::new(); + + // Process complete lines only (ending with \n) + while let Some(newline_pos) = buffer.find('\n') { + let line = buffer[..newline_pos].to_string(); + buffer = buffer[newline_pos + 1..].to_string(); + + match conversation_stream::parse_sse_line(&line) { + Ok(Some(event)) => events.push(Ok(event)), + Ok(None) => {} + Err(e) => events.push(Err(e)), + } + } + + stream::iter(events) + } + Err(e) => stream::iter(vec![Err(error::ApiError { + message: format!("Stream read error: {e}"), + })]), + } + }) + } + pub fn get_conversation( &self, conversation_id: &str, diff --git a/src/v1/conversation_stream.rs b/src/v1/conversation_stream.rs new file mode 100644 index 0000000..22860db --- /dev/null +++ b/src/v1/conversation_stream.rs @@ -0,0 +1,395 @@ +//! Streaming support for the Conversations API. +//! +//! When `stream: true` is set on a conversation request, the API returns +//! Server-Sent Events (SSE). Each event has an `event:` type line and a +//! `data:` JSON payload, discriminated by the `type` field. +//! +//! Event types: +//! - `conversation.response.started` — generation began +//! - `message.output.delta` — partial assistant text +//! - `function.call.delta` — a function call (tool call) +//! - `conversation.response.done` — generation complete (has usage) +//! - `conversation.response.error` — error during generation + +use serde::{Deserialize, Serialize}; + +use crate::v1::{conversations, error}; + +// ── SSE event types ──────────────────────────────────────────────────── + +/// A streaming event from the Conversations API. +/// The `type` field discriminates the variant. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ConversationEvent { + /// Generation started. + #[serde(rename = "conversation.response.started")] + ResponseStarted { + #[serde(default)] + created_at: Option, + }, + + /// Partial assistant text output. + #[serde(rename = "message.output.delta")] + MessageOutput { + id: String, + content: serde_json::Value, // string or array of chunks + #[serde(default)] + role: String, + #[serde(default)] + output_index: u32, + #[serde(default)] + content_index: u32, + #[serde(default)] + model: Option, + }, + + /// A function call from the model. + #[serde(rename = "function.call.delta")] + FunctionCall { + id: String, + name: String, + tool_call_id: String, + arguments: String, + #[serde(default)] + output_index: u32, + #[serde(default)] + model: Option, + #[serde(default)] + confirmation_status: Option, + }, + + /// Generation complete — includes token usage. + #[serde(rename = "conversation.response.done")] + ResponseDone { + usage: conversations::ConversationUsageInfo, + #[serde(default)] + created_at: Option, + }, + + /// Error during generation. + #[serde(rename = "conversation.response.error")] + ResponseError { + message: String, + code: i32, + #[serde(default)] + created_at: Option, + }, + + /// Tool execution started (server-side). + #[serde(rename = "tool.execution.started")] + ToolExecutionStarted { + #[serde(flatten)] + extra: serde_json::Value, + }, + + /// Tool execution delta (server-side). + #[serde(rename = "tool.execution.delta")] + ToolExecutionDelta { + #[serde(flatten)] + extra: serde_json::Value, + }, + + /// Tool execution done (server-side). + #[serde(rename = "tool.execution.done")] + ToolExecutionDone { + #[serde(flatten)] + extra: serde_json::Value, + }, + + /// Agent handoff started. + #[serde(rename = "agent.handoff.started")] + AgentHandoffStarted { + #[serde(flatten)] + extra: serde_json::Value, + }, + + /// Agent handoff done. + #[serde(rename = "agent.handoff.done")] + AgentHandoffDone { + #[serde(flatten)] + extra: serde_json::Value, + }, +} + +impl ConversationEvent { + /// Extract text content from a MessageOutput event. + pub fn text_delta(&self) -> Option { + match self { + ConversationEvent::MessageOutput { content, .. } => { + // content can be a string or an array of chunks + if let Some(s) = content.as_str() { + Some(s.to_string()) + } else if let Some(arr) = content.as_array() { + // Array of chunks — extract text from TextChunk items + let mut text = String::new(); + for chunk in arr { + if let Some(t) = chunk.get("text").and_then(|v| v.as_str()) { + text.push_str(t); + } + } + if text.is_empty() { None } else { Some(text) } + } else { + None + } + } + _ => None, + } + } +} + +// ── SSE parsing ──────────────────────────────────────────────────────── + +/// Parse a single SSE `data:` line into a conversation event. +/// +/// Returns: +/// - `Ok(Some(event))` — parsed event +/// - `Ok(None)` — `[DONE]` signal or empty/comment line +/// - `Err(e)` — parse error +pub fn parse_sse_line(line: &str) -> Result, error::ApiError> { + let line = line.trim(); + + if line.is_empty() || line.starts_with(':') || line.starts_with("event:") { + return Ok(None); + } + + if line == "data: [DONE]" || line == "[DONE]" { + return Ok(None); + } + + // SSE data lines start with "data: " + let json = match line.strip_prefix("data: ") { + Some(j) => j.trim(), + None => return Ok(None), // not a data line + }; + if json.is_empty() { + return Ok(None); + } + + serde_json::from_str::(json).map(Some).map_err(|e| { + error::ApiError { + message: format!("Failed to parse conversation stream event: {e}\nRaw: {json}"), + } + }) +} + +/// Accumulate streaming events into a final `ConversationResponse`. +pub fn accumulate( + conversation_id: &str, + events: &[ConversationEvent], +) -> conversations::ConversationResponse { + let mut full_text = String::new(); + let mut function_calls = Vec::new(); + let mut usage = conversations::ConversationUsageInfo { + prompt_tokens: 0, + completion_tokens: 0, + total_tokens: 0, + }; + + for event in events { + match event { + ConversationEvent::MessageOutput { content, .. } => { + if let Some(s) = content.as_str() { + full_text.push_str(s); + } else if let Some(arr) = content.as_array() { + for chunk in arr { + if let Some(t) = chunk.get("text").and_then(|v| v.as_str()) { + full_text.push_str(t); + } + } + } + } + ConversationEvent::FunctionCall { + id, name, tool_call_id, arguments, .. + } => { + function_calls.push(conversations::ConversationEntry::FunctionCall( + conversations::FunctionCallEntry { + name: name.clone(), + arguments: arguments.clone(), + id: Some(id.clone()), + object: None, + tool_call_id: Some(tool_call_id.clone()), + created_at: None, + completed_at: None, + }, + )); + } + ConversationEvent::ResponseDone { usage: u, .. } => { + usage = u.clone(); + } + _ => {} + } + } + + let mut outputs = Vec::new(); + if !full_text.is_empty() { + outputs.push(conversations::ConversationEntry::MessageOutput( + conversations::MessageOutputEntry { + role: "assistant".into(), + content: crate::v1::chat::ChatMessageContent::Text(full_text), + id: None, + object: None, + model: None, + created_at: None, + completed_at: None, + }, + )); + } + outputs.extend(function_calls); + + conversations::ConversationResponse { + conversation_id: conversation_id.to_string(), + outputs, + usage, + object: "conversation.response".into(), + guardrails: None, + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_done() { + assert!(parse_sse_line("data: [DONE]").unwrap().is_none()); + } + + #[test] + fn test_parse_empty() { + assert!(parse_sse_line("").unwrap().is_none()); + assert!(parse_sse_line(" ").unwrap().is_none()); + } + + #[test] + fn test_parse_comment() { + assert!(parse_sse_line(": keep-alive").unwrap().is_none()); + } + + #[test] + fn test_parse_response_started() { + let line = r#"data: {"type":"conversation.response.started","created_at":"2026-03-24T12:00:00Z"}"#; + let event = parse_sse_line(line).unwrap().unwrap(); + assert!(matches!(event, ConversationEvent::ResponseStarted { .. })); + } + + #[test] + fn test_parse_message_output_string() { + let line = r#"data: {"type":"message.output.delta","id":"msg-1","content":"hello ","role":"assistant"}"#; + let event = parse_sse_line(line).unwrap().unwrap(); + assert_eq!(event.text_delta(), Some("hello ".into())); + } + + #[test] + fn test_parse_message_output_chunks() { + let line = r#"data: {"type":"message.output.delta","id":"msg-1","content":[{"type":"text","text":"world"}],"role":"assistant"}"#; + let event = parse_sse_line(line).unwrap().unwrap(); + assert_eq!(event.text_delta(), Some("world".into())); + } + + #[test] + fn test_parse_function_call() { + let line = r#"data: {"type":"function.call.delta","id":"fc-1","name":"search_web","tool_call_id":"tc-1","arguments":"{\"query\":\"test\"}"}"#; + let event = parse_sse_line(line).unwrap().unwrap(); + match event { + ConversationEvent::FunctionCall { name, arguments, tool_call_id, .. } => { + assert_eq!(name, "search_web"); + assert_eq!(tool_call_id, "tc-1"); + assert!(arguments.contains("test")); + } + _ => panic!("Expected FunctionCall"), + } + } + + #[test] + fn test_parse_response_done() { + let line = r#"data: {"type":"conversation.response.done","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}"#; + let event = parse_sse_line(line).unwrap().unwrap(); + match event { + ConversationEvent::ResponseDone { usage, .. } => { + assert_eq!(usage.prompt_tokens, 100); + assert_eq!(usage.completion_tokens, 50); + assert_eq!(usage.total_tokens, 150); + } + _ => panic!("Expected ResponseDone"), + } + } + + #[test] + fn test_parse_response_error() { + let line = r#"data: {"type":"conversation.response.error","message":"rate limited","code":429}"#; + let event = parse_sse_line(line).unwrap().unwrap(); + match event { + ConversationEvent::ResponseError { message, code, .. } => { + assert_eq!(message, "rate limited"); + assert_eq!(code, 429); + } + _ => panic!("Expected ResponseError"), + } + } + + #[test] + fn test_accumulate() { + let events = vec![ + ConversationEvent::ResponseStarted { created_at: None }, + ConversationEvent::MessageOutput { + id: "m1".into(), + content: serde_json::json!("hello "), + role: "assistant".into(), + output_index: 0, + content_index: 0, + model: None, + }, + ConversationEvent::MessageOutput { + id: "m1".into(), + content: serde_json::json!("world"), + role: "assistant".into(), + output_index: 0, + content_index: 0, + model: None, + }, + ConversationEvent::ResponseDone { + usage: conversations::ConversationUsageInfo { + prompt_tokens: 10, + completion_tokens: 5, + total_tokens: 15, + }, + created_at: None, + }, + ]; + + let resp = accumulate("conv-1", &events); + assert_eq!(resp.conversation_id, "conv-1"); + assert_eq!(resp.assistant_text(), Some("hello world".into())); + assert_eq!(resp.usage.total_tokens, 15); + } + + #[test] + fn test_accumulate_with_function_calls() { + let events = vec![ + ConversationEvent::FunctionCall { + id: "fc-1".into(), + name: "search".into(), + tool_call_id: "tc-1".into(), + arguments: r#"{"q":"test"}"#.into(), + output_index: 0, + model: None, + confirmation_status: None, + }, + ConversationEvent::ResponseDone { + usage: conversations::ConversationUsageInfo { + prompt_tokens: 20, + completion_tokens: 10, + total_tokens: 30, + }, + created_at: None, + }, + ]; + + let resp = accumulate("conv-2", &events); + assert!(resp.assistant_text().is_none()); + let calls = resp.function_calls(); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "search"); + } +} diff --git a/src/v1/mod.rs b/src/v1/mod.rs index 12635b5..d0f9fac 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -6,6 +6,7 @@ pub mod chat_stream; pub mod client; pub mod common; pub mod constants; +pub mod conversation_stream; pub mod conversations; pub mod embedding; pub mod error; diff --git a/tests/v1_conversation_stream_test.rs b/tests/v1_conversation_stream_test.rs new file mode 100644 index 0000000..2c3aff7 --- /dev/null +++ b/tests/v1_conversation_stream_test.rs @@ -0,0 +1,183 @@ +use futures::StreamExt; +use mistralai_client::v1::{ + client::Client, + conversation_stream::ConversationEvent, + conversations::*, +}; + +mod setup; + +fn make_client() -> Client { + Client::new(None, None, None, None).unwrap() +} + +#[tokio::test] +async fn test_create_conversation_stream() { + setup::setup(); + let client = make_client(); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("What is 2 + 2? Answer in one word.".to_string()), + model: Some("mistral-medium-latest".to_string()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: Some("Respond concisely.".to_string()), + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: Some(true), + stream: true, + }; + + let stream = client.create_conversation_stream_async(&req).await.unwrap(); + tokio::pin!(stream); + + let mut saw_started = false; + let mut saw_output = false; + let mut saw_done = false; + let mut full_text = String::new(); + let mut conversation_id = String::new(); + let mut usage_tokens = 0u32; + + while let Some(result) = stream.next().await { + let event = result.unwrap(); + match &event { + ConversationEvent::ResponseStarted { .. } => { + saw_started = true; + } + ConversationEvent::MessageOutput { .. } => { + saw_output = true; + if let Some(delta) = event.text_delta() { + full_text.push_str(&delta); + } + } + ConversationEvent::ResponseDone { usage, .. } => { + saw_done = true; + usage_tokens = usage.total_tokens; + } + _ => {} + } + } + + assert!(saw_started, "Should receive ResponseStarted event"); + assert!(saw_output, "Should receive at least one MessageOutput event"); + assert!(saw_done, "Should receive ResponseDone event"); + assert!(!full_text.is_empty(), "Should accumulate text from deltas"); + assert!(usage_tokens > 0, "Should have token usage"); + + // Accumulate and verify + // (we can't accumulate from the consumed stream, but we verified the pieces above) +} + +#[tokio::test] +async fn test_append_conversation_stream() { + setup::setup(); + let client = make_client(); + + // Create conversation (non-streaming) first + let create_req = CreateConversationRequest { + inputs: ConversationInput::Text("Remember: the secret word is BANANA.".to_string()), + model: Some("mistral-medium-latest".to_string()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: Some("Keep responses short.".to_string()), + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: Some(true), + stream: false, + }; + let created = client.create_conversation_async(&create_req).await.unwrap(); + + // Append with streaming + let append_req = AppendConversationRequest { + inputs: ConversationInput::Text("What is the secret word?".to_string()), + completion_args: None, + handoff_execution: None, + store: Some(true), + tool_confirmations: None, + stream: true, + }; + + let stream = client + .append_conversation_stream_async(&created.conversation_id, &append_req) + .await + .unwrap(); + tokio::pin!(stream); + + let mut events = Vec::new(); + while let Some(result) = stream.next().await { + events.push(result.unwrap()); + } + + // Should have started + output(s) + done + assert!( + events.iter().any(|e| matches!(e, ConversationEvent::ResponseStarted { .. })), + "Should have ResponseStarted" + ); + assert!( + events.iter().any(|e| matches!(e, ConversationEvent::ResponseDone { .. })), + "Should have ResponseDone" + ); + + // Accumulate and check the text + let resp = mistralai_client::v1::conversation_stream::accumulate( + &created.conversation_id, + &events, + ); + let text = resp.assistant_text().unwrap_or_default().to_uppercase(); + assert!(text.contains("BANANA"), "Should recall the secret word, got: {text}"); + + // Cleanup + client.delete_conversation_async(&created.conversation_id).await.unwrap(); +} + +#[tokio::test] +async fn test_stream_accumulate_matches_non_stream() { + setup::setup(); + let client = make_client(); + + let question = "What is the capital of Japan? One word."; + + // Non-streaming + let req = CreateConversationRequest { + inputs: ConversationInput::Text(question.to_string()), + model: Some("mistral-medium-latest".to_string()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: Some("Answer in exactly one word.".to_string()), + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: Some(false), + stream: false, + }; + let non_stream = client.create_conversation_async(&req).await.unwrap(); + let non_stream_text = non_stream.assistant_text().unwrap_or_default().to_lowercase(); + + // Streaming + let mut stream_req = req.clone(); + stream_req.stream = true; + let stream = client.create_conversation_stream_async(&stream_req).await.unwrap(); + tokio::pin!(stream); + + let mut events = Vec::new(); + while let Some(result) = stream.next().await { + events.push(result.unwrap()); + } + let accumulated = mistralai_client::v1::conversation_stream::accumulate("", &events); + let stream_text = accumulated.assistant_text().unwrap_or_default().to_lowercase(); + + // Both should contain "tokyo" + assert!(non_stream_text.contains("tokyo"), "Non-stream should say Tokyo: {non_stream_text}"); + assert!(stream_text.contains("tokyo"), "Stream should say Tokyo: {stream_text}"); +}