From a29c3c010998b12ee1ebc5fe60d074902cbf11c4 Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Sat, 21 Mar 2026 20:58:25 +0000 Subject: [PATCH] feat: add Agents API, Conversations API, and multimodal support (v1.1.0) Agents API (beta): create, get, update, delete, list agents with tools, handoffs, completion args, and guardrails support. Conversations API (beta): create, append, history, messages, restart, delete, list conversations. Supports agent-backed and model-only conversations with function calling and handoff execution modes. Multimodal: ChatMessageContent enum (Text/Parts) with ContentPart variants for text and image_url. Backwards-compatible constructors. new_user_message_with_images() for mixed content messages. Chat: reasoning field on ChatResponseChoice for Magistral models. HTTP: PATCH methods for agent updates. 81 tests (30 live API integration + 35 serde unit + 16 existing). --- Cargo.toml | 2 +- src/v1/agents.rs | 195 +++++++- src/v1/chat.rs | 109 ++++- src/v1/client.rs | 328 +++++++++++++- src/v1/conversations.rs | 377 ++++++++++++++++ src/v1/mod.rs | 1 + tests/v1_agents_api_test.rs | 372 ++++++++++++++++ tests/v1_agents_types_test.rs | 119 +++++ tests/v1_chat_multimodal_api_test.rs | 156 +++++++ tests/v1_chat_multimodal_test.rs | 204 +++++++++ tests/v1_client_chat_async_test.rs | 2 +- tests/v1_client_chat_test.rs | 2 +- tests/v1_constants_test.rs | 2 +- tests/v1_conversations_api_test.rs | 642 +++++++++++++++++++++++++++ tests/v1_conversations_types_test.rs | 226 ++++++++++ 15 files changed, 2721 insertions(+), 16 deletions(-) create mode 100644 src/v1/conversations.rs create mode 100644 tests/v1_agents_api_test.rs create mode 100644 tests/v1_agents_types_test.rs create mode 100644 tests/v1_chat_multimodal_api_test.rs create mode 100644 tests/v1_chat_multimodal_test.rs create mode 100644 tests/v1_conversations_api_test.rs create mode 100644 tests/v1_conversations_types_test.rs diff --git a/Cargo.toml b/Cargo.toml index d134f05..9d4a04c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "mistralai-client" description = "Mistral AI API client library for Rust (unofficial)." license = "Apache-2.0" -version = "1.0.0" +version = "1.1.0" edition = "2021" rust-version = "1.76.0" diff --git a/src/v1/agents.rs b/src/v1/agents.rs index 4b4f1e2..35a2620 100644 --- a/src/v1/agents.rs +++ b/src/v1/agents.rs @@ -2,8 +2,9 @@ use serde::{Deserialize, Serialize}; use crate::v1::{chat, common, constants, tool}; -// ----------------------------------------------------------------------------- -// Request +// ============================================================================= +// Agent Completions (existing — POST /v1/agents/completions) +// ============================================================================= #[derive(Debug)] pub struct AgentCompletionParams { @@ -84,9 +85,7 @@ impl AgentCompletionRequest { } } -// ----------------------------------------------------------------------------- -// Response (same shape as chat completions) - +// Agent completion response (same shape as chat completions) #[derive(Clone, Debug, Deserialize, Serialize)] pub struct AgentCompletionResponse { pub id: String, @@ -96,3 +95,189 @@ pub struct AgentCompletionResponse { pub choices: Vec, pub usage: common::ResponseUsage, } + +// ============================================================================= +// Agents API — CRUD (Beta) +// POST/GET/PATCH/DELETE /v1/agents +// ============================================================================= + +// ----------------------------------------------------------------------------- +// Tool types for agents + +/// A function tool definition for an agent. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FunctionTool { + pub function: tool::ToolFunction, +} + +/// Tool types available to agents. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum AgentTool { + #[serde(rename = "function")] + Function(FunctionTool), + #[serde(rename = "web_search")] + WebSearch {}, + #[serde(rename = "web_search_premium")] + WebSearchPremium {}, + #[serde(rename = "code_interpreter")] + CodeInterpreter {}, + #[serde(rename = "image_generation")] + ImageGeneration {}, + #[serde(rename = "document_library")] + DocumentLibrary {}, +} + +impl AgentTool { + /// Create a function tool from name, description, and JSON Schema parameters. + pub fn function(name: String, description: String, parameters: serde_json::Value) -> Self { + Self::Function(FunctionTool { + function: tool::ToolFunction { + name, + description, + parameters, + }, + }) + } + + pub fn web_search() -> Self { + Self::WebSearch {} + } + + pub fn code_interpreter() -> Self { + Self::CodeInterpreter {} + } + + pub fn image_generation() -> Self { + Self::ImageGeneration {} + } + + pub fn document_library() -> Self { + Self::DocumentLibrary {} + } +} + +// ----------------------------------------------------------------------------- +// Completion args (subset of chat params allowed for agents) + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct CompletionArgs { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub random_seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prediction: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, +} + +// ----------------------------------------------------------------------------- +// Create agent request + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CreateAgentRequest { + pub model: String, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub handoffs: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_args: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +// ----------------------------------------------------------------------------- +// Update agent request + +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub struct UpdateAgentRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub handoffs: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_args: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +// ----------------------------------------------------------------------------- +// Agent response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Agent { + pub id: String, + pub object: String, + pub name: String, + pub model: String, + pub created_at: String, + pub updated_at: String, + #[serde(default)] + pub version: u64, + #[serde(default)] + pub versions: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + #[serde(default)] + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub handoffs: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_args: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(default)] + pub deployment_chat: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub version_message: Option, + #[serde(default)] + pub guardrails: Vec, +} + +/// List agents response. The API returns a raw JSON array of agents. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(transparent)] +pub struct AgentListResponse { + pub data: Vec, +} + +/// Delete agent response. The API returns 204 No Content on success. +#[derive(Clone, Debug)] +pub struct AgentDeleteResponse { + pub deleted: bool, +} diff --git a/src/v1/chat.rs b/src/v1/chat.rs index c4733be..f8f8c45 100644 --- a/src/v1/chat.rs +++ b/src/v1/chat.rs @@ -2,13 +2,98 @@ use serde::{Deserialize, Serialize}; use crate::v1::{common, constants, tool}; +// ----------------------------------------------------------------------------- +// Content parts (multimodal) + +/// A single part of a multimodal message. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(tag = "type")] +pub enum ContentPart { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image_url")] + ImageUrl { + image_url: ImageUrl, + }, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ImageUrl { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub detail: Option, +} + +/// Message content: either a plain text string or multimodal content parts. +/// +/// Serializes as a JSON string for text, or a JSON array for parts. +/// All existing `new_*_message()` constructors produce `Text` variants, +/// so existing code continues to work unchanged. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum ChatMessageContent { + Text(String), + Parts(Vec), +} + +impl ChatMessageContent { + /// Extract the text content. For multimodal messages, concatenates all text parts. + pub fn text(&self) -> String { + match self { + Self::Text(s) => s.clone(), + Self::Parts(parts) => parts + .iter() + .filter_map(|p| match p { + ContentPart::Text { text } => Some(text.as_str()), + _ => None, + }) + .collect::>() + .join(""), + } + } + + /// Returns the content as a string slice if it is a plain text message. + pub fn as_text(&self) -> Option<&str> { + match self { + Self::Text(s) => Some(s), + Self::Parts(_) => None, + } + } + + /// Returns true if this is a multimodal message with image parts. + pub fn has_images(&self) -> bool { + match self { + Self::Text(_) => false, + Self::Parts(parts) => parts.iter().any(|p| matches!(p, ContentPart::ImageUrl { .. })), + } + } +} + +impl std::fmt::Display for ChatMessageContent { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.text()) + } +} + +impl From for ChatMessageContent { + fn from(s: String) -> Self { + Self::Text(s) + } +} + +impl From<&str> for ChatMessageContent { + fn from(s: &str) -> Self { + Self::Text(s.to_string()) + } +} + // ----------------------------------------------------------------------------- // Definitions #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ChatMessage { pub role: ChatMessageRole, - pub content: String, + pub content: ChatMessageContent, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, /// Tool call ID, required when role is Tool. @@ -22,7 +107,7 @@ impl ChatMessage { pub fn new_system_message(content: &str) -> Self { Self { role: ChatMessageRole::System, - content: content.to_string(), + content: ChatMessageContent::Text(content.to_string()), tool_calls: None, tool_call_id: None, name: None, @@ -32,7 +117,7 @@ impl ChatMessage { pub fn new_assistant_message(content: &str, tool_calls: Option>) -> Self { Self { role: ChatMessageRole::Assistant, - content: content.to_string(), + content: ChatMessageContent::Text(content.to_string()), tool_calls, tool_call_id: None, name: None, @@ -42,7 +127,18 @@ impl ChatMessage { pub fn new_user_message(content: &str) -> Self { Self { role: ChatMessageRole::User, - content: content.to_string(), + content: ChatMessageContent::Text(content.to_string()), + tool_calls: None, + tool_call_id: None, + name: None, + } + } + + /// Create a user message with mixed text and image content. + pub fn new_user_message_with_images(parts: Vec) -> Self { + Self { + role: ChatMessageRole::User, + content: ChatMessageContent::Parts(parts), tool_calls: None, tool_call_id: None, name: None, @@ -52,7 +148,7 @@ impl ChatMessage { pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self { Self { role: ChatMessageRole::Tool, - content: content.to_string(), + content: ChatMessageContent::Text(content.to_string()), tool_calls: None, tool_call_id: Some(tool_call_id.to_string()), name: name.map(|n| n.to_string()), @@ -238,6 +334,9 @@ pub struct ChatResponseChoice { pub index: u32, pub message: ChatMessage, pub finish_reason: ChatResponseChoiceFinishReason, + /// Reasoning content returned by Magistral models. + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning: Option, } #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] diff --git a/src/v1/client.rs b/src/v1/client.rs index 7b2ecc8..5c10c96 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -10,8 +10,8 @@ use std::{ }; use crate::v1::{ - agents, audio, batch, chat, chat_stream, constants, embedding, error, files, fim, fine_tuning, - model_list, moderation, ocr, tool, utils, + agents, audio, batch, chat, chat_stream, constants, conversations, embedding, error, files, + fim, fine_tuning, model_list, moderation, ocr, tool, utils, }; #[derive(Debug)] @@ -900,6 +900,300 @@ impl Client { .map_err(|e| self.to_api_error(e)) } + // ========================================================================= + // Agents CRUD (Beta — /v1/agents) + // ========================================================================= + + pub fn create_agent( + &self, + request: &agents::CreateAgentRequest, + ) -> Result { + let response = self.post_sync("/agents", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn create_agent_async( + &self, + request: &agents::CreateAgentRequest, + ) -> Result { + let response = self.post_async("/agents", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_agent(&self, agent_id: &str) -> Result { + let response = self.get_sync(&format!("/agents/{}", agent_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_agent_async( + &self, + agent_id: &str, + ) -> Result { + let response = self.get_async(&format!("/agents/{}", agent_id)).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn update_agent( + &self, + agent_id: &str, + request: &agents::UpdateAgentRequest, + ) -> Result { + let response = self.patch_sync(&format!("/agents/{}", agent_id), request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn update_agent_async( + &self, + agent_id: &str, + request: &agents::UpdateAgentRequest, + ) -> Result { + let response = self + .patch_async(&format!("/agents/{}", agent_id), request) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn delete_agent( + &self, + agent_id: &str, + ) -> Result { + let _response = self.delete_sync(&format!("/agents/{}", agent_id))?; + Ok(agents::AgentDeleteResponse { deleted: true }) + } + + pub async fn delete_agent_async( + &self, + agent_id: &str, + ) -> Result { + let _response = self + .delete_async(&format!("/agents/{}", agent_id)) + .await?; + Ok(agents::AgentDeleteResponse { deleted: true }) + } + + pub fn list_agents(&self) -> Result { + let response = self.get_sync("/agents")?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn list_agents_async( + &self, + ) -> Result { + let response = self.get_async("/agents").await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Conversations (Beta — /v1/conversations) + // ========================================================================= + + pub fn create_conversation( + &self, + request: &conversations::CreateConversationRequest, + ) -> Result { + let response = self.post_sync("/conversations", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn create_conversation_async( + &self, + request: &conversations::CreateConversationRequest, + ) -> Result { + let response = self.post_async("/conversations", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn append_conversation( + &self, + conversation_id: &str, + request: &conversations::AppendConversationRequest, + ) -> Result { + let response = + self.post_sync(&format!("/conversations/{}", conversation_id), request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn append_conversation_async( + &self, + conversation_id: &str, + request: &conversations::AppendConversationRequest, + ) -> Result { + let response = self + .post_async(&format!("/conversations/{}", conversation_id), request) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_conversation( + &self, + conversation_id: &str, + ) -> Result { + let response = self.get_sync(&format!("/conversations/{}", conversation_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_conversation_async( + &self, + conversation_id: &str, + ) -> Result { + let response = self + .get_async(&format!("/conversations/{}", conversation_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_conversation_history( + &self, + conversation_id: &str, + ) -> Result { + let response = + self.get_sync(&format!("/conversations/{}/history", conversation_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_conversation_history_async( + &self, + conversation_id: &str, + ) -> Result { + let response = self + .get_async(&format!("/conversations/{}/history", conversation_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_conversation_messages( + &self, + conversation_id: &str, + ) -> Result { + let response = + self.get_sync(&format!("/conversations/{}/messages", conversation_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_conversation_messages_async( + &self, + conversation_id: &str, + ) -> Result { + let response = self + .get_async(&format!("/conversations/{}/messages", conversation_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn restart_conversation( + &self, + conversation_id: &str, + request: &conversations::RestartConversationRequest, + ) -> Result { + let response = self.post_sync( + &format!("/conversations/{}/restart", conversation_id), + request, + )?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn restart_conversation_async( + &self, + conversation_id: &str, + request: &conversations::RestartConversationRequest, + ) -> Result { + let response = self + .post_async( + &format!("/conversations/{}/restart", conversation_id), + request, + ) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn delete_conversation( + &self, + conversation_id: &str, + ) -> Result { + let _response = + self.delete_sync(&format!("/conversations/{}", conversation_id))?; + Ok(conversations::ConversationDeleteResponse { deleted: true }) + } + + pub async fn delete_conversation_async( + &self, + conversation_id: &str, + ) -> Result { + let _response = self + .delete_async(&format!("/conversations/{}", conversation_id)) + .await?; + Ok(conversations::ConversationDeleteResponse { deleted: true }) + } + + pub fn list_conversations( + &self, + ) -> Result { + let response = self.get_sync("/conversations")?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn list_conversations_async( + &self, + ) -> Result { + let response = self.get_async("/conversations").await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + // ========================================================================= // Function Calling // ========================================================================= @@ -1112,6 +1406,36 @@ impl Client { self.handle_async_response(result).await } + fn patch_sync( + &self, + path: &str, + params: &T, + ) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + utils::debug_pretty_json_from_struct("Request Body", params); + + let request = self.build_request_sync(reqwest_client.patch(url).json(params)); + let result = request.send(); + self.handle_sync_response(result) + } + + async fn patch_async( + &self, + path: &str, + params: &T, + ) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + utils::debug_pretty_json_from_struct("Request Body", params); + + let request = self.build_request_async(reqwest_client.patch(url).json(params)); + let result = request.send().await; + self.handle_async_response(result).await + } + fn delete_sync(&self, path: &str) -> Result { let reqwest_client = reqwest::blocking::Client::new(); let url = format!("{}{}", self.endpoint, path); diff --git a/src/v1/conversations.rs b/src/v1/conversations.rs new file mode 100644 index 0000000..9b2dd83 --- /dev/null +++ b/src/v1/conversations.rs @@ -0,0 +1,377 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::{agents, chat}; + +// ============================================================================= +// Conversations API (Beta) +// POST/GET/DELETE /v1/conversations +// ============================================================================= + +// ----------------------------------------------------------------------------- +// Conversation entries (inputs and outputs) +// All entries share common fields: id, object, type, created_at, completed_at + +/// Input entry — a message sent to the conversation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MessageInputEntry { + pub role: String, + pub content: chat::ChatMessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + pub prefix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, +} + +/// Output entry — an assistant message produced by the model. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct MessageOutputEntry { + pub role: String, + pub content: chat::ChatMessageContent, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, +} + +/// A function call requested by the model. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FunctionCallEntry { + pub name: String, + pub arguments: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, +} + +/// Result of a function call, sent back to the conversation. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct FunctionResultEntry { + pub tool_call_id: String, + pub result: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, +} + +/// A built-in tool execution (web_search, code_interpreter, etc.). +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ToolExecutionEntry { + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + #[serde(flatten)] + pub extra: serde_json::Value, +} + +/// Agent handoff entry — transfer between agents. +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AgentHandoffEntry { + #[serde(skip_serializing_if = "Option::is_none")] + pub previous_agent_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_agent_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub object: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, +} + +/// Union of all conversation entry types. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ConversationEntry { + #[serde(rename = "message.input")] + MessageInput(MessageInputEntry), + #[serde(rename = "message.output")] + MessageOutput(MessageOutputEntry), + #[serde(rename = "function.call")] + FunctionCall(FunctionCallEntry), + #[serde(rename = "function.result")] + FunctionResult(FunctionResultEntry), + #[serde(rename = "tool.execution")] + ToolExecution(ToolExecutionEntry), + #[serde(rename = "agent.handoff")] + AgentHandoff(AgentHandoffEntry), +} + +// ----------------------------------------------------------------------------- +// Conversation inputs (flexible: string or array of entries) + +/// Conversation input: either a plain string or structured entry array. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(untagged)] +pub enum ConversationInput { + Text(String), + Entries(Vec), +} + +impl From<&str> for ConversationInput { + fn from(s: &str) -> Self { + Self::Text(s.to_string()) + } +} + +impl From for ConversationInput { + fn from(s: String) -> Self { + Self::Text(s) + } +} + +impl From> for ConversationInput { + fn from(entries: Vec) -> Self { + Self::Entries(entries) + } +} + +// ----------------------------------------------------------------------------- +// Handoff execution mode + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +pub enum HandoffExecution { + #[serde(rename = "server")] + Server, + #[serde(rename = "client")] + Client, +} + +impl Default for HandoffExecution { + fn default() -> Self { + Self::Server + } +} + +// ----------------------------------------------------------------------------- +// Create conversation request (POST /v1/conversations) + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct CreateConversationRequest { + pub inputs: ConversationInput, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_version: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_args: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub handoff_execution: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + #[serde(default)] + pub stream: bool, +} + +// ----------------------------------------------------------------------------- +// Append to conversation request (POST /v1/conversations/{id}) + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct AppendConversationRequest { + pub inputs: ConversationInput, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_args: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub handoff_execution: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_confirmations: Option>, + #[serde(default)] + pub stream: bool, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct ToolCallConfirmation { + pub tool_call_id: String, + pub result: String, +} + +// ----------------------------------------------------------------------------- +// Restart conversation request (POST /v1/conversations/{id}/restart) + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct RestartConversationRequest { + pub from_entry_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub inputs: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_args: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_version: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub handoff_execution: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub store: Option, + #[serde(default)] + pub stream: bool, +} + +// ----------------------------------------------------------------------------- +// Conversation response (returned by create, append, restart) + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ConversationUsageInfo { + #[serde(default)] + pub prompt_tokens: u32, + #[serde(default)] + pub completion_tokens: u32, + #[serde(default)] + pub total_tokens: u32, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ConversationResponse { + pub conversation_id: String, + pub outputs: Vec, + pub usage: ConversationUsageInfo, + #[serde(default)] + pub object: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub guardrails: Option, +} + +impl ConversationResponse { + /// Extract the assistant's text response from the outputs, if any. + pub fn assistant_text(&self) -> Option { + for entry in &self.outputs { + if let ConversationEntry::MessageOutput(msg) = entry { + return Some(msg.content.text()); + } + } + None + } + + /// Extract all function call entries from the outputs. + pub fn function_calls(&self) -> Vec<&FunctionCallEntry> { + self.outputs + .iter() + .filter_map(|e| match e { + ConversationEntry::FunctionCall(fc) => Some(fc), + _ => None, + }) + .collect() + } + + /// Check if any outputs are agent handoff entries. + pub fn has_handoff(&self) -> bool { + self.outputs + .iter() + .any(|e| matches!(e, ConversationEntry::AgentHandoff(_))) + } +} + +// ----------------------------------------------------------------------------- +// Conversation history response (GET /v1/conversations/{id}/history) + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ConversationHistoryResponse { + pub conversation_id: String, + pub entries: Vec, + #[serde(default)] + pub object: String, +} + +// ----------------------------------------------------------------------------- +// Conversation messages response (GET /v1/conversations/{id}/messages) +// Note: may have same shape as history; keeping separate for API clarity + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ConversationMessagesResponse { + pub conversation_id: String, + #[serde(alias = "messages", alias = "entries")] + pub messages: Vec, + #[serde(default)] + pub object: String, +} + +// ----------------------------------------------------------------------------- +// Conversation info (GET /v1/conversations/{id}) + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Conversation { + pub id: String, + #[serde(default)] + pub object: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub agent_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub instructions: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub created_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub updated_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completion_args: Option, + #[serde(default)] + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub guardrails: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +/// List conversations response. API returns a raw JSON array. +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(transparent)] +pub struct ConversationListResponse { + pub data: Vec, +} + +/// Delete conversation response. API returns 204 No Content. +#[derive(Clone, Debug)] +pub struct ConversationDeleteResponse { + pub deleted: bool, +} diff --git a/src/v1/mod.rs b/src/v1/mod.rs index e1140b6..12635b5 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -6,6 +6,7 @@ pub mod chat_stream; pub mod client; pub mod common; pub mod constants; +pub mod conversations; pub mod embedding; pub mod error; pub mod files; diff --git a/tests/v1_agents_api_test.rs b/tests/v1_agents_api_test.rs new file mode 100644 index 0000000..13ae8b3 --- /dev/null +++ b/tests/v1_agents_api_test.rs @@ -0,0 +1,372 @@ +use mistralai_client::v1::{ + agents::*, + client::Client, +}; + +mod setup; + +fn make_client() -> Client { + Client::new(None, None, None, None).unwrap() +} + +// --------------------------------------------------------------------------- +// Sync tests +// --------------------------------------------------------------------------- + +#[test] +fn test_create_and_delete_agent() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-create-delete".to_string(), + description: Some("Integration test agent".to_string()), + instructions: Some("You are a test agent. Respond briefly.".to_string()), + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + + let agent = client.create_agent(&req).unwrap(); + assert!(!agent.id.is_empty()); + assert_eq!(agent.name, "test-create-delete"); + assert_eq!(agent.model, "mistral-medium-latest"); + assert_eq!(agent.object, "agent"); + // Version starts at 0 in the API + assert!(agent.description.as_deref() == Some("Integration test agent")); + assert!(agent.instructions.as_deref() == Some("You are a test agent. Respond briefly.")); + + // Cleanup + let del = client.delete_agent(&agent.id).unwrap(); + assert!(del.deleted); +} + +#[test] +fn test_create_agent_with_tools() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-agent-tools".to_string(), + description: None, + instructions: Some("You can search.".to_string()), + tools: Some(vec![ + AgentTool::function( + "search".to_string(), + "Search for things".to_string(), + serde_json::json!({ + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + }, + "required": ["query"] + }), + ), + AgentTool::web_search(), + ]), + handoffs: None, + completion_args: Some(CompletionArgs { + temperature: Some(0.3), + ..Default::default() + }), + metadata: None, + }; + + let agent = client.create_agent(&req).unwrap(); + assert_eq!(agent.tools.len(), 2); + assert!(matches!(&agent.tools[0], AgentTool::Function(_))); + assert!(matches!(&agent.tools[1], AgentTool::WebSearch {})); + + // Verify completion_args round-tripped + let args = agent.completion_args.as_ref().unwrap(); + assert!((args.temperature.unwrap() - 0.3).abs() < 0.01); + + client.delete_agent(&agent.id).unwrap(); +} + +#[test] +fn test_get_agent() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-get-agent".to_string(), + description: Some("Get test".to_string()), + instructions: None, + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + + let created = client.create_agent(&req).unwrap(); + let fetched = client.get_agent(&created.id).unwrap(); + + assert_eq!(fetched.id, created.id); + assert_eq!(fetched.name, "test-get-agent"); + assert_eq!(fetched.model, "mistral-medium-latest"); + assert_eq!(fetched.description.as_deref(), Some("Get test")); + + client.delete_agent(&created.id).unwrap(); +} + +#[test] +fn test_update_agent() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-update-agent".to_string(), + description: Some("Before update".to_string()), + instructions: Some("Original instructions".to_string()), + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + + let created = client.create_agent(&req).unwrap(); + + let update = UpdateAgentRequest { + name: Some("test-update-agent-renamed".to_string()), + description: Some("After update".to_string()), + instructions: Some("Updated instructions".to_string()), + ..Default::default() + }; + + let updated = client.update_agent(&created.id, &update).unwrap(); + assert_eq!(updated.id, created.id); + assert_eq!(updated.name, "test-update-agent-renamed"); + assert_eq!(updated.description.as_deref(), Some("After update")); + assert_eq!(updated.instructions.as_deref(), Some("Updated instructions")); + // Version should have incremented + assert!(updated.version >= created.version); + + client.delete_agent(&created.id).unwrap(); +} + +#[test] +fn test_list_agents() { + setup::setup(); + let client = make_client(); + + // Create two agents + let req1 = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-list-agent-1".to_string(), + description: None, + instructions: None, + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + let req2 = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-list-agent-2".to_string(), + description: None, + instructions: None, + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + + let a1 = client.create_agent(&req1).unwrap(); + let a2 = client.create_agent(&req2).unwrap(); + + let list = client.list_agents().unwrap(); + assert!(list.data.len() >= 2); + + // Our two agents should be in the list + let ids: Vec<&str> = list.data.iter().map(|a| a.id.as_str()).collect(); + assert!(ids.contains(&a1.id.as_str())); + assert!(ids.contains(&a2.id.as_str())); + + // Cleanup + client.delete_agent(&a1.id).unwrap(); + client.delete_agent(&a2.id).unwrap(); +} + +// --------------------------------------------------------------------------- +// Async tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_create_and_delete_agent_async() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-async-create-delete".to_string(), + description: Some("Async integration test".to_string()), + instructions: Some("Respond briefly.".to_string()), + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + + let agent = client.create_agent_async(&req).await.unwrap(); + assert!(!agent.id.is_empty()); + assert_eq!(agent.name, "test-async-create-delete"); + assert_eq!(agent.object, "agent"); + + let del = client.delete_agent_async(&agent.id).await.unwrap(); + assert!(del.deleted); +} + +#[tokio::test] +async fn test_get_agent_async() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-async-get".to_string(), + description: None, + instructions: None, + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + + let created = client.create_agent_async(&req).await.unwrap(); + let fetched = client.get_agent_async(&created.id).await.unwrap(); + assert_eq!(fetched.id, created.id); + assert_eq!(fetched.name, "test-async-get"); + + client.delete_agent_async(&created.id).await.unwrap(); +} + +#[tokio::test] +async fn test_update_agent_async() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-async-update".to_string(), + description: Some("Before".to_string()), + instructions: None, + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + + let created = client.create_agent_async(&req).await.unwrap(); + + let update = UpdateAgentRequest { + description: Some("After".to_string()), + ..Default::default() + }; + let updated = client.update_agent_async(&created.id, &update).await.unwrap(); + assert_eq!(updated.description.as_deref(), Some("After")); + + client.delete_agent_async(&created.id).await.unwrap(); +} + +#[tokio::test] +async fn test_list_agents_async() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-async-list".to_string(), + description: None, + instructions: None, + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + + let agent = client.create_agent_async(&req).await.unwrap(); + let list = client.list_agents_async().await.unwrap(); + assert!(list.data.iter().any(|a| a.id == agent.id)); + + client.delete_agent_async(&agent.id).await.unwrap(); +} + +#[test] +fn test_create_agent_with_handoffs() { + setup::setup(); + let client = make_client(); + + // Create a target agent first + let target_req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-handoff-target".to_string(), + description: Some("Target agent for handoff".to_string()), + instructions: Some("You handle math questions.".to_string()), + tools: None, + handoffs: None, + completion_args: None, + metadata: None, + }; + let target = client.create_agent(&target_req).unwrap(); + + // Create orchestrator with handoff to target + let orch_req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-handoff-orchestrator".to_string(), + description: Some("Orchestrator with handoffs".to_string()), + instructions: Some("Delegate math questions.".to_string()), + tools: None, + handoffs: Some(vec![target.id.clone()]), + completion_args: None, + metadata: None, + }; + let orch = client.create_agent(&orch_req).unwrap(); + assert_eq!(orch.handoffs.as_ref().unwrap().len(), 1); + assert_eq!(orch.handoffs.as_ref().unwrap()[0], target.id); + + // Cleanup + client.delete_agent(&orch.id).unwrap(); + client.delete_agent(&target.id).unwrap(); +} + +#[test] +fn test_agent_completion_with_created_agent() { + setup::setup(); + let client = make_client(); + + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "test-completion-agent".to_string(), + description: None, + instructions: Some("Always respond with exactly the word 'pong'.".to_string()), + tools: None, + handoffs: None, + completion_args: Some(CompletionArgs { + temperature: Some(0.0), + ..Default::default() + }), + metadata: None, + }; + + let agent = client.create_agent(&req).unwrap(); + + // Use the existing agent_completion method with the created agent + use mistralai_client::v1::chat::ChatMessage; + let messages = vec![ChatMessage::new_user_message("ping")]; + let response = client + .agent_completion(agent.id.clone(), messages, None) + .unwrap(); + + assert!(!response.choices.is_empty()); + let text = response.choices[0].message.content.text().to_lowercase(); + assert!(text.contains("pong"), "Expected 'pong', got: {text}"); + assert!(response.usage.total_tokens > 0); + + client.delete_agent(&agent.id).unwrap(); +} diff --git a/tests/v1_agents_types_test.rs b/tests/v1_agents_types_test.rs new file mode 100644 index 0000000..9fbea02 --- /dev/null +++ b/tests/v1_agents_types_test.rs @@ -0,0 +1,119 @@ +use mistralai_client::v1::agents::*; + +#[test] +fn test_create_agent_request_serialization() { + let req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "sol-orchestrator".to_string(), + description: Some("Virtual librarian".to_string()), + instructions: Some("You are Sol.".to_string()), + tools: Some(vec![AgentTool::web_search()]), + handoffs: Some(vec!["agent_abc123".to_string()]), + completion_args: Some(CompletionArgs { + temperature: Some(0.3), + ..Default::default() + }), + metadata: None, + }; + + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["model"], "mistral-medium-latest"); + assert_eq!(json["name"], "sol-orchestrator"); + assert_eq!(json["tools"][0]["type"], "web_search"); + assert_eq!(json["handoffs"][0], "agent_abc123"); + assert!(json["completion_args"]["temperature"].as_f64().unwrap() - 0.3 < 0.001); +} + +#[test] +fn test_agent_response_deserialization() { + let json = serde_json::json!({ + "id": "ag_abc123", + "object": "agent", + "name": "sol-orchestrator", + "model": "mistral-medium-latest", + "created_at": "2026-03-21T10:00:00Z", + "updated_at": "2026-03-21T10:00:00Z", + "version": 1, + "versions": [1], + "description": "Virtual librarian", + "instructions": "You are Sol.", + "tools": [ + {"type": "function", "function": {"name": "search", "description": "Search", "parameters": {}}}, + {"type": "web_search"}, + {"type": "code_interpreter"} + ], + "handoffs": ["ag_def456"], + "completion_args": {"temperature": 0.3, "response_format": {"type": "text"}} + }); + + let agent: Agent = serde_json::from_value(json).unwrap(); + assert_eq!(agent.id, "ag_abc123"); + assert_eq!(agent.name, "sol-orchestrator"); + assert_eq!(agent.version, 1); + assert_eq!(agent.tools.len(), 3); + assert!(matches!(&agent.tools[0], AgentTool::Function(_))); + assert!(matches!(&agent.tools[1], AgentTool::WebSearch {})); + assert!(matches!(&agent.tools[2], AgentTool::CodeInterpreter {})); + assert_eq!(agent.handoffs.as_ref().unwrap()[0], "ag_def456"); +} + +#[test] +fn test_agent_tool_function_constructor() { + let tool = AgentTool::function( + "search_archive".to_string(), + "Search messages".to_string(), + serde_json::json!({"type": "object", "properties": {"query": {"type": "string"}}}), + ); + + let json = serde_json::to_value(&tool).unwrap(); + assert_eq!(json["type"], "function"); + assert_eq!(json["function"]["name"], "search_archive"); +} + +#[test] +fn test_completion_args_default_skips_none() { + let args = CompletionArgs::default(); + let json = serde_json::to_value(&args).unwrap(); + // All fields are None, so the JSON object should be empty + assert_eq!(json, serde_json::json!({})); +} + +#[test] +fn test_agent_delete_response() { + // AgentDeleteResponse is not deserialized from JSON — the API returns 204 No Content. + // The client constructs it directly. + let resp = AgentDeleteResponse { deleted: true }; + assert!(resp.deleted); +} + +#[test] +fn test_agent_list_response() { + // API returns a raw JSON array (no wrapper object) + let json = serde_json::json!([ + { + "id": "ag_1", + "object": "agent", + "name": "agent-1", + "model": "mistral-medium-latest", + "created_at": "2026-03-21T10:00:00Z", + "updated_at": "2026-03-21T10:00:00Z", + "version": 0, + "tools": [] + } + ]); + let resp: AgentListResponse = serde_json::from_value(json).unwrap(); + assert_eq!(resp.data.len(), 1); + assert_eq!(resp.data[0].name, "agent-1"); +} + +#[test] +fn test_update_agent_partial() { + let req = UpdateAgentRequest { + instructions: Some("New instructions".to_string()), + ..Default::default() + }; + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["instructions"], "New instructions"); + assert!(json.get("model").is_none()); + assert!(json.get("name").is_none()); +} diff --git a/tests/v1_chat_multimodal_api_test.rs b/tests/v1_chat_multimodal_api_test.rs new file mode 100644 index 0000000..aa6a5db --- /dev/null +++ b/tests/v1_chat_multimodal_api_test.rs @@ -0,0 +1,156 @@ +use mistralai_client::v1::{ + chat::{ + ChatMessage, ChatParams, ChatResponseChoiceFinishReason, ContentPart, ImageUrl, + }, + client::Client, + constants::Model, +}; + +mod setup; + +fn make_client() -> Client { + Client::new(None, None, None, None).unwrap() +} + +#[test] +fn test_multimodal_chat_with_image_url() { + setup::setup(); + let client = make_client(); + + // Use a small, publicly accessible image + let msg = ChatMessage::new_user_message_with_images(vec![ + ContentPart::Text { + text: "Describe this image in one sentence.".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://picsum.photos/id/237/200/300".to_string(), + detail: None, + }, + }, + ]); + + let model = Model::new("pixtral-large-latest".to_string()); + let options = ChatParams { + max_tokens: Some(100), + temperature: Some(0.0), + ..Default::default() + }; + + let response = client.chat(model, vec![msg], Some(options)).unwrap(); + + assert_eq!( + response.choices[0].finish_reason, + ChatResponseChoiceFinishReason::Stop + ); + let text = response.choices[0].message.content.text(); + assert!(!text.is_empty(), "Expected non-empty description"); + assert!(response.usage.total_tokens > 0); +} + +#[tokio::test] +async fn test_multimodal_chat_with_image_url_async() { + setup::setup(); + let client = make_client(); + + let msg = ChatMessage::new_user_message_with_images(vec![ + ContentPart::Text { + text: "What colors do you see in this image? Reply in one sentence.".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://picsum.photos/id/237/200/300".to_string(), + detail: None, + }, + }, + ]); + + let model = Model::new("pixtral-large-latest".to_string()); + let options = ChatParams { + max_tokens: Some(100), + temperature: Some(0.0), + ..Default::default() + }; + + let response = client + .chat_async(model, vec![msg], Some(options)) + .await + .unwrap(); + + let text = response.choices[0].message.content.text(); + assert!(!text.is_empty(), "Expected non-empty description"); + assert!(response.usage.total_tokens > 0); +} + +#[test] +fn test_mixed_text_and_image_messages() { + setup::setup(); + let client = make_client(); + + // First message: just text + let msg1 = ChatMessage::new_user_message("I'm going to show you an image next."); + + // Second message: text + image + let msg2 = ChatMessage::new_user_message_with_images(vec![ + ContentPart::Text { + text: "Here it is. What do you see?".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://picsum.photos/id/237/200/300".to_string(), + detail: None, + }, + }, + ]); + + let model = Model::new("pixtral-large-latest".to_string()); + let options = ChatParams { + max_tokens: Some(100), + temperature: Some(0.0), + ..Default::default() + }; + + let response = client.chat(model, vec![msg1, msg2], Some(options)).unwrap(); + let text = response.choices[0].message.content.text(); + assert!(!text.is_empty()); +} + +#[test] +fn test_text_only_message_still_works() { + setup::setup(); + let client = make_client(); + + // Verify that text-only messages (the common case) still work fine + // with the new ChatMessageContent type + let msg = ChatMessage::new_user_message("What is 7 + 8?"); + let model = Model::mistral_small_latest(); + let options = ChatParams { + temperature: Some(0.0), + max_tokens: Some(50), + ..Default::default() + }; + + let response = client.chat(model, vec![msg], Some(options)).unwrap(); + let text = response.choices[0].message.content.text(); + assert!(text.contains("15"), "Expected '15', got: {text}"); +} + +#[test] +fn test_reasoning_field_presence() { + setup::setup(); + let client = make_client(); + + // Normal model should not have reasoning + let msg = ChatMessage::new_user_message("What is 2 + 2?"); + let model = Model::mistral_small_latest(); + let options = ChatParams { + temperature: Some(0.0), + max_tokens: Some(50), + ..Default::default() + }; + + let response = client.chat(model, vec![msg], Some(options)).unwrap(); + // reasoning is None for non-Magistral models (or it might just be absent) + // This test verifies the field deserializes correctly either way + let _ = response.choices[0].reasoning.as_ref(); +} diff --git a/tests/v1_chat_multimodal_test.rs b/tests/v1_chat_multimodal_test.rs new file mode 100644 index 0000000..dd0f8be --- /dev/null +++ b/tests/v1_chat_multimodal_test.rs @@ -0,0 +1,204 @@ +use mistralai_client::v1::chat::*; + +#[test] +fn test_content_part_text_serialization() { + let part = ContentPart::Text { + text: "hello".to_string(), + }; + let json = serde_json::to_value(&part).unwrap(); + assert_eq!(json["type"], "text"); + assert_eq!(json["text"], "hello"); +} + +#[test] +fn test_content_part_image_url_serialization() { + let part = ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/image.png".to_string(), + detail: Some("high".to_string()), + }, + }; + let json = serde_json::to_value(&part).unwrap(); + assert_eq!(json["type"], "image_url"); + assert_eq!(json["image_url"]["url"], "https://example.com/image.png"); + assert_eq!(json["image_url"]["detail"], "high"); +} + +#[test] +fn test_content_part_image_url_no_detail() { + let part = ContentPart::ImageUrl { + image_url: ImageUrl { + url: "data:image/png;base64,abc123".to_string(), + detail: None, + }, + }; + let json = serde_json::to_value(&part).unwrap(); + assert_eq!(json["type"], "image_url"); + assert!(json["image_url"].get("detail").is_none()); +} + +#[test] +fn test_chat_message_content_text() { + let content = ChatMessageContent::Text("hello world".to_string()); + assert_eq!(content.text(), "hello world"); + assert_eq!(content.as_text(), Some("hello world")); + assert!(!content.has_images()); + assert_eq!(content.to_string(), "hello world"); +} + +#[test] +fn test_chat_message_content_parts() { + let content = ChatMessageContent::Parts(vec![ + ContentPart::Text { + text: "What is this? ".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "https://example.com/cat.jpg".to_string(), + detail: None, + }, + }, + ]); + + assert_eq!(content.text(), "What is this? "); + assert!(content.as_text().is_none()); + assert!(content.has_images()); +} + +#[test] +fn test_chat_message_content_text_serialization() { + let content = ChatMessageContent::Text("hello".to_string()); + let json = serde_json::to_value(&content).unwrap(); + assert_eq!(json, serde_json::json!("hello")); +} + +#[test] +fn test_chat_message_content_parts_serialization() { + let content = ChatMessageContent::Parts(vec![ContentPart::Text { + text: "hello".to_string(), + }]); + let json = serde_json::to_value(&content).unwrap(); + assert!(json.is_array()); + assert_eq!(json[0]["type"], "text"); +} + +#[test] +fn test_chat_message_content_text_deserialization() { + let content: ChatMessageContent = serde_json::from_value(serde_json::json!("hello")).unwrap(); + assert_eq!(content.text(), "hello"); +} + +#[test] +fn test_chat_message_content_parts_deserialization() { + let content: ChatMessageContent = serde_json::from_value(serde_json::json!([ + {"type": "text", "text": "describe this"}, + {"type": "image_url", "image_url": {"url": "https://example.com/img.jpg"}} + ])) + .unwrap(); + assert_eq!(content.text(), "describe this"); + assert!(content.has_images()); +} + +#[test] +fn test_new_user_message_text_content() { + let msg = ChatMessage::new_user_message("hello"); + let json = serde_json::to_value(&msg).unwrap(); + assert_eq!(json["role"], "user"); + assert_eq!(json["content"], "hello"); +} + +#[test] +fn test_new_user_message_with_images() { + let msg = ChatMessage::new_user_message_with_images(vec![ + ContentPart::Text { + text: "What is this?".to_string(), + }, + ContentPart::ImageUrl { + image_url: ImageUrl { + url: "data:image/png;base64,abc123".to_string(), + detail: None, + }, + }, + ]); + + let json = serde_json::to_value(&msg).unwrap(); + assert_eq!(json["role"], "user"); + assert!(json["content"].is_array()); + assert_eq!(json["content"][0]["type"], "text"); + assert_eq!(json["content"][1]["type"], "image_url"); +} + +#[test] +fn test_chat_message_content_from_str() { + let content: ChatMessageContent = "test".into(); + assert_eq!(content.text(), "test"); +} + +#[test] +fn test_chat_message_content_from_string() { + let content: ChatMessageContent = String::from("test").into(); + assert_eq!(content.text(), "test"); +} + +#[test] +fn test_chat_response_choice_with_reasoning() { + let json = serde_json::json!({ + "index": 0, + "message": { + "role": "assistant", + "content": "The answer is 42." + }, + "finish_reason": "stop", + "reasoning": "Let me think about this step by step..." + }); + + let choice: ChatResponseChoice = serde_json::from_value(json).unwrap(); + assert_eq!(choice.reasoning.as_deref(), Some("Let me think about this step by step...")); + assert_eq!(choice.message.content.text(), "The answer is 42."); +} + +#[test] +fn test_chat_response_choice_without_reasoning() { + let json = serde_json::json!({ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello" + }, + "finish_reason": "stop" + }); + + let choice: ChatResponseChoice = serde_json::from_value(json).unwrap(); + assert!(choice.reasoning.is_none()); +} + +#[test] +fn test_full_chat_response_roundtrip() { + let json = serde_json::json!({ + "id": "chat-abc123", + "object": "chat.completion", + "created": 1711000000, + "model": "mistral-medium-latest", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hi there!" + }, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + } + }); + + let resp: ChatResponse = serde_json::from_value(json).unwrap(); + assert_eq!(resp.choices[0].message.content.text(), "Hi there!"); + assert_eq!(resp.usage.total_tokens, 15); + + // Re-serialize and verify + let re_json = serde_json::to_value(&resp).unwrap(); + assert_eq!(re_json["choices"][0]["message"]["content"], "Hi there!"); +} diff --git a/tests/v1_client_chat_async_test.rs b/tests/v1_client_chat_async_test.rs index 6afb942..16d1a85 100644 --- a/tests/v1_client_chat_async_test.rs +++ b/tests/v1_client_chat_async_test.rs @@ -39,7 +39,7 @@ async fn test_client_chat_async() { expect!(response.choices[0] .message .content - .clone() + .text() .contains("Tower")) .to_be(true); diff --git a/tests/v1_client_chat_test.rs b/tests/v1_client_chat_test.rs index 1ecf769..4a2edc0 100644 --- a/tests/v1_client_chat_test.rs +++ b/tests/v1_client_chat_test.rs @@ -33,7 +33,7 @@ fn test_client_chat() { expect!(response.choices[0] .message .content - .clone() + .text() .contains("Tower")) .to_be(true); expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); diff --git a/tests/v1_constants_test.rs b/tests/v1_constants_test.rs index 7039af4..9d8f7df 100644 --- a/tests/v1_constants_test.rs +++ b/tests/v1_constants_test.rs @@ -32,6 +32,6 @@ fn test_model_constants() { expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); - expect!(response.choices[0].message.content.len()).to_be_greater_than(0); + expect!(response.choices[0].message.content.text().len()).to_be_greater_than(0); } } diff --git a/tests/v1_conversations_api_test.rs b/tests/v1_conversations_api_test.rs new file mode 100644 index 0000000..6b6623f --- /dev/null +++ b/tests/v1_conversations_api_test.rs @@ -0,0 +1,642 @@ +use mistralai_client::v1::{ + agents::*, + client::Client, + conversations::*, +}; + +mod setup; + +fn make_client() -> Client { + Client::new(None, None, None, None).unwrap() +} + +/// Helper: create a disposable agent for conversation tests (sync). +fn create_test_agent(client: &Client, name: &str) -> Agent { + let req = make_agent_request(name); + client.create_agent(&req).unwrap() +} + +/// Helper: create a disposable agent for conversation tests (async). +async fn create_test_agent_async(client: &Client, name: &str) -> Agent { + let req = make_agent_request(name); + client.create_agent_async(&req).await.unwrap() +} + +fn make_agent_request(name: &str) -> CreateAgentRequest { + CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: name.to_string(), + description: Some("Conversation test agent".to_string()), + instructions: Some("You are a helpful test agent. Keep responses short.".to_string()), + tools: None, + handoffs: None, + completion_args: Some(CompletionArgs { + temperature: Some(0.0), + ..Default::default() + }), + metadata: None, + } +} + +// --------------------------------------------------------------------------- +// Sync tests +// --------------------------------------------------------------------------- + +#[test] +fn test_create_conversation_with_agent() { + setup::setup(); + let client = make_client(); + let agent = create_test_agent(&client, "conv-test-create"); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("What is 2 + 2?".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + + let response = client.create_conversation(&req).unwrap(); + assert!(!response.conversation_id.is_empty()); + assert_eq!(response.object, "conversation.response"); + assert!(!response.outputs.is_empty()); + assert!(response.usage.total_tokens > 0); + + // Should have an assistant response + let text = response.assistant_text(); + assert!(text.is_some(), "Expected assistant text in outputs"); + assert!(text.unwrap().contains('4'), "Expected answer containing '4'"); + + // Cleanup + client.delete_conversation(&response.conversation_id).unwrap(); + client.delete_agent(&agent.id).unwrap(); +} + +#[test] +fn test_create_conversation_without_agent() { + setup::setup(); + let client = make_client(); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("Say hello.".to_string()), + model: Some("mistral-medium-latest".to_string()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: Some("Always respond with exactly 'hello'.".to_string()), + completion_args: Some(CompletionArgs { + temperature: Some(0.0), + ..Default::default() + }), + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + + let response = client.create_conversation(&req).unwrap(); + assert!(!response.conversation_id.is_empty()); + let text = response.assistant_text().unwrap().to_lowercase(); + assert!(text.contains("hello"), "Expected 'hello', got: {text}"); + + client.delete_conversation(&response.conversation_id).unwrap(); +} + +#[test] +fn test_append_to_conversation() { + setup::setup(); + let client = make_client(); + let agent = create_test_agent(&client, "conv-test-append"); + + // Create conversation + let create_req = CreateConversationRequest { + inputs: ConversationInput::Text("Remember the number 42.".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation(&create_req).unwrap(); + + // Append follow-up + let append_req = AppendConversationRequest { + inputs: ConversationInput::Text("What number did I ask you to remember?".to_string()), + completion_args: None, + handoff_execution: None, + store: None, + tool_confirmations: None, + stream: false, + }; + let appended = client + .append_conversation(&created.conversation_id, &append_req) + .unwrap(); + + assert_eq!(appended.conversation_id, created.conversation_id); + assert!(!appended.outputs.is_empty()); + let text = appended.assistant_text().unwrap(); + assert!(text.contains("42"), "Expected '42' in response, got: {text}"); + assert!(appended.usage.total_tokens > 0); + + client.delete_conversation(&created.conversation_id).unwrap(); + client.delete_agent(&agent.id).unwrap(); +} + +#[test] +fn test_get_conversation_info() { + setup::setup(); + let client = make_client(); + let agent = create_test_agent(&client, "conv-test-get-info"); + + let create_req = CreateConversationRequest { + inputs: ConversationInput::Text("Hello.".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation(&create_req).unwrap(); + + let info = client.get_conversation(&created.conversation_id).unwrap(); + assert_eq!(info.id, created.conversation_id); + assert_eq!(info.agent_id.as_deref(), Some(agent.id.as_str())); + + client.delete_conversation(&created.conversation_id).unwrap(); + client.delete_agent(&agent.id).unwrap(); +} + +#[test] +fn test_get_conversation_history() { + setup::setup(); + let client = make_client(); + let agent = create_test_agent(&client, "conv-test-history"); + + // Create and do two turns + let create_req = CreateConversationRequest { + inputs: ConversationInput::Text("First message.".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation(&create_req).unwrap(); + + let append_req = AppendConversationRequest { + inputs: ConversationInput::Text("Second message.".to_string()), + completion_args: None, + handoff_execution: None, + store: None, + tool_confirmations: None, + stream: false, + }; + client + .append_conversation(&created.conversation_id, &append_req) + .unwrap(); + + // Get history — should have at least 4 entries (user, assistant, user, assistant) + let history = client + .get_conversation_history(&created.conversation_id) + .unwrap(); + assert_eq!(history.conversation_id, created.conversation_id); + assert_eq!(history.object, "conversation.history"); + assert!( + history.entries.len() >= 4, + "Expected >= 4 history entries, got {}", + history.entries.len() + ); + + // First entry should be a message input + assert!(matches!( + &history.entries[0], + ConversationEntry::MessageInput(_) + )); + + client.delete_conversation(&created.conversation_id).unwrap(); + client.delete_agent(&agent.id).unwrap(); +} + +#[test] +fn test_get_conversation_messages() { + setup::setup(); + let client = make_client(); + let agent = create_test_agent(&client, "conv-test-messages"); + + let create_req = CreateConversationRequest { + inputs: ConversationInput::Text("Hello there.".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation(&create_req).unwrap(); + + let messages = client + .get_conversation_messages(&created.conversation_id) + .unwrap(); + assert_eq!(messages.conversation_id, created.conversation_id); + assert!(!messages.messages.is_empty()); + + client.delete_conversation(&created.conversation_id).unwrap(); + client.delete_agent(&agent.id).unwrap(); +} + +#[test] +fn test_list_conversations() { + setup::setup(); + let client = make_client(); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("List test.".to_string()), + model: Some("mistral-medium-latest".to_string()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation(&req).unwrap(); + + let list = client.list_conversations().unwrap(); + // API returns raw array (no wrapper object) + assert!(list.data.iter().any(|c| c.id == created.conversation_id)); + + client.delete_conversation(&created.conversation_id).unwrap(); +} + +#[test] +fn test_delete_conversation() { + setup::setup(); + let client = make_client(); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("To be deleted.".to_string()), + model: Some("mistral-medium-latest".to_string()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation(&req).unwrap(); + + let del = client.delete_conversation(&created.conversation_id).unwrap(); + assert!(del.deleted); + + // Should no longer appear in list + let list = client.list_conversations().unwrap(); + assert!(!list.data.iter().any(|c| c.id == created.conversation_id)); +} + +#[test] +fn test_conversation_with_structured_entries() { + setup::setup(); + let client = make_client(); + + use mistralai_client::v1::chat::ChatMessageContent; + + let entries = vec![ConversationEntry::MessageInput(MessageInputEntry { + role: "user".to_string(), + content: ChatMessageContent::Text("What is the capital of France?".to_string()), + prefix: None, + id: None, + object: None, + created_at: None, + completed_at: None, + })]; + + let req = CreateConversationRequest { + inputs: ConversationInput::Entries(entries), + model: Some("mistral-medium-latest".to_string()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: Some("Respond in one word.".to_string()), + completion_args: Some(CompletionArgs { + temperature: Some(0.0), + ..Default::default() + }), + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + + let response = client.create_conversation(&req).unwrap(); + let text = response.assistant_text().unwrap().to_lowercase(); + assert!(text.contains("paris"), "Expected 'Paris', got: {text}"); + + client.delete_conversation(&response.conversation_id).unwrap(); +} + +#[test] +fn test_conversation_with_function_calling() { + setup::setup(); + let client = make_client(); + + // Create agent with a function tool + let agent_req = CreateAgentRequest { + model: "mistral-medium-latest".to_string(), + name: "conv-test-function".to_string(), + description: None, + instructions: Some("When asked about temperature, use the get_temperature tool.".to_string()), + tools: Some(vec![AgentTool::function( + "get_temperature".to_string(), + "Get the current temperature in a city".to_string(), + serde_json::json!({ + "type": "object", + "properties": { + "city": {"type": "string", "description": "City name"} + }, + "required": ["city"] + }), + )]), + handoffs: None, + completion_args: Some(CompletionArgs { + temperature: Some(0.0), + ..Default::default() + }), + metadata: None, + }; + let agent = client.create_agent(&agent_req).unwrap(); + + // Create conversation — model should call the function + let conv_req = CreateConversationRequest { + inputs: ConversationInput::Text("What is the temperature in Paris?".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: Some(HandoffExecution::Client), + metadata: None, + store: None, + stream: false, + }; + let response = client.create_conversation(&conv_req).unwrap(); + + // With client-side execution, we should see function calls in outputs + let function_calls = response.function_calls(); + if !function_calls.is_empty() { + assert_eq!(function_calls[0].name, "get_temperature"); + let args: serde_json::Value = + serde_json::from_str(&function_calls[0].arguments).unwrap(); + assert!(args["city"].as_str().is_some()); + + // Send back the function result + let tool_call_id = function_calls[0] + .tool_call_id + .as_deref() + .unwrap_or("unknown"); + + let result_entries = vec![ConversationEntry::FunctionResult(FunctionResultEntry { + tool_call_id: tool_call_id.to_string(), + result: "22°C".to_string(), + id: None, + object: None, + created_at: None, + completed_at: None, + })]; + + let append_req = AppendConversationRequest { + inputs: ConversationInput::Entries(result_entries), + completion_args: None, + handoff_execution: None, + store: None, + tool_confirmations: None, + stream: false, + }; + let final_response = client + .append_conversation(&response.conversation_id, &append_req) + .unwrap(); + + // Now we should get an assistant text response + let text = final_response.assistant_text(); + assert!(text.is_some(), "Expected final text after function result"); + assert!( + text.unwrap().contains("22"), + "Expected temperature in response" + ); + } + // If the API handled it server-side instead, we should still have a response + else { + assert!( + response.assistant_text().is_some(), + "Expected either function calls or assistant text" + ); + } + + client.delete_conversation(&response.conversation_id).unwrap(); + client.delete_agent(&agent.id).unwrap(); +} + +// --------------------------------------------------------------------------- +// Async tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn test_create_conversation_async() { + setup::setup(); + let client = make_client(); + let agent = create_test_agent_async(&client, "conv-async-create").await; + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("Async test: what is 3 + 3?".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + + let response = client.create_conversation_async(&req).await.unwrap(); + assert!(!response.conversation_id.is_empty()); + let text = response.assistant_text().unwrap(); + assert!(text.contains('6'), "Expected '6', got: {text}"); + + client + .delete_conversation_async(&response.conversation_id) + .await + .unwrap(); + client.delete_agent_async(&agent.id).await.unwrap(); +} + +#[tokio::test] +async fn test_append_conversation_async() { + setup::setup(); + let client = make_client(); + let agent = create_test_agent_async(&client, "conv-async-append").await; + + let create_req = CreateConversationRequest { + inputs: ConversationInput::Text("My name is Alice.".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation_async(&create_req).await.unwrap(); + + let append_req = AppendConversationRequest { + inputs: ConversationInput::Text("What is my name?".to_string()), + completion_args: None, + handoff_execution: None, + store: None, + tool_confirmations: None, + stream: false, + }; + let appended = client + .append_conversation_async(&created.conversation_id, &append_req) + .await + .unwrap(); + + let text = appended.assistant_text().unwrap(); + assert!( + text.to_lowercase().contains("alice"), + "Expected 'Alice' in response, got: {text}" + ); + + client + .delete_conversation_async(&created.conversation_id) + .await + .unwrap(); + client.delete_agent_async(&agent.id).await.unwrap(); +} + +#[tokio::test] +async fn test_get_conversation_history_async() { + setup::setup(); + let client = make_client(); + let agent = create_test_agent_async(&client, "conv-async-history").await; + + let create_req = CreateConversationRequest { + inputs: ConversationInput::Text("Hello.".to_string()), + model: None, + agent_id: Some(agent.id.clone()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation_async(&create_req).await.unwrap(); + + let history = client + .get_conversation_history_async(&created.conversation_id) + .await + .unwrap(); + assert!(history.entries.len() >= 2); // at least user + assistant + + client + .delete_conversation_async(&created.conversation_id) + .await + .unwrap(); + client.delete_agent_async(&agent.id).await.unwrap(); +} + +#[tokio::test] +async fn test_list_conversations_async() { + setup::setup(); + let client = make_client(); + + let req = CreateConversationRequest { + inputs: ConversationInput::Text("Async list test.".to_string()), + model: Some("mistral-medium-latest".to_string()), + agent_id: None, + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: None, + metadata: None, + store: None, + stream: false, + }; + let created = client.create_conversation_async(&req).await.unwrap(); + + let list = client.list_conversations_async().await.unwrap(); + assert!(list.data.iter().any(|c| c.id == created.conversation_id)); + + client + .delete_conversation_async(&created.conversation_id) + .await + .unwrap(); +} diff --git a/tests/v1_conversations_types_test.rs b/tests/v1_conversations_types_test.rs new file mode 100644 index 0000000..2cefa63 --- /dev/null +++ b/tests/v1_conversations_types_test.rs @@ -0,0 +1,226 @@ +use mistralai_client::v1::chat::ChatMessageContent; +use mistralai_client::v1::conversations::*; + +#[test] +fn test_conversation_input_from_string() { + let input: ConversationInput = "hello".into(); + let json = serde_json::to_value(&input).unwrap(); + assert_eq!(json, serde_json::json!("hello")); +} + +#[test] +fn test_conversation_input_from_entries() { + let entries = vec![ConversationEntry::MessageInput(MessageInputEntry { + role: "user".to_string(), + content: ChatMessageContent::Text("hello".to_string()), + prefix: None, + id: None, + object: None, + created_at: None, + completed_at: None, + })]; + let input: ConversationInput = entries.into(); + let json = serde_json::to_value(&input).unwrap(); + assert!(json.is_array()); + assert_eq!(json[0]["type"], "message.input"); + assert_eq!(json[0]["content"], "hello"); +} + +#[test] +fn test_create_conversation_request() { + let req = CreateConversationRequest { + inputs: ConversationInput::Text("What is 2+2?".to_string()), + model: None, + agent_id: Some("ag_abc123".to_string()), + agent_version: None, + name: None, + description: None, + instructions: None, + completion_args: None, + tools: None, + handoff_execution: Some(HandoffExecution::Server), + metadata: None, + store: None, + stream: false, + }; + + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["inputs"], "What is 2+2?"); + assert_eq!(json["agent_id"], "ag_abc123"); + assert_eq!(json["handoff_execution"], "server"); + assert_eq!(json["stream"], false); +} + +#[test] +fn test_conversation_response_deserialization() { + let json = serde_json::json!({ + "conversation_id": "conv_abc123", + "outputs": [ + { + "type": "message.output", + "role": "assistant", + "content": "4" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15 + }, + "object": "conversation.response" + }); + + let resp: ConversationResponse = serde_json::from_value(json).unwrap(); + assert_eq!(resp.conversation_id, "conv_abc123"); + assert_eq!(resp.assistant_text().unwrap(), "4"); + assert_eq!(resp.usage.total_tokens, 15); + assert!(!resp.has_handoff()); +} + +#[test] +fn test_conversation_response_with_function_calls() { + let json = serde_json::json!({ + "conversation_id": "conv_abc123", + "outputs": [ + { + "type": "function.call", + "name": "search_archive", + "arguments": "{\"query\":\"error rate\"}", + "tool_call_id": "tc_1" + }, + { + "type": "message.output", + "role": "assistant", + "content": "error rate is 0.3%" + } + ], + "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30}, + "object": "conversation.response" + }); + + let resp: ConversationResponse = serde_json::from_value(json).unwrap(); + let fc = resp.function_calls(); + assert_eq!(fc.len(), 1); + assert_eq!(fc[0].name, "search_archive"); + assert_eq!(resp.assistant_text().unwrap(), "error rate is 0.3%"); +} + +#[test] +fn test_conversation_response_with_handoff() { + let json = serde_json::json!({ + "conversation_id": "conv_abc123", + "outputs": [ + { + "type": "agent.handoff", + "previous_agent_id": "ag_orch", + "next_agent_id": "ag_obs" + } + ], + "usage": {"prompt_tokens": 5, "completion_tokens": 0, "total_tokens": 5}, + "object": "conversation.response" + }); + + let resp: ConversationResponse = serde_json::from_value(json).unwrap(); + assert!(resp.has_handoff()); + assert!(resp.assistant_text().is_none()); +} + +#[test] +fn test_conversation_history_response() { + let json = serde_json::json!({ + "conversation_id": "conv_abc123", + "entries": [ + {"type": "message.input", "role": "user", "content": "hi"}, + {"type": "message.output", "role": "assistant", "content": "hello"}, + {"type": "message.input", "role": "user", "content": "search for cats"}, + {"type": "function.call", "name": "search", "arguments": "{\"q\":\"cats\"}"}, + {"type": "function.result", "tool_call_id": "tc_1", "result": "found 3 results"}, + {"type": "message.output", "role": "assistant", "content": "found 3 results about cats"} + ], + "object": "conversation.history" + }); + + let resp: ConversationHistoryResponse = serde_json::from_value(json).unwrap(); + assert_eq!(resp.entries.len(), 6); + assert!(matches!(&resp.entries[0], ConversationEntry::MessageInput(_))); + assert!(matches!(&resp.entries[3], ConversationEntry::FunctionCall(_))); + assert!(matches!(&resp.entries[4], ConversationEntry::FunctionResult(_))); +} + +#[test] +fn test_append_conversation_request() { + let req = AppendConversationRequest { + inputs: ConversationInput::Text("follow-up question".to_string()), + completion_args: None, + handoff_execution: None, + store: None, + tool_confirmations: None, + stream: false, + }; + + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["inputs"], "follow-up question"); + assert_eq!(json["stream"], false); +} + +#[test] +fn test_restart_conversation_request() { + let req = RestartConversationRequest { + from_entry_id: "entry_3".to_string(), + inputs: Some(ConversationInput::Text("different question".to_string())), + completion_args: None, + agent_version: None, + handoff_execution: Some(HandoffExecution::Client), + metadata: None, + store: None, + stream: false, + }; + + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["from_entry_id"], "entry_3"); + assert_eq!(json["handoff_execution"], "client"); +} + +#[test] +fn test_tool_call_confirmation() { + let req = AppendConversationRequest { + inputs: ConversationInput::Entries(vec![ConversationEntry::FunctionResult( + FunctionResultEntry { + tool_call_id: "tc_1".to_string(), + result: "search returned 5 results".to_string(), + id: None, + object: None, + created_at: None, + completed_at: None, + }, + )]), + completion_args: None, + handoff_execution: None, + store: None, + tool_confirmations: None, + stream: false, + }; + + let json = serde_json::to_value(&req).unwrap(); + assert_eq!(json["inputs"][0]["type"], "function.result"); + assert_eq!(json["inputs"][0]["tool_call_id"], "tc_1"); +} + +#[test] +fn test_handoff_execution_default() { + assert_eq!(HandoffExecution::default(), HandoffExecution::Server); +} + +#[test] +fn test_conversation_list_response() { + // API returns a raw JSON array + let json = serde_json::json!([ + {"id": "conv_1", "object": "conversation", "agent_id": "ag_1", "created_at": "2026-03-21T00:00:00Z"}, + {"id": "conv_2", "object": "conversation", "model": "mistral-medium-latest"} + ]); + + let resp: ConversationListResponse = serde_json::from_value(json).unwrap(); + assert_eq!(resp.data.len(), 2); + assert_eq!(resp.data[0].agent_id.as_deref(), Some("ag_1")); + assert!(resp.data[1].agent_id.is_none()); +}