feat: streaming Conversations API support (v1.2.0)
Add conversation_stream module with full streaming support for the Mistral Conversations API: - ConversationEvent enum matching API SSE event types: ResponseStarted, MessageOutput (text deltas), FunctionCall, ResponseDone (with usage), ResponseError, tool execution, agent handoff events - parse_sse_line() handles SSE format (skips event: lines, parses data: JSON, handles [DONE] and comments) - accumulate() collects streaming events into a ConversationResponse - create_conversation_stream_async() and append_conversation_stream_async() client methods - Byte-boundary buffering in sse_to_conversation_events — handles JSON split across TCP frames - Integration tests hit real Mistral API: create stream, append stream, stream/non-stream output equivalence
This commit is contained in:
@@ -2,7 +2,7 @@
|
|||||||
name = "mistralai-client"
|
name = "mistralai-client"
|
||||||
description = "Mistral AI API client library for Rust (unofficial)."
|
description = "Mistral AI API client library for Rust (unofficial)."
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
version = "1.1.0"
|
version = "1.2.0"
|
||||||
|
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
rust-version = "1.76.0"
|
rust-version = "1.76.0"
|
||||||
|
|||||||
101
src/v1/client.rs
101
src/v1/client.rs
@@ -10,8 +10,8 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
use crate::v1::{
|
use crate::v1::{
|
||||||
agents, audio, batch, chat, chat_stream, constants, conversations, embedding, error, files,
|
agents, audio, batch, chat, chat_stream, constants, conversation_stream, conversations,
|
||||||
fim, fine_tuning, model_list, moderation, ocr, tool, utils,
|
embedding, error, files, fim, fine_tuning, model_list, moderation, ocr, tool, utils,
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
@@ -1054,6 +1054,103 @@ impl Client {
|
|||||||
.map_err(|e| self.to_api_error(e))
|
.map_err(|e| self.to_api_error(e))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a conversation with streaming response.
|
||||||
|
/// Returns a stream of `ConversationEvent`s as SSE events arrive.
|
||||||
|
pub async fn create_conversation_stream_async(
|
||||||
|
&self,
|
||||||
|
request: &conversations::CreateConversationRequest,
|
||||||
|
) -> Result<
|
||||||
|
impl futures::Stream<Item = Result<conversation_stream::ConversationEvent, error::ApiError>>,
|
||||||
|
error::ApiError,
|
||||||
|
> {
|
||||||
|
// Ensure stream is true
|
||||||
|
let mut req = request.clone();
|
||||||
|
req.stream = true;
|
||||||
|
|
||||||
|
let response = self.post_stream("/conversations", &req).await?;
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
return Err(error::ApiError {
|
||||||
|
message: format!("{}: {}", status, text),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.sse_to_conversation_events(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Append to a conversation with streaming response.
|
||||||
|
/// Returns a stream of `ConversationEvent`s as SSE events arrive.
|
||||||
|
pub async fn append_conversation_stream_async(
|
||||||
|
&self,
|
||||||
|
conversation_id: &str,
|
||||||
|
request: &conversations::AppendConversationRequest,
|
||||||
|
) -> Result<
|
||||||
|
impl futures::Stream<Item = Result<conversation_stream::ConversationEvent, error::ApiError>>,
|
||||||
|
error::ApiError,
|
||||||
|
> {
|
||||||
|
let mut req = request.clone();
|
||||||
|
req.stream = true;
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.post_stream(&format!("/conversations/{}", conversation_id), &req)
|
||||||
|
.await?;
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let text = response.text().await.unwrap_or_default();
|
||||||
|
return Err(error::ApiError {
|
||||||
|
message: format!("{}: {}", status, text),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(self.sse_to_conversation_events(response))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert a raw SSE response into a stream of conversation events.
|
||||||
|
/// Handles byte-boundary splits by buffering incomplete lines.
|
||||||
|
fn sse_to_conversation_events(
|
||||||
|
&self,
|
||||||
|
response: reqwest::Response,
|
||||||
|
) -> impl futures::Stream<Item = Result<conversation_stream::ConversationEvent, error::ApiError>> {
|
||||||
|
use futures::stream;
|
||||||
|
|
||||||
|
let mut buffer = String::new();
|
||||||
|
|
||||||
|
response.bytes_stream().flat_map(move |bytes_result| {
|
||||||
|
match bytes_result {
|
||||||
|
Ok(bytes) => {
|
||||||
|
let text = match String::from_utf8(bytes.to_vec()) {
|
||||||
|
Ok(t) => t,
|
||||||
|
Err(e) => return stream::iter(vec![Err(error::ApiError {
|
||||||
|
message: format!("UTF-8 decode error: {e}"),
|
||||||
|
})]),
|
||||||
|
};
|
||||||
|
|
||||||
|
buffer.push_str(&text);
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
|
||||||
|
// Process complete lines only (ending with \n)
|
||||||
|
while let Some(newline_pos) = buffer.find('\n') {
|
||||||
|
let line = buffer[..newline_pos].to_string();
|
||||||
|
buffer = buffer[newline_pos + 1..].to_string();
|
||||||
|
|
||||||
|
match conversation_stream::parse_sse_line(&line) {
|
||||||
|
Ok(Some(event)) => events.push(Ok(event)),
|
||||||
|
Ok(None) => {}
|
||||||
|
Err(e) => events.push(Err(e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
stream::iter(events)
|
||||||
|
}
|
||||||
|
Err(e) => stream::iter(vec![Err(error::ApiError {
|
||||||
|
message: format!("Stream read error: {e}"),
|
||||||
|
})]),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
pub fn get_conversation(
|
pub fn get_conversation(
|
||||||
&self,
|
&self,
|
||||||
conversation_id: &str,
|
conversation_id: &str,
|
||||||
|
|||||||
395
src/v1/conversation_stream.rs
Normal file
395
src/v1/conversation_stream.rs
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
//! Streaming support for the Conversations API.
|
||||||
|
//!
|
||||||
|
//! When `stream: true` is set on a conversation request, the API returns
|
||||||
|
//! Server-Sent Events (SSE). Each event has an `event:` type line and a
|
||||||
|
//! `data:` JSON payload, discriminated by the `type` field.
|
||||||
|
//!
|
||||||
|
//! Event types:
|
||||||
|
//! - `conversation.response.started` — generation began
|
||||||
|
//! - `message.output.delta` — partial assistant text
|
||||||
|
//! - `function.call.delta` — a function call (tool call)
|
||||||
|
//! - `conversation.response.done` — generation complete (has usage)
|
||||||
|
//! - `conversation.response.error` — error during generation
|
||||||
|
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::{conversations, error};
|
||||||
|
|
||||||
|
// ── SSE event types ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// A streaming event from the Conversations API.
|
||||||
|
/// The `type` field discriminates the variant.
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ConversationEvent {
|
||||||
|
/// Generation started.
|
||||||
|
#[serde(rename = "conversation.response.started")]
|
||||||
|
ResponseStarted {
|
||||||
|
#[serde(default)]
|
||||||
|
created_at: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Partial assistant text output.
|
||||||
|
#[serde(rename = "message.output.delta")]
|
||||||
|
MessageOutput {
|
||||||
|
id: String,
|
||||||
|
content: serde_json::Value, // string or array of chunks
|
||||||
|
#[serde(default)]
|
||||||
|
role: String,
|
||||||
|
#[serde(default)]
|
||||||
|
output_index: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
content_index: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
model: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// A function call from the model.
|
||||||
|
#[serde(rename = "function.call.delta")]
|
||||||
|
FunctionCall {
|
||||||
|
id: String,
|
||||||
|
name: String,
|
||||||
|
tool_call_id: String,
|
||||||
|
arguments: String,
|
||||||
|
#[serde(default)]
|
||||||
|
output_index: u32,
|
||||||
|
#[serde(default)]
|
||||||
|
model: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
confirmation_status: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Generation complete — includes token usage.
|
||||||
|
#[serde(rename = "conversation.response.done")]
|
||||||
|
ResponseDone {
|
||||||
|
usage: conversations::ConversationUsageInfo,
|
||||||
|
#[serde(default)]
|
||||||
|
created_at: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Error during generation.
|
||||||
|
#[serde(rename = "conversation.response.error")]
|
||||||
|
ResponseError {
|
||||||
|
message: String,
|
||||||
|
code: i32,
|
||||||
|
#[serde(default)]
|
||||||
|
created_at: Option<String>,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Tool execution started (server-side).
|
||||||
|
#[serde(rename = "tool.execution.started")]
|
||||||
|
ToolExecutionStarted {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Tool execution delta (server-side).
|
||||||
|
#[serde(rename = "tool.execution.delta")]
|
||||||
|
ToolExecutionDelta {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Tool execution done (server-side).
|
||||||
|
#[serde(rename = "tool.execution.done")]
|
||||||
|
ToolExecutionDone {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Agent handoff started.
|
||||||
|
#[serde(rename = "agent.handoff.started")]
|
||||||
|
AgentHandoffStarted {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Agent handoff done.
|
||||||
|
#[serde(rename = "agent.handoff.done")]
|
||||||
|
AgentHandoffDone {
|
||||||
|
#[serde(flatten)]
|
||||||
|
extra: serde_json::Value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ConversationEvent {
|
||||||
|
/// Extract text content from a MessageOutput event.
|
||||||
|
pub fn text_delta(&self) -> Option<String> {
|
||||||
|
match self {
|
||||||
|
ConversationEvent::MessageOutput { content, .. } => {
|
||||||
|
// content can be a string or an array of chunks
|
||||||
|
if let Some(s) = content.as_str() {
|
||||||
|
Some(s.to_string())
|
||||||
|
} else if let Some(arr) = content.as_array() {
|
||||||
|
// Array of chunks — extract text from TextChunk items
|
||||||
|
let mut text = String::new();
|
||||||
|
for chunk in arr {
|
||||||
|
if let Some(t) = chunk.get("text").and_then(|v| v.as_str()) {
|
||||||
|
text.push_str(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if text.is_empty() { None } else { Some(text) }
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── SSE parsing ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Parse a single SSE `data:` line into a conversation event.
|
||||||
|
///
|
||||||
|
/// Returns:
|
||||||
|
/// - `Ok(Some(event))` — parsed event
|
||||||
|
/// - `Ok(None)` — `[DONE]` signal or empty/comment line
|
||||||
|
/// - `Err(e)` — parse error
|
||||||
|
pub fn parse_sse_line(line: &str) -> Result<Option<ConversationEvent>, error::ApiError> {
|
||||||
|
let line = line.trim();
|
||||||
|
|
||||||
|
if line.is_empty() || line.starts_with(':') || line.starts_with("event:") {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
if line == "data: [DONE]" || line == "[DONE]" {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE data lines start with "data: "
|
||||||
|
let json = match line.strip_prefix("data: ") {
|
||||||
|
Some(j) => j.trim(),
|
||||||
|
None => return Ok(None), // not a data line
|
||||||
|
};
|
||||||
|
if json.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
serde_json::from_str::<ConversationEvent>(json).map(Some).map_err(|e| {
|
||||||
|
error::ApiError {
|
||||||
|
message: format!("Failed to parse conversation stream event: {e}\nRaw: {json}"),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Accumulate streaming events into a final `ConversationResponse`.
|
||||||
|
pub fn accumulate(
|
||||||
|
conversation_id: &str,
|
||||||
|
events: &[ConversationEvent],
|
||||||
|
) -> conversations::ConversationResponse {
|
||||||
|
let mut full_text = String::new();
|
||||||
|
let mut function_calls = Vec::new();
|
||||||
|
let mut usage = conversations::ConversationUsageInfo {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0,
|
||||||
|
};
|
||||||
|
|
||||||
|
for event in events {
|
||||||
|
match event {
|
||||||
|
ConversationEvent::MessageOutput { content, .. } => {
|
||||||
|
if let Some(s) = content.as_str() {
|
||||||
|
full_text.push_str(s);
|
||||||
|
} else if let Some(arr) = content.as_array() {
|
||||||
|
for chunk in arr {
|
||||||
|
if let Some(t) = chunk.get("text").and_then(|v| v.as_str()) {
|
||||||
|
full_text.push_str(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ConversationEvent::FunctionCall {
|
||||||
|
id, name, tool_call_id, arguments, ..
|
||||||
|
} => {
|
||||||
|
function_calls.push(conversations::ConversationEntry::FunctionCall(
|
||||||
|
conversations::FunctionCallEntry {
|
||||||
|
name: name.clone(),
|
||||||
|
arguments: arguments.clone(),
|
||||||
|
id: Some(id.clone()),
|
||||||
|
object: None,
|
||||||
|
tool_call_id: Some(tool_call_id.clone()),
|
||||||
|
created_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
ConversationEvent::ResponseDone { usage: u, .. } => {
|
||||||
|
usage = u.clone();
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut outputs = Vec::new();
|
||||||
|
if !full_text.is_empty() {
|
||||||
|
outputs.push(conversations::ConversationEntry::MessageOutput(
|
||||||
|
conversations::MessageOutputEntry {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: crate::v1::chat::ChatMessageContent::Text(full_text),
|
||||||
|
id: None,
|
||||||
|
object: None,
|
||||||
|
model: None,
|
||||||
|
created_at: None,
|
||||||
|
completed_at: None,
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
outputs.extend(function_calls);
|
||||||
|
|
||||||
|
conversations::ConversationResponse {
|
||||||
|
conversation_id: conversation_id.to_string(),
|
||||||
|
outputs,
|
||||||
|
usage,
|
||||||
|
object: "conversation.response".into(),
|
||||||
|
guardrails: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_done() {
|
||||||
|
assert!(parse_sse_line("data: [DONE]").unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_empty() {
|
||||||
|
assert!(parse_sse_line("").unwrap().is_none());
|
||||||
|
assert!(parse_sse_line(" ").unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_comment() {
|
||||||
|
assert!(parse_sse_line(": keep-alive").unwrap().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_response_started() {
|
||||||
|
let line = r#"data: {"type":"conversation.response.started","created_at":"2026-03-24T12:00:00Z"}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
assert!(matches!(event, ConversationEvent::ResponseStarted { .. }));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_message_output_string() {
|
||||||
|
let line = r#"data: {"type":"message.output.delta","id":"msg-1","content":"hello ","role":"assistant"}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
assert_eq!(event.text_delta(), Some("hello ".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_message_output_chunks() {
|
||||||
|
let line = r#"data: {"type":"message.output.delta","id":"msg-1","content":[{"type":"text","text":"world"}],"role":"assistant"}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
assert_eq!(event.text_delta(), Some("world".into()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_function_call() {
|
||||||
|
let line = r#"data: {"type":"function.call.delta","id":"fc-1","name":"search_web","tool_call_id":"tc-1","arguments":"{\"query\":\"test\"}"}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
match event {
|
||||||
|
ConversationEvent::FunctionCall { name, arguments, tool_call_id, .. } => {
|
||||||
|
assert_eq!(name, "search_web");
|
||||||
|
assert_eq!(tool_call_id, "tc-1");
|
||||||
|
assert!(arguments.contains("test"));
|
||||||
|
}
|
||||||
|
_ => panic!("Expected FunctionCall"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_response_done() {
|
||||||
|
let line = r#"data: {"type":"conversation.response.done","usage":{"prompt_tokens":100,"completion_tokens":50,"total_tokens":150}}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
match event {
|
||||||
|
ConversationEvent::ResponseDone { usage, .. } => {
|
||||||
|
assert_eq!(usage.prompt_tokens, 100);
|
||||||
|
assert_eq!(usage.completion_tokens, 50);
|
||||||
|
assert_eq!(usage.total_tokens, 150);
|
||||||
|
}
|
||||||
|
_ => panic!("Expected ResponseDone"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_parse_response_error() {
|
||||||
|
let line = r#"data: {"type":"conversation.response.error","message":"rate limited","code":429}"#;
|
||||||
|
let event = parse_sse_line(line).unwrap().unwrap();
|
||||||
|
match event {
|
||||||
|
ConversationEvent::ResponseError { message, code, .. } => {
|
||||||
|
assert_eq!(message, "rate limited");
|
||||||
|
assert_eq!(code, 429);
|
||||||
|
}
|
||||||
|
_ => panic!("Expected ResponseError"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accumulate() {
|
||||||
|
let events = vec![
|
||||||
|
ConversationEvent::ResponseStarted { created_at: None },
|
||||||
|
ConversationEvent::MessageOutput {
|
||||||
|
id: "m1".into(),
|
||||||
|
content: serde_json::json!("hello "),
|
||||||
|
role: "assistant".into(),
|
||||||
|
output_index: 0,
|
||||||
|
content_index: 0,
|
||||||
|
model: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::MessageOutput {
|
||||||
|
id: "m1".into(),
|
||||||
|
content: serde_json::json!("world"),
|
||||||
|
role: "assistant".into(),
|
||||||
|
output_index: 0,
|
||||||
|
content_index: 0,
|
||||||
|
model: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::ResponseDone {
|
||||||
|
usage: conversations::ConversationUsageInfo {
|
||||||
|
prompt_tokens: 10,
|
||||||
|
completion_tokens: 5,
|
||||||
|
total_tokens: 15,
|
||||||
|
},
|
||||||
|
created_at: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let resp = accumulate("conv-1", &events);
|
||||||
|
assert_eq!(resp.conversation_id, "conv-1");
|
||||||
|
assert_eq!(resp.assistant_text(), Some("hello world".into()));
|
||||||
|
assert_eq!(resp.usage.total_tokens, 15);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accumulate_with_function_calls() {
|
||||||
|
let events = vec![
|
||||||
|
ConversationEvent::FunctionCall {
|
||||||
|
id: "fc-1".into(),
|
||||||
|
name: "search".into(),
|
||||||
|
tool_call_id: "tc-1".into(),
|
||||||
|
arguments: r#"{"q":"test"}"#.into(),
|
||||||
|
output_index: 0,
|
||||||
|
model: None,
|
||||||
|
confirmation_status: None,
|
||||||
|
},
|
||||||
|
ConversationEvent::ResponseDone {
|
||||||
|
usage: conversations::ConversationUsageInfo {
|
||||||
|
prompt_tokens: 20,
|
||||||
|
completion_tokens: 10,
|
||||||
|
total_tokens: 30,
|
||||||
|
},
|
||||||
|
created_at: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let resp = accumulate("conv-2", &events);
|
||||||
|
assert!(resp.assistant_text().is_none());
|
||||||
|
let calls = resp.function_calls();
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "search");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ pub mod chat_stream;
|
|||||||
pub mod client;
|
pub mod client;
|
||||||
pub mod common;
|
pub mod common;
|
||||||
pub mod constants;
|
pub mod constants;
|
||||||
|
pub mod conversation_stream;
|
||||||
pub mod conversations;
|
pub mod conversations;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
|||||||
183
tests/v1_conversation_stream_test.rs
Normal file
183
tests/v1_conversation_stream_test.rs
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
use futures::StreamExt;
|
||||||
|
use mistralai_client::v1::{
|
||||||
|
client::Client,
|
||||||
|
conversation_stream::ConversationEvent,
|
||||||
|
conversations::*,
|
||||||
|
};
|
||||||
|
|
||||||
|
mod setup;
|
||||||
|
|
||||||
|
fn make_client() -> Client {
|
||||||
|
Client::new(None, None, None, None).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_create_conversation_stream() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is 2 + 2? Answer in one word.".to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Respond concisely.".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(true),
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = client.create_conversation_stream_async(&req).await.unwrap();
|
||||||
|
tokio::pin!(stream);
|
||||||
|
|
||||||
|
let mut saw_started = false;
|
||||||
|
let mut saw_output = false;
|
||||||
|
let mut saw_done = false;
|
||||||
|
let mut full_text = String::new();
|
||||||
|
let mut conversation_id = String::new();
|
||||||
|
let mut usage_tokens = 0u32;
|
||||||
|
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
let event = result.unwrap();
|
||||||
|
match &event {
|
||||||
|
ConversationEvent::ResponseStarted { .. } => {
|
||||||
|
saw_started = true;
|
||||||
|
}
|
||||||
|
ConversationEvent::MessageOutput { .. } => {
|
||||||
|
saw_output = true;
|
||||||
|
if let Some(delta) = event.text_delta() {
|
||||||
|
full_text.push_str(&delta);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ConversationEvent::ResponseDone { usage, .. } => {
|
||||||
|
saw_done = true;
|
||||||
|
usage_tokens = usage.total_tokens;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert!(saw_started, "Should receive ResponseStarted event");
|
||||||
|
assert!(saw_output, "Should receive at least one MessageOutput event");
|
||||||
|
assert!(saw_done, "Should receive ResponseDone event");
|
||||||
|
assert!(!full_text.is_empty(), "Should accumulate text from deltas");
|
||||||
|
assert!(usage_tokens > 0, "Should have token usage");
|
||||||
|
|
||||||
|
// Accumulate and verify
|
||||||
|
// (we can't accumulate from the consumed stream, but we verified the pieces above)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_append_conversation_stream() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
// Create conversation (non-streaming) first
|
||||||
|
let create_req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("Remember: the secret word is BANANA.".to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Keep responses short.".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(true),
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let created = client.create_conversation_async(&create_req).await.unwrap();
|
||||||
|
|
||||||
|
// Append with streaming
|
||||||
|
let append_req = AppendConversationRequest {
|
||||||
|
inputs: ConversationInput::Text("What is the secret word?".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
store: Some(true),
|
||||||
|
tool_confirmations: None,
|
||||||
|
stream: true,
|
||||||
|
};
|
||||||
|
|
||||||
|
let stream = client
|
||||||
|
.append_conversation_stream_async(&created.conversation_id, &append_req)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
tokio::pin!(stream);
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
events.push(result.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Should have started + output(s) + done
|
||||||
|
assert!(
|
||||||
|
events.iter().any(|e| matches!(e, ConversationEvent::ResponseStarted { .. })),
|
||||||
|
"Should have ResponseStarted"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
events.iter().any(|e| matches!(e, ConversationEvent::ResponseDone { .. })),
|
||||||
|
"Should have ResponseDone"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Accumulate and check the text
|
||||||
|
let resp = mistralai_client::v1::conversation_stream::accumulate(
|
||||||
|
&created.conversation_id,
|
||||||
|
&events,
|
||||||
|
);
|
||||||
|
let text = resp.assistant_text().unwrap_or_default().to_uppercase();
|
||||||
|
assert!(text.contains("BANANA"), "Should recall the secret word, got: {text}");
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
client.delete_conversation_async(&created.conversation_id).await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_stream_accumulate_matches_non_stream() {
|
||||||
|
setup::setup();
|
||||||
|
let client = make_client();
|
||||||
|
|
||||||
|
let question = "What is the capital of Japan? One word.";
|
||||||
|
|
||||||
|
// Non-streaming
|
||||||
|
let req = CreateConversationRequest {
|
||||||
|
inputs: ConversationInput::Text(question.to_string()),
|
||||||
|
model: Some("mistral-medium-latest".to_string()),
|
||||||
|
agent_id: None,
|
||||||
|
agent_version: None,
|
||||||
|
name: None,
|
||||||
|
description: None,
|
||||||
|
instructions: Some("Answer in exactly one word.".to_string()),
|
||||||
|
completion_args: None,
|
||||||
|
tools: None,
|
||||||
|
handoff_execution: None,
|
||||||
|
metadata: None,
|
||||||
|
store: Some(false),
|
||||||
|
stream: false,
|
||||||
|
};
|
||||||
|
let non_stream = client.create_conversation_async(&req).await.unwrap();
|
||||||
|
let non_stream_text = non_stream.assistant_text().unwrap_or_default().to_lowercase();
|
||||||
|
|
||||||
|
// Streaming
|
||||||
|
let mut stream_req = req.clone();
|
||||||
|
stream_req.stream = true;
|
||||||
|
let stream = client.create_conversation_stream_async(&stream_req).await.unwrap();
|
||||||
|
tokio::pin!(stream);
|
||||||
|
|
||||||
|
let mut events = Vec::new();
|
||||||
|
while let Some(result) = stream.next().await {
|
||||||
|
events.push(result.unwrap());
|
||||||
|
}
|
||||||
|
let accumulated = mistralai_client::v1::conversation_stream::accumulate("", &events);
|
||||||
|
let stream_text = accumulated.assistant_text().unwrap_or_default().to_lowercase();
|
||||||
|
|
||||||
|
// Both should contain "tokyo"
|
||||||
|
assert!(non_stream_text.contains("tokyo"), "Non-stream should say Tokyo: {non_stream_text}");
|
||||||
|
assert!(stream_text.contains("tokyo"), "Stream should say Tokyo: {stream_text}");
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user