Add conversation_stream module with full streaming support for the Mistral Conversations API: - ConversationEvent enum matching API SSE event types: ResponseStarted, MessageOutput (text deltas), FunctionCall, ResponseDone (with usage), ResponseError, tool execution, agent handoff events - parse_sse_line() handles SSE format (skips event: lines, parses data: JSON, handles [DONE] and comments) - accumulate() collects streaming events into a ConversationResponse - create_conversation_stream_async() and append_conversation_stream_async() client methods - Byte-boundary buffering in sse_to_conversation_events — handles JSON split across TCP frames - Integration tests hit real Mistral API: create stream, append stream, stream/non-stream output equivalence
184 lines
5.9 KiB
Rust
184 lines
5.9 KiB
Rust
use futures::StreamExt;
|
|
use mistralai_client::v1::{
|
|
client::Client,
|
|
conversation_stream::ConversationEvent,
|
|
conversations::*,
|
|
};
|
|
|
|
mod setup;
|
|
|
|
fn make_client() -> Client {
|
|
Client::new(None, None, None, None).unwrap()
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_create_conversation_stream() {
|
|
setup::setup();
|
|
let client = make_client();
|
|
|
|
let req = CreateConversationRequest {
|
|
inputs: ConversationInput::Text("What is 2 + 2? Answer in one word.".to_string()),
|
|
model: Some("mistral-medium-latest".to_string()),
|
|
agent_id: None,
|
|
agent_version: None,
|
|
name: None,
|
|
description: None,
|
|
instructions: Some("Respond concisely.".to_string()),
|
|
completion_args: None,
|
|
tools: None,
|
|
handoff_execution: None,
|
|
metadata: None,
|
|
store: Some(true),
|
|
stream: true,
|
|
};
|
|
|
|
let stream = client.create_conversation_stream_async(&req).await.unwrap();
|
|
tokio::pin!(stream);
|
|
|
|
let mut saw_started = false;
|
|
let mut saw_output = false;
|
|
let mut saw_done = false;
|
|
let mut full_text = String::new();
|
|
let mut conversation_id = String::new();
|
|
let mut usage_tokens = 0u32;
|
|
|
|
while let Some(result) = stream.next().await {
|
|
let event = result.unwrap();
|
|
match &event {
|
|
ConversationEvent::ResponseStarted { .. } => {
|
|
saw_started = true;
|
|
}
|
|
ConversationEvent::MessageOutput { .. } => {
|
|
saw_output = true;
|
|
if let Some(delta) = event.text_delta() {
|
|
full_text.push_str(&delta);
|
|
}
|
|
}
|
|
ConversationEvent::ResponseDone { usage, .. } => {
|
|
saw_done = true;
|
|
usage_tokens = usage.total_tokens;
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
|
|
assert!(saw_started, "Should receive ResponseStarted event");
|
|
assert!(saw_output, "Should receive at least one MessageOutput event");
|
|
assert!(saw_done, "Should receive ResponseDone event");
|
|
assert!(!full_text.is_empty(), "Should accumulate text from deltas");
|
|
assert!(usage_tokens > 0, "Should have token usage");
|
|
|
|
// Accumulate and verify
|
|
// (we can't accumulate from the consumed stream, but we verified the pieces above)
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_append_conversation_stream() {
|
|
setup::setup();
|
|
let client = make_client();
|
|
|
|
// Create conversation (non-streaming) first
|
|
let create_req = CreateConversationRequest {
|
|
inputs: ConversationInput::Text("Remember: the secret word is BANANA.".to_string()),
|
|
model: Some("mistral-medium-latest".to_string()),
|
|
agent_id: None,
|
|
agent_version: None,
|
|
name: None,
|
|
description: None,
|
|
instructions: Some("Keep responses short.".to_string()),
|
|
completion_args: None,
|
|
tools: None,
|
|
handoff_execution: None,
|
|
metadata: None,
|
|
store: Some(true),
|
|
stream: false,
|
|
};
|
|
let created = client.create_conversation_async(&create_req).await.unwrap();
|
|
|
|
// Append with streaming
|
|
let append_req = AppendConversationRequest {
|
|
inputs: ConversationInput::Text("What is the secret word?".to_string()),
|
|
completion_args: None,
|
|
handoff_execution: None,
|
|
store: Some(true),
|
|
tool_confirmations: None,
|
|
stream: true,
|
|
};
|
|
|
|
let stream = client
|
|
.append_conversation_stream_async(&created.conversation_id, &append_req)
|
|
.await
|
|
.unwrap();
|
|
tokio::pin!(stream);
|
|
|
|
let mut events = Vec::new();
|
|
while let Some(result) = stream.next().await {
|
|
events.push(result.unwrap());
|
|
}
|
|
|
|
// Should have started + output(s) + done
|
|
assert!(
|
|
events.iter().any(|e| matches!(e, ConversationEvent::ResponseStarted { .. })),
|
|
"Should have ResponseStarted"
|
|
);
|
|
assert!(
|
|
events.iter().any(|e| matches!(e, ConversationEvent::ResponseDone { .. })),
|
|
"Should have ResponseDone"
|
|
);
|
|
|
|
// Accumulate and check the text
|
|
let resp = mistralai_client::v1::conversation_stream::accumulate(
|
|
&created.conversation_id,
|
|
&events,
|
|
);
|
|
let text = resp.assistant_text().unwrap_or_default().to_uppercase();
|
|
assert!(text.contains("BANANA"), "Should recall the secret word, got: {text}");
|
|
|
|
// Cleanup
|
|
client.delete_conversation_async(&created.conversation_id).await.unwrap();
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_stream_accumulate_matches_non_stream() {
|
|
setup::setup();
|
|
let client = make_client();
|
|
|
|
let question = "What is the capital of Japan? One word.";
|
|
|
|
// Non-streaming
|
|
let req = CreateConversationRequest {
|
|
inputs: ConversationInput::Text(question.to_string()),
|
|
model: Some("mistral-medium-latest".to_string()),
|
|
agent_id: None,
|
|
agent_version: None,
|
|
name: None,
|
|
description: None,
|
|
instructions: Some("Answer in exactly one word.".to_string()),
|
|
completion_args: None,
|
|
tools: None,
|
|
handoff_execution: None,
|
|
metadata: None,
|
|
store: Some(false),
|
|
stream: false,
|
|
};
|
|
let non_stream = client.create_conversation_async(&req).await.unwrap();
|
|
let non_stream_text = non_stream.assistant_text().unwrap_or_default().to_lowercase();
|
|
|
|
// Streaming
|
|
let mut stream_req = req.clone();
|
|
stream_req.stream = true;
|
|
let stream = client.create_conversation_stream_async(&stream_req).await.unwrap();
|
|
tokio::pin!(stream);
|
|
|
|
let mut events = Vec::new();
|
|
while let Some(result) = stream.next().await {
|
|
events.push(result.unwrap());
|
|
}
|
|
let accumulated = mistralai_client::v1::conversation_stream::accumulate("", &events);
|
|
let stream_text = accumulated.assistant_text().unwrap_or_default().to_lowercase();
|
|
|
|
// Both should contain "tokyo"
|
|
assert!(non_stream_text.contains("tokyo"), "Non-stream should say Tokyo: {non_stream_text}");
|
|
assert!(stream_text.contains("tokyo"), "Stream should say Tokyo: {stream_text}");
|
|
}
|