From 79bc40bb1571128a893d7e480e3f03c93933794b Mon Sep 17 00:00:00 2001 From: Sienna Meridian Satterwhite Date: Fri, 20 Mar 2026 17:16:26 +0000 Subject: [PATCH] Update to latest Mistral AI API (v1.0.0) - Replace closed Model enum with flexible string-based Model type with constructor methods for all current models (Mistral Large 3, Small 4, Magistral, Codestral, Devstral, Pixtral, Voxtral, etc.) - Add new API endpoints: FIM completions, Files, Fine-tuning, Batch jobs, OCR, Audio transcription, Moderations/Classifications, and Agent completions (sync + async for all) - Add new chat fields: frequency_penalty, presence_penalty, stop, n, parallel_tool_calls, reasoning_effort, min_tokens, json_schema response format - Add embedding fields: output_dimension, output_dtype - Tool parameters now accept raw JSON Schema (serde_json::Value) instead of limited enum types - Add tool call IDs and Required tool choice variant - Add DELETE HTTP method support and multipart file upload - Bump thiserror to v2, add reqwest multipart feature - Remove strum dependency (no longer needed) - Update all tests and examples for new API --- Cargo.toml | 27 +- examples/chat.rs | 14 +- examples/chat_async.rs | 15 +- examples/chat_with_function_calling.rs | 32 +- examples/chat_with_function_calling_async.rs | 32 +- examples/chat_with_streaming.rs | 21 +- examples/embeddings.rs | 5 +- examples/embeddings_async.rs | 5 +- examples/fim.rs | 21 + examples/ocr.rs | 25 + src/v1/agents.rs | 98 ++ src/v1/audio.rs | 78 ++ src/v1/batch.rs | 53 + src/v1/chat.rs | 187 +-- src/v1/chat_stream.rs | 15 +- src/v1/client.rs | 1102 +++++++++++++----- src/v1/common.rs | 1 + src/v1/constants.rs | 152 ++- src/v1/embedding.rs | 36 +- src/v1/files.rs | 55 + src/v1/fim.rs | 101 ++ src/v1/fine_tuning.rs | 101 ++ src/v1/mod.rs | 8 + src/v1/model_list.rs | 33 +- src/v1/moderation.rs | 70 ++ src/v1/ocr.rs | 96 ++ src/v1/tool.rs | 79 +- tests/v1_client_chat_async_test.rs | 34 +- tests/v1_client_chat_stream_test.rs | 46 +- tests/v1_client_chat_test.rs | 28 +- tests/v1_client_embeddings_async_test.rs | 5 +- tests/v1_client_embeddings_test.rs | 5 +- tests/v1_constants_test.rs | 19 +- 33 files changed, 1977 insertions(+), 622 deletions(-) create mode 100644 examples/fim.rs create mode 100644 examples/ocr.rs create mode 100644 src/v1/agents.rs create mode 100644 src/v1/audio.rs create mode 100644 src/v1/batch.rs create mode 100644 src/v1/files.rs create mode 100644 src/v1/fim.rs create mode 100644 src/v1/fine_tuning.rs create mode 100644 src/v1/moderation.rs create mode 100644 src/v1/ocr.rs diff --git a/Cargo.toml b/Cargo.toml index 8df6709..c507e4b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,7 +2,7 @@ name = "mistralai-client" description = "Mistral AI API client library for Rust (unofficial)." license = "Apache-2.0" -version = "0.14.0" +version = "1.0.0" edition = "2021" rust-version = "1.76.0" @@ -15,18 +15,17 @@ readme = "README.md" repository = "https://github.com/ivangabriele/mistralai-client-rs" [dependencies] -async-stream = "0.3.5" -async-trait = "0.1.77" -env_logger = "0.11.3" -futures = "0.3.30" -log = "0.4.21" -reqwest = { version = "0.12.0", features = ["json", "blocking", "stream"] } -serde = { version = "1.0.197", features = ["derive"] } -serde_json = "1.0.114" -strum = "0.26.1" -thiserror = "1.0.57" -tokio = { version = "1.36.0", features = ["full"] } -tokio-stream = "0.1.14" +async-stream = "0.3" +async-trait = "0.1" +env_logger = "0.11" +futures = "0.3" +log = "0.4" +reqwest = { version = "0.12", features = ["json", "blocking", "stream", "multipart"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +thiserror = "2" +tokio = { version = "1", features = ["full"] } +tokio-stream = "0.1" [dev-dependencies] -jrest = "0.2.3" +jrest = "0.2" diff --git a/examples/chat.rs b/examples/chat.rs index ad3be09..a924d84 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -1,5 +1,5 @@ use mistralai_client::v1::{ - chat::{ChatMessage, ChatMessageRole, ChatParams}, + chat::{ChatMessage, ChatParams}, client::Client, constants::Model, }; @@ -8,14 +8,12 @@ fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. let client = Client::new(None, None, None, None).unwrap(); - let model = Model::OpenMistral7b; - let messages = vec![ChatMessage { - role: ChatMessageRole::User, - content: "Just guess the next word: \"Eiffel ...\"?".to_string(), - tool_calls: None, - }]; + let model = Model::mistral_small_latest(); + let messages = vec![ChatMessage::new_user_message( + "Just guess the next word: \"Eiffel ...\"?", + )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; diff --git a/examples/chat_async.rs b/examples/chat_async.rs index a3f35a5..6b62848 100644 --- a/examples/chat_async.rs +++ b/examples/chat_async.rs @@ -1,5 +1,5 @@ use mistralai_client::v1::{ - chat::{ChatMessage, ChatMessageRole, ChatParams}, + chat::{ChatMessage, ChatParams}, client::Client, constants::Model, }; @@ -9,14 +9,12 @@ async fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. let client = Client::new(None, None, None, None).unwrap(); - let model = Model::OpenMistral7b; - let messages = vec![ChatMessage { - role: ChatMessageRole::User, - content: "Just guess the next word: \"Eiffel ...\"?".to_string(), - tool_calls: None, - }]; + let model = Model::mistral_small_latest(); + let messages = vec![ChatMessage::new_user_message( + "Just guess the next word: \"Eiffel ...\"?", + )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; @@ -29,5 +27,4 @@ async fn main() { "{:?}: {}", result.choices[0].message.role, result.choices[0].message.content ); - // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France." } diff --git a/examples/chat_with_function_calling.rs b/examples/chat_with_function_calling.rs index 67660fb..b14fbd1 100644 --- a/examples/chat_with_function_calling.rs +++ b/examples/chat_with_function_calling.rs @@ -1,8 +1,8 @@ use mistralai_client::v1::{ - chat::{ChatMessage, ChatMessageRole, ChatParams}, + chat::{ChatMessage, ChatParams}, client::Client, constants::Model, - tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, + tool::{Function, Tool, ToolChoice}, }; use serde::Deserialize; use std::any::Any; @@ -16,7 +16,6 @@ struct GetCityTemperatureFunction; #[async_trait::async_trait] impl Function for GetCityTemperatureFunction { async fn execute(&self, arguments: String) -> Box { - // Deserialize arguments, perform the logic, and return the result let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); let temperature = match city.as_str() { @@ -32,11 +31,16 @@ fn main() { let tools = vec![Tool::new( "get_city_temperature".to_string(), "Get the current temperature in a city.".to_string(), - vec![ToolFunctionParameter::new( - "city".to_string(), - "The name of the city.".to_string(), - ToolFunctionParameterType::String, - )], + serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city." + } + }, + "required": ["city"] + }), )]; // This example suppose you have set the `MISTRAL_API_KEY` environment variable. @@ -46,14 +50,12 @@ fn main() { Box::new(GetCityTemperatureFunction), ); - let model = Model::MistralSmallLatest; - let messages = vec![ChatMessage { - role: ChatMessageRole::User, - content: "What's the temperature in Paris?".to_string(), - tool_calls: None, - }]; + let model = Model::mistral_small_latest(); + let messages = vec![ChatMessage::new_user_message( + "What's the temperature in Paris?", + )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools), diff --git a/examples/chat_with_function_calling_async.rs b/examples/chat_with_function_calling_async.rs index 5b91329..330f114 100644 --- a/examples/chat_with_function_calling_async.rs +++ b/examples/chat_with_function_calling_async.rs @@ -1,8 +1,8 @@ use mistralai_client::v1::{ - chat::{ChatMessage, ChatMessageRole, ChatParams}, + chat::{ChatMessage, ChatParams}, client::Client, constants::Model, - tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, + tool::{Function, Tool, ToolChoice}, }; use serde::Deserialize; use std::any::Any; @@ -16,7 +16,6 @@ struct GetCityTemperatureFunction; #[async_trait::async_trait] impl Function for GetCityTemperatureFunction { async fn execute(&self, arguments: String) -> Box { - // Deserialize arguments, perform the logic, and return the result let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); let temperature = match city.as_str() { @@ -33,11 +32,16 @@ async fn main() { let tools = vec![Tool::new( "get_city_temperature".to_string(), "Get the current temperature in a city.".to_string(), - vec![ToolFunctionParameter::new( - "city".to_string(), - "The name of the city.".to_string(), - ToolFunctionParameterType::String, - )], + serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city." + } + }, + "required": ["city"] + }), )]; // This example suppose you have set the `MISTRAL_API_KEY` environment variable. @@ -47,14 +51,12 @@ async fn main() { Box::new(GetCityTemperatureFunction), ); - let model = Model::MistralSmallLatest; - let messages = vec![ChatMessage { - role: ChatMessageRole::User, - content: "What's the temperature in Paris?".to_string(), - tool_calls: None, - }]; + let model = Model::mistral_small_latest(); + let messages = vec![ChatMessage::new_user_message( + "What's the temperature in Paris?", + )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools), diff --git a/examples/chat_with_streaming.rs b/examples/chat_with_streaming.rs index f5ad8d4..bedab71 100644 --- a/examples/chat_with_streaming.rs +++ b/examples/chat_with_streaming.rs @@ -1,6 +1,6 @@ use futures::stream::StreamExt; use mistralai_client::v1::{ - chat::{ChatMessage, ChatMessageRole, ChatParams}, + chat::{ChatMessage, ChatParams}, client::Client, constants::Model, }; @@ -11,14 +11,10 @@ async fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. let client = Client::new(None, None, None, None).unwrap(); - let model = Model::OpenMistral7b; - let messages = vec![ChatMessage { - role: ChatMessageRole::User, - content: "Tell me a short happy story.".to_string(), - tool_calls: None, - }]; + let model = Model::mistral_small_latest(); + let messages = vec![ChatMessage::new_user_message("Tell me a short happy story.")]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; @@ -31,9 +27,10 @@ async fn main() { .for_each(|chunk_result| async { match chunk_result { Ok(chunks) => chunks.iter().for_each(|chunk| { - print!("{}", chunk.choices[0].delta.content); - io::stdout().flush().unwrap(); - // => "Once upon a time, [...]" + if let Some(content) = &chunk.choices[0].delta.content { + print!("{}", content); + io::stdout().flush().unwrap(); + } }), Err(error) => { eprintln!("Error processing chunk: {:?}", error) @@ -41,5 +38,5 @@ async fn main() { } }) .await; - print!("\n") // To persist the last chunk output. + println!(); } diff --git a/examples/embeddings.rs b/examples/embeddings.rs index 898e7d4..7359d99 100644 --- a/examples/embeddings.rs +++ b/examples/embeddings.rs @@ -1,10 +1,10 @@ -use mistralai_client::v1::{client::Client, constants::EmbedModel}; +use mistralai_client::v1::{client::Client, constants::Model}; fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. let client: Client = Client::new(None, None, None, None).unwrap(); - let model = EmbedModel::MistralEmbed; + let model = Model::mistral_embed(); let input = vec!["Embed this sentence.", "As well as this one."] .iter() .map(|s| s.to_string()) @@ -13,5 +13,4 @@ fn main() { let response = client.embeddings(model, input, options).unwrap(); println!("First Embedding: {:?}", response.data[0]); - // => "First Embedding: {...}" } diff --git a/examples/embeddings_async.rs b/examples/embeddings_async.rs index a93d374..1987bb5 100644 --- a/examples/embeddings_async.rs +++ b/examples/embeddings_async.rs @@ -1,11 +1,11 @@ -use mistralai_client::v1::{client::Client, constants::EmbedModel}; +use mistralai_client::v1::{client::Client, constants::Model}; #[tokio::main] async fn main() { // This example suppose you have set the `MISTRAL_API_KEY` environment variable. let client: Client = Client::new(None, None, None, None).unwrap(); - let model = EmbedModel::MistralEmbed; + let model = Model::mistral_embed(); let input = vec!["Embed this sentence.", "As well as this one."] .iter() .map(|s| s.to_string()) @@ -17,5 +17,4 @@ async fn main() { .await .unwrap(); println!("First Embedding: {:?}", response.data[0]); - // => "First Embedding: {...}" } diff --git a/examples/fim.rs b/examples/fim.rs new file mode 100644 index 0000000..7d7274f --- /dev/null +++ b/examples/fim.rs @@ -0,0 +1,21 @@ +use mistralai_client::v1::{ + client::Client, + constants::Model, + fim::FimParams, +}; + +fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let model = Model::codestral_latest(); + let prompt = "def fibonacci(n):".to_string(); + let options = FimParams { + suffix: Some("\n return result".to_string()), + temperature: Some(0.0), + ..Default::default() + }; + + let response = client.fim(model, prompt, Some(options)).unwrap(); + println!("Completion: {}", response.choices[0].message.content); +} diff --git a/examples/ocr.rs b/examples/ocr.rs new file mode 100644 index 0000000..f077e7d --- /dev/null +++ b/examples/ocr.rs @@ -0,0 +1,25 @@ +use mistralai_client::v1::{ + client::Client, + constants::Model, + ocr::{OcrDocument, OcrRequest}, +}; + +fn main() { + // This example suppose you have set the `MISTRAL_API_KEY` environment variable. + let client = Client::new(None, None, None, None).unwrap(); + + let request = OcrRequest { + model: Model::mistral_ocr_latest(), + document: OcrDocument::from_url("https://arxiv.org/pdf/2201.04234"), + pages: Some(vec![0]), + table_format: None, + include_image_base64: None, + image_limit: None, + }; + + let response = client.ocr(&request).unwrap(); + for page in &response.pages { + println!("--- Page {} ---", page.index); + println!("{}", &page.markdown[..200.min(page.markdown.len())]); + } +} diff --git a/src/v1/agents.rs b/src/v1/agents.rs new file mode 100644 index 0000000..4b4f1e2 --- /dev/null +++ b/src/v1/agents.rs @@ -0,0 +1,98 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::{chat, common, constants, tool}; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug)] +pub struct AgentCompletionParams { + pub max_tokens: Option, + pub min_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub random_seed: Option, + pub stop: Option>, + pub response_format: Option, + pub tools: Option>, + pub tool_choice: Option, +} +impl Default for AgentCompletionParams { + fn default() -> Self { + Self { + max_tokens: None, + min_tokens: None, + temperature: None, + top_p: None, + random_seed: None, + stop: None, + response_format: None, + tools: None, + tool_choice: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct AgentCompletionRequest { + pub agent_id: String, + pub messages: Vec, + pub stream: bool, + + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub random_seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, +} +impl AgentCompletionRequest { + pub fn new( + agent_id: String, + messages: Vec, + stream: bool, + options: Option, + ) -> Self { + let opts = options.unwrap_or_default(); + + Self { + agent_id, + messages, + stream, + max_tokens: opts.max_tokens, + min_tokens: opts.min_tokens, + temperature: opts.temperature, + top_p: opts.top_p, + random_seed: opts.random_seed, + stop: opts.stop, + response_format: opts.response_format, + tools: opts.tools, + tool_choice: opts.tool_choice, + } + } +} + +// ----------------------------------------------------------------------------- +// Response (same shape as chat completions) + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AgentCompletionResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: constants::Model, + pub choices: Vec, + pub usage: common::ResponseUsage, +} diff --git a/src/v1/audio.rs b/src/v1/audio.rs new file mode 100644 index 0000000..98cdfb7 --- /dev/null +++ b/src/v1/audio.rs @@ -0,0 +1,78 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request (multipart form, but we define the params struct) + +#[derive(Debug)] +pub struct AudioTranscriptionParams { + pub model: constants::Model, + pub language: Option, + pub temperature: Option, + pub diarize: Option, + pub timestamp_granularities: Option>, +} +impl Default for AudioTranscriptionParams { + fn default() -> Self { + Self { + model: constants::Model::voxtral_mini_transcribe(), + language: None, + temperature: None, + diarize: None, + timestamp_granularities: None, + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum TimestampGranularity { + Segment, + Word, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AudioTranscriptionResponse { + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub language: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub segments: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub words: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TranscriptionSegment { + pub id: u32, + pub start: f32, + pub end: f32, + pub text: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub speaker: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TranscriptionWord { + pub word: String, + pub start: f32, + pub end: f32, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct AudioUsage { + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_audio_seconds: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompt_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_tokens: Option, +} diff --git a/src/v1/batch.rs b/src/v1/batch.rs new file mode 100644 index 0000000..e1e1be1 --- /dev/null +++ b/src/v1/batch.rs @@ -0,0 +1,53 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug, Serialize, Deserialize)] +pub struct BatchJobRequest { + pub input_files: Vec, + pub model: constants::Model, + pub endpoint: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct BatchJobResponse { + pub id: String, + pub object: String, + pub model: constants::Model, + pub endpoint: String, + pub input_files: Vec, + pub status: String, + pub created_at: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_at: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_file: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub error_file: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub metadata: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub total_requests: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub completed_requests: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub succeeded_requests: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub failed_requests: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct BatchJobListResponse { + pub data: Vec, + pub object: String, + #[serde(default)] + pub total: u32, +} diff --git a/src/v1/chat.rs b/src/v1/chat.rs index 2309029..1411357 100644 --- a/src/v1/chat.rs +++ b/src/v1/chat.rs @@ -11,13 +11,31 @@ pub struct ChatMessage { pub content: String, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + /// Tool call ID, required when role is Tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Function name, used when role is Tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, } impl ChatMessage { + pub fn new_system_message(content: &str) -> Self { + Self { + role: ChatMessageRole::System, + content: content.to_string(), + tool_calls: None, + tool_call_id: None, + name: None, + } + } + pub fn new_assistant_message(content: &str, tool_calls: Option>) -> Self { Self { role: ChatMessageRole::Assistant, content: content.to_string(), tool_calls, + tool_call_id: None, + name: None, } } @@ -26,6 +44,18 @@ impl ChatMessage { role: ChatMessageRole::User, content: content.to_string(), tool_calls: None, + tool_call_id: None, + name: None, + } + } + + pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self { + Self { + role: ChatMessageRole::Tool, + content: content.to_string(), + tool_calls: None, + tool_call_id: Some(tool_call_id.to_string()), + name: name.map(|n| n.to_string()), } } } @@ -44,17 +74,32 @@ pub enum ChatMessageRole { } /// The format that the model must output. -/// -/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ResponseFormat { #[serde(rename = "type")] pub type_: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, } impl ResponseFormat { + pub fn text() -> Self { + Self { + type_: "text".to_string(), + json_schema: None, + } + } + pub fn json_object() -> Self { Self { type_: "json_object".to_string(), + json_schema: None, + } + } + + pub fn json_schema(schema: serde_json::Value) -> Self { + Self { + type_: "json_schema".to_string(), + json_schema: Some(schema), } } } @@ -63,91 +108,83 @@ impl ResponseFormat { // Request /// The parameters for the chat request. -/// -/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information. #[derive(Clone, Debug)] pub struct ChatParams { - /// The maximum number of tokens to generate in the completion. - /// - /// Defaults to `None`. pub max_tokens: Option, - /// The seed to use for random sampling. If set, different calls will generate deterministic results. - /// - /// Defaults to `None`. + pub min_tokens: Option, pub random_seed: Option, - /// The format that the model must output. - /// - /// Defaults to `None`. pub response_format: Option, - /// Whether to inject a safety prompt before all conversations. - /// - /// Defaults to `false`. pub safe_prompt: bool, - /// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`. - /// - /// Defaults to `0.7`. - pub temperature: f32, - /// Specifies if/how functions are called. - /// - /// Defaults to `None`. + pub temperature: Option, pub tool_choice: Option, - /// A list of available tools for the model. - /// - /// Defaults to `None`. pub tools: Option>, - /// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass. - /// - /// Defaults to `1.0`. - pub top_p: f32, + pub top_p: Option, + pub stop: Option>, + pub n: Option, + pub frequency_penalty: Option, + pub presence_penalty: Option, + pub parallel_tool_calls: Option, + /// For reasoning models (Magistral). "high" or "none". + pub reasoning_effort: Option, } impl Default for ChatParams { fn default() -> Self { Self { max_tokens: None, + min_tokens: None, random_seed: None, safe_prompt: false, response_format: None, - temperature: 0.7, + temperature: None, tool_choice: None, tools: None, - top_p: 1.0, - } - } -} -impl ChatParams { - pub fn json_default() -> Self { - Self { - max_tokens: None, - random_seed: None, - safe_prompt: false, - response_format: None, - temperature: 0.7, - tool_choice: None, - tools: None, - top_p: 1.0, + top_p: None, + stop: None, + n: None, + frequency_penalty: None, + presence_penalty: None, + parallel_tool_calls: None, + reasoning_effort: None, } } } #[derive(Debug, Serialize, Deserialize)] pub struct ChatRequest { - pub messages: Vec, pub model: constants::Model, + pub messages: Vec, + pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub random_seed: Option, #[serde(skip_serializing_if = "Option::is_none")] pub response_format: Option, - pub safe_prompt: bool, - pub stream: bool, - pub temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub safe_prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_choice: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, - pub top_p: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, } impl ChatRequest { pub fn new( @@ -156,30 +193,28 @@ impl ChatRequest { stream: bool, options: Option, ) -> Self { - let ChatParams { - max_tokens, - random_seed, - safe_prompt, - temperature, - tool_choice, - tools, - top_p, - response_format, - } = options.unwrap_or_default(); + let opts = options.unwrap_or_default(); + let safe_prompt = if opts.safe_prompt { Some(true) } else { None }; Self { - messages, model, - - max_tokens, - random_seed, - safe_prompt, + messages, stream, - temperature, - tool_choice, - tools, - top_p, - response_format, + max_tokens: opts.max_tokens, + min_tokens: opts.min_tokens, + random_seed: opts.random_seed, + safe_prompt, + temperature: opts.temperature, + tool_choice: opts.tool_choice, + tools: opts.tools, + top_p: opts.top_p, + response_format: opts.response_format, + stop: opts.stop, + n: opts.n, + frequency_penalty: opts.frequency_penalty, + presence_penalty: opts.presence_penalty, + parallel_tool_calls: opts.parallel_tool_calls, + reasoning_effort: opts.reasoning_effort, } } } @@ -192,7 +227,7 @@ pub struct ChatResponse { pub id: String, pub object: String, /// Unix timestamp (in seconds). - pub created: u32, + pub created: u64, pub model: constants::Model, pub choices: Vec, pub usage: common::ResponseUsage, @@ -203,14 +238,18 @@ pub struct ChatResponseChoice { pub index: u32, pub message: ChatMessage, pub finish_reason: ChatResponseChoiceFinishReason, - // TODO Check this prop (seen in API responses but undocumented). - // pub logprobs: ??? } #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] pub enum ChatResponseChoiceFinishReason { #[serde(rename = "stop")] Stop, + #[serde(rename = "length")] + Length, #[serde(rename = "tool_calls")] ToolCalls, + #[serde(rename = "model_length")] + ModelLength, + #[serde(rename = "error")] + Error, } diff --git a/src/v1/chat_stream.rs b/src/v1/chat_stream.rs index 1daf481..bce8a89 100644 --- a/src/v1/chat_stream.rs +++ b/src/v1/chat_stream.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use serde_json::from_str; -use crate::v1::{chat, common, constants, error}; +use crate::v1::{chat, common, constants, error, tool}; // ----------------------------------------------------------------------------- // Response @@ -11,12 +11,11 @@ pub struct ChatStreamChunk { pub id: String, pub object: String, /// Unix timestamp (in seconds). - pub created: u32, + pub created: u64, pub model: constants::Model, pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] pub usage: Option, - // TODO Check this prop (seen in API responses but undocumented). - // pub logprobs: ???, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice { pub index: u32, pub delta: ChatStreamChunkChoiceDelta, pub finish_reason: Option, - // TODO Check this prop (seen in API responses but undocumented). - // pub logprobs: ???, } #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ChatStreamChunkChoiceDelta { pub role: Option, - pub content: String, + #[serde(default)] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, } /// Extracts serialized chunks from a stream message. @@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line( return Ok(Some(vec![])); } - // Attempt to deserialize the JSON string into ChatStreamChunk match from_str::(chunk_as_json) { Ok(chunk) => Ok(Some(vec![chunk])), Err(e) => Err(error::ApiError { diff --git a/src/v1/client.rs b/src/v1/client.rs index 9a8ac81..7b2ecc8 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -5,10 +5,14 @@ use reqwest::Error as ReqwestError; use std::{ any::Any, collections::HashMap, + path::Path, sync::{Arc, Mutex}, }; -use crate::v1::{chat, chat_stream, constants, embedding, error, model_list, tool, utils}; +use crate::v1::{ + agents, audio, batch, chat, chat_stream, constants, embedding, error, files, fim, fine_tuning, + model_list, moderation, ocr, tool, utils, +}; #[derive(Debug)] pub struct Client { @@ -26,25 +30,10 @@ impl Client { /// /// # Arguments /// - /// * `api_key` - An optional API key. - /// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable. - /// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided. - /// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`. - /// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`. - /// - /// # Examples - /// - /// ``` - /// use mistralai_client::v1::client::Client; - /// - /// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60)); - /// assert!(client.is_ok()); - /// ``` - /// - /// # Errors - /// - /// This method fails whenever neither the `api_key` is provided - /// nor the `MISTRAL_API_KEY` environment variable is set. + /// * `api_key` - An optional API key. If not provided, uses `MISTRAL_API_KEY` env var. + /// * `endpoint` - An optional custom API endpoint. Defaults to `https://api.mistral.ai/v1`. + /// * `max_retries` - Optional maximum number of retries. Defaults to `5`. + /// * `timeout` - Optional timeout in seconds. Defaults to `120`. pub fn new( api_key: Option, endpoint: Option, @@ -69,43 +58,15 @@ impl Client { endpoint, max_retries, timeout, - functions, last_function_call_result, }) } - /// Synchronously sends a chat completion request and returns the response. - /// - /// # Arguments - /// - /// * `model` - The [Model] to use for the chat completion. - /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatParams] to customize the request. - /// - /// # Returns - /// - /// Returns a [Result] containing the `ChatResponse` if the request is successful, - /// or an [ApiError] if there is an error. - /// - /// # Examples - /// - /// ``` - /// use mistralai_client::v1::{ - /// chat::{ChatMessage, ChatMessageRole}, - /// client::Client, - /// constants::Model, - /// }; - /// - /// let client = Client::new(None, None, None, None).unwrap(); - /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::User, - /// content: "Hello, world!".to_string(), - /// tool_calls: None, - /// }]; - /// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap(); - /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content); - /// ``` + // ========================================================================= + // Chat Completions + // ========================================================================= + pub fn chat( &self, model: constants::Model, @@ -119,49 +80,13 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - self.call_function_if_any(data.clone()); - Ok(data) } Err(error) => Err(self.to_api_error(error)), } } - /// Asynchronously sends a chat completion request and returns the response. - /// - /// # Arguments - /// - /// * `model` - The [Model] to use for the chat completion. - /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatParams] to customize the request. - /// - /// # Returns - /// - /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful, - /// or an [ApiError] if there is an error. - /// - /// # Examples - /// - /// ``` - /// use mistralai_client::v1::{ - /// chat::{ChatMessage, ChatMessageRole}, - /// client::Client, - /// constants::Model, - /// }; - /// - /// #[tokio::main] - /// async fn main() { - /// let client = Client::new(None, None, None, None).unwrap(); - /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::User, - /// content: "Hello, world!".to_string(), - /// tool_calls: None, - /// }]; - /// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap(); - /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content); - /// } - /// ``` pub async fn chat_async( &self, model: constants::Model, @@ -175,68 +100,13 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - self.call_function_if_any_async(data.clone()).await; - Ok(data) } Err(error) => Err(self.to_api_error(error)), } } - /// Asynchronously sends a chat completion request and returns a stream of message chunks. - /// - /// # Arguments - /// - /// * `model` - The [Model] to use for the chat completion. - /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatParams] to customize the request. - /// - /// # Returns - /// - /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful, - /// or an [ApiError] if there is an error. - /// - /// # Examples - /// - /// ``` - /// use futures::stream::StreamExt; - /// use mistralai_client::v1::{ - /// chat::{ChatMessage, ChatMessageRole}, - /// client::Client, - /// constants::Model, - /// }; - /// use std::io::{self, Write}; - /// - /// #[tokio::main] - /// async fn main() { - /// let client = Client::new(None, None, None, None).unwrap(); - /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::User, - /// content: "Hello, world!".to_string(), - /// tool_calls: None, - /// }]; - /// - /// let stream_result = client - /// .chat_stream(Model::OpenMistral7b,messages, None) - /// .await - /// .unwrap(); - /// stream_result - /// .for_each(|chunk_result| async { - /// match chunk_result { - /// Ok(chunks) => chunks.iter().for_each(|chunk| { - /// print!("{}", chunk.choices[0].delta.content); - /// io::stdout().flush().unwrap(); - /// // => "Once upon a time, [...]" - /// }), - /// Err(error) => { - /// eprintln!("Error processing chunk: {:?}", error) - /// } - /// } - /// }) - /// .await; - /// print!("\n") // To persist the last chunk output. - /// } pub async fn chat_stream( &self, model: constants::Model, @@ -292,9 +162,13 @@ impl Client { Ok(deserialized_stream) } + // ========================================================================= + // Embeddings + // ========================================================================= + pub fn embeddings( &self, - model: constants::EmbedModel, + model: constants::Model, input: Vec, options: Option, ) -> Result { @@ -305,7 +179,6 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - Ok(data) } Err(error) => Err(self.to_api_error(error)), @@ -314,7 +187,7 @@ impl Client { pub async fn embeddings_async( &self, - model: constants::EmbedModel, + model: constants::Model, input: Vec, options: Option, ) -> Result { @@ -325,26 +198,64 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - Ok(data) } Err(error) => Err(self.to_api_error(error)), } } - pub fn get_last_function_call_result(&self) -> Option> { - let mut result_lock = self.last_function_call_result.lock().unwrap(); + // ========================================================================= + // FIM (Fill-in-the-Middle) + // ========================================================================= - result_lock.take() + pub fn fim( + &self, + model: constants::Model, + prompt: String, + options: Option, + ) -> Result { + let request = fim::FimRequest::new(model, prompt, false, options); + + let response = self.post_sync("/fim/completions", &request)?; + let result = response.json::(); + match result { + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + Ok(data) + } + Err(error) => Err(self.to_api_error(error)), + } } + pub async fn fim_async( + &self, + model: constants::Model, + prompt: String, + options: Option, + ) -> Result { + let request = fim::FimRequest::new(model, prompt, false, options); + + let response = self.post_async("/fim/completions", &request).await?; + let result = response.json::().await; + match result { + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + Ok(data) + } + Err(error) => Err(self.to_api_error(error)), + } + } + + // ========================================================================= + // Models + // ========================================================================= + pub fn list_models(&self) -> Result { let response = self.get_sync("/models")?; let result = response.json::(); match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - Ok(data) } Err(error) => Err(self.to_api_error(error)), @@ -359,68 +270,707 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - Ok(data) } Err(error) => Err(self.to_api_error(error)), } } + pub fn get_model(&self, model_id: &str) -> Result { + let response = self.get_sync(&format!("/models/{}", model_id))?; + let result = response.json::(); + match result { + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + Ok(data) + } + Err(error) => Err(self.to_api_error(error)), + } + } + + pub async fn get_model_async( + &self, + model_id: &str, + ) -> Result { + let response = self.get_async(&format!("/models/{}", model_id)).await?; + let result = response.json::().await; + match result { + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + Ok(data) + } + Err(error) => Err(self.to_api_error(error)), + } + } + + pub fn delete_model( + &self, + model_id: &str, + ) -> Result { + let response = self.delete_sync(&format!("/models/{}", model_id))?; + let result = response.json::(); + match result { + Ok(data) => Ok(data), + Err(error) => Err(self.to_api_error(error)), + } + } + + pub async fn delete_model_async( + &self, + model_id: &str, + ) -> Result { + let response = self + .delete_async(&format!("/models/{}", model_id)) + .await?; + let result = response.json::().await; + match result { + Ok(data) => Ok(data), + Err(error) => Err(self.to_api_error(error)), + } + } + + // ========================================================================= + // Files + // ========================================================================= + + pub fn list_files(&self) -> Result { + let response = self.get_sync("/files")?; + let result = response.json::(); + match result { + Ok(data) => Ok(data), + Err(error) => Err(self.to_api_error(error)), + } + } + + pub async fn list_files_async(&self) -> Result { + let response = self.get_async("/files").await?; + let result = response.json::().await; + match result { + Ok(data) => Ok(data), + Err(error) => Err(self.to_api_error(error)), + } + } + + pub fn upload_file( + &self, + file_path: &Path, + purpose: files::FilePurpose, + ) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); + let url = format!("{}/files", self.endpoint); + + let purpose_str = serde_json::to_value(&purpose) + .unwrap() + .as_str() + .unwrap() + .to_string(); + + let file_bytes = std::fs::read(file_path).map_err(|e| error::ApiError { + message: format!("Failed to read file: {}", e), + })?; + + let file_name = file_path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let part = reqwest::blocking::multipart::Part::bytes(file_bytes).file_name(file_name); + + let form = reqwest::blocking::multipart::Form::new() + .text("purpose", purpose_str) + .part("file", part); + + let request = self.build_request_sync_no_accept(reqwest_client.post(url).multipart(form)); + let result = request.send(); + match result { + Ok(response) => { + if response.status().is_success() { + response + .json::() + .map_err(|e| self.to_api_error(e)) + } else { + let status = response.status(); + let body = response.text().unwrap_or_default(); + Err(error::ApiError { + message: format!("{}: {}", status, body), + }) + } + } + Err(error) => Err(error::ApiError { + message: error.to_string(), + }), + } + } + + pub async fn upload_file_async( + &self, + file_path: &Path, + purpose: files::FilePurpose, + ) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}/files", self.endpoint); + + let purpose_str = serde_json::to_value(&purpose) + .unwrap() + .as_str() + .unwrap() + .to_string(); + + let file_bytes = tokio::fs::read(file_path).await.map_err(|e| error::ApiError { + message: format!("Failed to read file: {}", e), + })?; + + let file_name = file_path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let part = reqwest::multipart::Part::bytes(file_bytes).file_name(file_name); + + let form = reqwest::multipart::Form::new() + .text("purpose", purpose_str) + .part("file", part); + + let request = self.build_request_async_no_accept(reqwest_client.post(url).multipart(form)); + let result = request.send().await; + match result { + Ok(response) => { + if response.status().is_success() { + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } else { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + Err(error::ApiError { + message: format!("{}: {}", status, body), + }) + } + } + Err(error) => Err(error::ApiError { + message: error.to_string(), + }), + } + } + + pub fn get_file(&self, file_id: &str) -> Result { + let response = self.get_sync(&format!("/files/{}", file_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_file_async( + &self, + file_id: &str, + ) -> Result { + let response = self.get_async(&format!("/files/{}", file_id)).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn delete_file(&self, file_id: &str) -> Result { + let response = self.delete_sync(&format!("/files/{}", file_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn delete_file_async( + &self, + file_id: &str, + ) -> Result { + let response = self.delete_async(&format!("/files/{}", file_id)).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_file_url(&self, file_id: &str) -> Result { + let response = self.get_sync(&format!("/files/{}/url", file_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_file_url_async( + &self, + file_id: &str, + ) -> Result { + let response = self.get_async(&format!("/files/{}/url", file_id)).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Fine-Tuning + // ========================================================================= + + pub fn create_fine_tuning_job( + &self, + request: &fine_tuning::FineTuningJobRequest, + ) -> Result { + let response = self.post_sync("/fine_tuning/jobs", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn create_fine_tuning_job_async( + &self, + request: &fine_tuning::FineTuningJobRequest, + ) -> Result { + let response = self.post_async("/fine_tuning/jobs", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn list_fine_tuning_jobs( + &self, + ) -> Result { + let response = self.get_sync("/fine_tuning/jobs")?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn list_fine_tuning_jobs_async( + &self, + ) -> Result { + let response = self.get_async("/fine_tuning/jobs").await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_fine_tuning_job( + &self, + job_id: &str, + ) -> Result { + let response = self.get_sync(&format!("/fine_tuning/jobs/{}", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_fine_tuning_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .get_async(&format!("/fine_tuning/jobs/{}", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn cancel_fine_tuning_job( + &self, + job_id: &str, + ) -> Result { + let response = self.post_sync_empty(&format!("/fine_tuning/jobs/{}/cancel", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn cancel_fine_tuning_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .post_async_empty(&format!("/fine_tuning/jobs/{}/cancel", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn start_fine_tuning_job( + &self, + job_id: &str, + ) -> Result { + let response = self.post_sync_empty(&format!("/fine_tuning/jobs/{}/start", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn start_fine_tuning_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .post_async_empty(&format!("/fine_tuning/jobs/{}/start", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Batch Jobs + // ========================================================================= + + pub fn create_batch_job( + &self, + request: &batch::BatchJobRequest, + ) -> Result { + let response = self.post_sync("/batch/jobs", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn create_batch_job_async( + &self, + request: &batch::BatchJobRequest, + ) -> Result { + let response = self.post_async("/batch/jobs", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn list_batch_jobs(&self) -> Result { + let response = self.get_sync("/batch/jobs")?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn list_batch_jobs_async( + &self, + ) -> Result { + let response = self.get_async("/batch/jobs").await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn get_batch_job( + &self, + job_id: &str, + ) -> Result { + let response = self.get_sync(&format!("/batch/jobs/{}", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn get_batch_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .get_async(&format!("/batch/jobs/{}", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn cancel_batch_job( + &self, + job_id: &str, + ) -> Result { + let response = self.post_sync_empty(&format!("/batch/jobs/{}/cancel", job_id))?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn cancel_batch_job_async( + &self, + job_id: &str, + ) -> Result { + let response = self + .post_async_empty(&format!("/batch/jobs/{}/cancel", job_id)) + .await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // OCR + // ========================================================================= + + pub fn ocr( + &self, + request: &ocr::OcrRequest, + ) -> Result { + let response = self.post_sync("/ocr", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn ocr_async( + &self, + request: &ocr::OcrRequest, + ) -> Result { + let response = self.post_async("/ocr", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Audio Transcription + // ========================================================================= + + pub async fn transcribe_audio_async( + &self, + file_path: &Path, + params: Option, + ) -> Result { + let opts = params.unwrap_or_default(); + let reqwest_client = reqwest::Client::new(); + let url = format!("{}/audio/transcriptions", self.endpoint); + + let file_bytes = tokio::fs::read(file_path).await.map_err(|e| error::ApiError { + message: format!("Failed to read file: {}", e), + })?; + + let file_name = file_path + .file_name() + .unwrap_or_default() + .to_string_lossy() + .to_string(); + + let part = reqwest::multipart::Part::bytes(file_bytes).file_name(file_name); + + let mut form = reqwest::multipart::Form::new() + .text("model", opts.model.0) + .part("file", part); + + if let Some(lang) = opts.language { + form = form.text("language", lang); + } + if let Some(temp) = opts.temperature { + form = form.text("temperature", temp.to_string()); + } + if let Some(diarize) = opts.diarize { + form = form.text("diarize", diarize.to_string()); + } + + let request = self.build_request_async_no_accept(reqwest_client.post(url).multipart(form)); + let result = request.send().await; + match result { + Ok(response) => { + if response.status().is_success() { + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } else { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + Err(error::ApiError { + message: format!("{}: {}", status, body), + }) + } + } + Err(error) => Err(error::ApiError { + message: error.to_string(), + }), + } + } + + // ========================================================================= + // Moderations & Classifications + // ========================================================================= + + pub fn moderations( + &self, + model: constants::Model, + input: Vec, + ) -> Result { + let request = moderation::ModerationRequest::new(model, input); + let response = self.post_sync("/moderations", &request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn moderations_async( + &self, + model: constants::Model, + input: Vec, + ) -> Result { + let request = moderation::ModerationRequest::new(model, input); + let response = self.post_async("/moderations", &request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn chat_moderations( + &self, + request: &moderation::ChatModerationRequest, + ) -> Result { + let response = self.post_sync("/chat/moderations", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn chat_moderations_async( + &self, + request: &moderation::ChatModerationRequest, + ) -> Result { + let response = self.post_async("/chat/moderations", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + pub fn classifications( + &self, + request: &moderation::ClassificationRequest, + ) -> Result { + let response = self.post_sync("/classifications", request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn classifications_async( + &self, + request: &moderation::ClassificationRequest, + ) -> Result { + let response = self.post_async("/classifications", request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Agent Completions + // ========================================================================= + + pub fn agent_completion( + &self, + agent_id: String, + messages: Vec, + options: Option, + ) -> Result { + let request = agents::AgentCompletionRequest::new(agent_id, messages, false, options); + let response = self.post_sync("/agents/completions", &request)?; + response + .json::() + .map_err(|e| self.to_api_error(e)) + } + + pub async fn agent_completion_async( + &self, + agent_id: String, + messages: Vec, + options: Option, + ) -> Result { + let request = agents::AgentCompletionRequest::new(agent_id, messages, false, options); + let response = self.post_async("/agents/completions", &request).await?; + response + .json::() + .await + .map_err(|e| self.to_api_error(e)) + } + + // ========================================================================= + // Function Calling + // ========================================================================= + pub fn register_function(&mut self, name: String, function: Box) { let mut functions = self.functions.lock().unwrap(); - functions.insert(name, function); } + pub fn get_last_function_call_result(&self) -> Option> { + let mut result_lock = self.last_function_call_result.lock().unwrap(); + result_lock.take() + } + + // ========================================================================= + // HTTP Transport + // ========================================================================= + + fn user_agent(&self) -> String { + format!( + "mistralai-client-rs/{}", + env!("CARGO_PKG_VERSION") + ) + } + fn build_request_sync( &self, request: reqwest::blocking::RequestBuilder, ) -> reqwest::blocking::RequestBuilder { - let user_agent = format!( - "ivangabriele/mistralai-client-rs/{}", - env!("CARGO_PKG_VERSION") - ); - - let request_builder = request + request .bearer_auth(&self.api_key) .header("Accept", "application/json") - .header("User-Agent", user_agent); + .header("User-Agent", self.user_agent()) + } - request_builder + fn build_request_sync_no_accept( + &self, + request: reqwest::blocking::RequestBuilder, + ) -> reqwest::blocking::RequestBuilder { + request + .bearer_auth(&self.api_key) + .header("User-Agent", self.user_agent()) } fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - let user_agent = format!( - "ivangabriele/mistralai-client-rs/{}", - env!("CARGO_PKG_VERSION") - ); - - let request_builder = request + request .bearer_auth(&self.api_key) .header("Accept", "application/json") - .header("User-Agent", user_agent); + .header("User-Agent", self.user_agent()) + } - request_builder + fn build_request_async_no_accept( + &self, + request: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { + request + .bearer_auth(&self.api_key) + .header("User-Agent", self.user_agent()) } fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - let user_agent = format!( - "ivangabriele/mistralai-client-rs/{}", - env!("CARGO_PKG_VERSION") - ); - - let request_builder = request + request .bearer_auth(&self.api_key) .header("Accept", "text/event-stream") - .header("User-Agent", user_agent); - - request_builder + .header("User-Agent", self.user_agent()) } - fn call_function_if_any(&self, response: chat::ChatResponse) -> () { - let next_result = match response.choices.get(0) { - Some(first_choice) => match first_choice.message.tool_calls.to_owned() { - Some(tool_calls) => match tool_calls.get(0) { + fn call_function_if_any(&self, response: chat::ChatResponse) { + let next_result = match response.choices.first() { + Some(first_choice) => match first_choice.message.tool_calls.as_ref() { + Some(tool_calls) => match tool_calls.first() { Some(first_tool_call) => { let functions = self.functions.lock().unwrap(); match functions.get(&first_tool_call.function.name) { @@ -431,7 +981,6 @@ impl Client { .execute(first_tool_call.function.arguments.to_owned()) .await }); - Some(result) } None => None, @@ -448,10 +997,10 @@ impl Client { *last_result_lock = next_result; } - async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () { - let next_result = match response.choices.get(0) { - Some(first_choice) => match first_choice.message.tool_calls.to_owned() { - Some(tool_calls) => match tool_calls.get(0) { + async fn call_function_if_any_async(&self, response: chat::ChatResponse) { + let next_result = match response.choices.first() { + Some(first_choice) => match first_choice.message.tool_calls.as_ref() { + Some(tool_calls) => match tool_calls.first() { Some(first_tool_call) => { let functions = self.functions.lock().unwrap(); match functions.get(&first_tool_call.function.name) { @@ -459,7 +1008,6 @@ impl Client { let result = function .execute(first_tool_call.function.arguments.to_owned()) .await; - Some(result) } None => None, @@ -482,27 +1030,8 @@ impl Client { debug!("Request URL: {}", url); let request = self.build_request_sync(reqwest_client.get(url)); - let result = request.send(); - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let response_status = response.status(); - let response_body = response.text().unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); - - Err(error::ApiError { - message: format!("{}: {}", response_status, response_body), - }) - } - } - Err(error) => Err(error::ApiError { - message: error.to_string(), - }), - } + self.handle_sync_response(result) } async fn get_async(&self, path: &str) -> Result { @@ -510,29 +1039,9 @@ impl Client { let url = format!("{}{}", self.endpoint, path); debug!("Request URL: {}", url); - let request_builder = reqwest_client.get(url); - let request = self.build_request_async(request_builder); - + let request = self.build_request_async(reqwest_client.get(url)); let result = request.send().await; - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let response_status = response.status(); - let response_body = response.text().await.unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); - - Err(error::ApiError { - message: format!("{}: {}", response_status, response_body), - }) - } - } - Err(error) => Err(error::ApiError { - message: error.to_string(), - }), - } + self.handle_async_response(result).await } fn post_sync( @@ -545,29 +1054,22 @@ impl Client { debug!("Request URL: {}", url); utils::debug_pretty_json_from_struct("Request Body", params); - let request_builder = reqwest_client.post(url).json(params); - let request = self.build_request_sync(request_builder); - + let request = self.build_request_sync(reqwest_client.post(url).json(params)); let result = request.send(); - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let response_status = response.status(); - let response_body = response.text().unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); + self.handle_sync_response(result) + } - Err(error::ApiError { - message: format!("{}: {}", response_body, response_status), - }) - } - } - Err(error) => Err(error::ApiError { - message: error.to_string(), - }), - } + fn post_sync_empty( + &self, + path: &str, + ) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + + let request = self.build_request_sync(reqwest_client.post(url)); + let result = request.send(); + self.handle_sync_response(result) } async fn post_async( @@ -580,29 +1082,19 @@ impl Client { debug!("Request URL: {}", url); utils::debug_pretty_json_from_struct("Request Body", params); - let request_builder = reqwest_client.post(url).json(params); - let request = self.build_request_async(request_builder); - + let request = self.build_request_async(reqwest_client.post(url).json(params)); let result = request.send().await; - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let response_status = response.status(); - let response_body = response.text().await.unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); + self.handle_async_response(result).await + } - Err(error::ApiError { - message: format!("{}: {}", response_status, response_body), - }) - } - } - Err(error) => Err(error::ApiError { - message: error.to_string(), - }), - } + async fn post_async_empty(&self, path: &str) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + + let request = self.build_request_async(reqwest_client.post(url)); + let result = request.send().await; + self.handle_async_response(result).await } async fn post_stream( @@ -615,22 +1107,70 @@ impl Client { debug!("Request URL: {}", url); utils::debug_pretty_json_from_struct("Request Body", params); - let request_builder = reqwest_client.post(url).json(params); - let request = self.build_request_stream(request_builder); - + let request = self.build_request_stream(reqwest_client.post(url).json(params)); let result = request.send().await; + self.handle_async_response(result).await + } + + fn delete_sync(&self, path: &str) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + + let request = self.build_request_sync(reqwest_client.delete(url)); + let result = request.send(); + self.handle_sync_response(result) + } + + async fn delete_async(&self, path: &str) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + + let request = self.build_request_async(reqwest_client.delete(url)); + let result = request.send().await; + self.handle_async_response(result).await + } + + fn handle_sync_response( + &self, + result: Result, + ) -> Result { match result { Ok(response) => { if response.status().is_success() { Ok(response) } else { - let response_status = response.status(); - let response_body = response.text().await.unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); - + let status = response.status(); + let body = response.text().unwrap_or_default(); + debug!("Response Status: {}", &status); + utils::debug_pretty_json_from_string("Response Data", &body); Err(error::ApiError { - message: format!("{}: {}", response_status, response_body), + message: format!("{}: {}", status, body), + }) + } + } + Err(error) => Err(error::ApiError { + message: error.to_string(), + }), + } + } + + async fn handle_async_response( + &self, + result: Result, + ) -> Result { + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + debug!("Response Status: {}", &status); + utils::debug_pretty_json_from_string("Response Data", &body); + Err(error::ApiError { + message: format!("{}: {}", status, body), }) } } diff --git a/src/v1/common.rs b/src/v1/common.rs index 160073c..3597580 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ResponseUsage { pub prompt_tokens: u32, + #[serde(default)] pub completion_tokens: u32, pub total_tokens: u32, } diff --git a/src/v1/constants.rs b/src/v1/constants.rs index 52c0976..80168b7 100644 --- a/src/v1/constants.rs +++ b/src/v1/constants.rs @@ -1,35 +1,131 @@ +use std::fmt; + use serde::{Deserialize, Serialize}; pub const API_URL_BASE: &str = "https://api.mistral.ai/v1"; -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum Model { - #[serde(rename = "open-mistral-7b")] - OpenMistral7b, - #[serde(rename = "open-mixtral-8x7b")] - OpenMixtral8x7b, - #[serde(rename = "open-mixtral-8x22b")] - OpenMixtral8x22b, - #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")] - OpenMistralNemo, - #[serde(rename = "mistral-tiny")] - MistralTiny, - #[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")] - MistralSmallLatest, - #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")] - MistralMediumLatest, - #[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")] - MistralLargeLatest, - #[serde(rename = "mistral-large-2402")] - MistralLarge, - #[serde(rename = "codestral-latest", alias = "codestral-2405")] - CodestralLatest, - #[serde(rename = "open-codestral-mamba")] - CodestralMamba, +/// A Mistral AI model identifier. +/// +/// Use the associated constants for known models, or construct with `Model::new()` for any model string. +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Model(pub String); + +impl Model { + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } + + // Flagship / Premier + pub fn mistral_large_latest() -> Self { + Self::new("mistral-large-latest") + } + pub fn mistral_large_3() -> Self { + Self::new("mistral-large-3-25-12") + } + pub fn mistral_medium_latest() -> Self { + Self::new("mistral-medium-latest") + } + pub fn mistral_medium_3_1() -> Self { + Self::new("mistral-medium-3-1-25-08") + } + pub fn mistral_small_latest() -> Self { + Self::new("mistral-small-latest") + } + pub fn mistral_small_4() -> Self { + Self::new("mistral-small-4-0-26-03") + } + pub fn mistral_small_3_2() -> Self { + Self::new("mistral-small-3-2-25-06") + } + + // Ministral + pub fn ministral_3_14b() -> Self { + Self::new("ministral-3-14b-25-12") + } + pub fn ministral_3_8b() -> Self { + Self::new("ministral-3-8b-25-12") + } + pub fn ministral_3_3b() -> Self { + Self::new("ministral-3-3b-25-12") + } + + // Reasoning + pub fn magistral_medium_latest() -> Self { + Self::new("magistral-medium-latest") + } + pub fn magistral_small_latest() -> Self { + Self::new("magistral-small-latest") + } + + // Code + pub fn codestral_latest() -> Self { + Self::new("codestral-latest") + } + pub fn codestral_2508() -> Self { + Self::new("codestral-2508") + } + pub fn codestral_embed() -> Self { + Self::new("codestral-embed-25-05") + } + pub fn devstral_2() -> Self { + Self::new("devstral-2-25-12") + } + pub fn devstral_small_2() -> Self { + Self::new("devstral-small-2-25-12") + } + + // Multimodal / Vision + pub fn pixtral_large() -> Self { + Self::new("pixtral-large-2411") + } + + // Audio + pub fn voxtral_mini_transcribe() -> Self { + Self::new("voxtral-mini-transcribe-2-26-02") + } + pub fn voxtral_small() -> Self { + Self::new("voxtral-small-25-07") + } + pub fn voxtral_mini() -> Self { + Self::new("voxtral-mini-25-07") + } + + // Legacy (kept for backward compatibility) + pub fn open_mistral_nemo() -> Self { + Self::new("open-mistral-nemo") + } + + // Embedding + pub fn mistral_embed() -> Self { + Self::new("mistral-embed") + } + + // Moderation + pub fn mistral_moderation_latest() -> Self { + Self::new("mistral-moderation-26-03") + } + + // OCR + pub fn mistral_ocr_latest() -> Self { + Self::new("mistral-ocr-latest") + } } -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum EmbedModel { - #[serde(rename = "mistral-embed")] - MistralEmbed, +impl fmt::Display for Model { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&str> for Model { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl From for Model { + fn from(s: String) -> Self { + Self(s) + } } diff --git a/src/v1/embedding.rs b/src/v1/embedding.rs index 7d1c2d7..080a57b 100644 --- a/src/v1/embedding.rs +++ b/src/v1/embedding.rs @@ -8,42 +8,63 @@ use crate::v1::{common, constants}; #[derive(Debug)] pub struct EmbeddingRequestOptions { pub encoding_format: Option, + pub output_dimension: Option, + pub output_dtype: Option, } impl Default for EmbeddingRequestOptions { fn default() -> Self { Self { encoding_format: None, + output_dimension: None, + output_dtype: None, } } } #[derive(Debug, Serialize, Deserialize)] pub struct EmbeddingRequest { - pub model: constants::EmbedModel, + pub model: constants::Model, pub input: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub encoding_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_dimension: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_dtype: Option, } impl EmbeddingRequest { pub fn new( - model: constants::EmbedModel, + model: constants::Model, input: Vec, options: Option, ) -> Self { - let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default(); + let opts = options.unwrap_or_default(); Self { model, input, - encoding_format, + encoding_format: opts.encoding_format, + output_dimension: opts.output_dimension, + output_dtype: opts.output_dtype, } } } #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -#[allow(non_camel_case_types)] +#[serde(rename_all = "lowercase")] pub enum EmbeddingRequestEncodingFormat { - float, + Float, + Base64, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum EmbeddingOutputDtype { + Float, + Int8, + Uint8, + Binary, + Ubinary, } // ----------------------------------------------------------------------------- @@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat { #[derive(Clone, Debug, Deserialize, Serialize)] pub struct EmbeddingResponse { - pub id: String, pub object: String, - pub model: constants::EmbedModel, + pub model: constants::Model, pub data: Vec, pub usage: common::ResponseUsage, } diff --git a/src/v1/files.rs b/src/v1/files.rs new file mode 100644 index 0000000..e667983 --- /dev/null +++ b/src/v1/files.rs @@ -0,0 +1,55 @@ +use serde::{Deserialize, Serialize}; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum FilePurpose { + FineTune, + Batch, + Ocr, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FileListResponse { + pub data: Vec, + pub object: String, + #[serde(default)] + pub total: u32, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FileObject { + pub id: String, + pub object: String, + pub bytes: u64, + pub created_at: u64, + pub filename: String, + pub purpose: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub sample_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub source: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_lines: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub mimetype: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FileDeleteResponse { + pub id: String, + pub object: String, + pub deleted: bool, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FileUrlResponse { + pub url: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub expires_at: Option, +} diff --git a/src/v1/fim.rs b/src/v1/fim.rs new file mode 100644 index 0000000..5d6ff05 --- /dev/null +++ b/src/v1/fim.rs @@ -0,0 +1,101 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::{common, constants}; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug)] +pub struct FimParams { + pub suffix: Option, + pub max_tokens: Option, + pub min_tokens: Option, + pub temperature: Option, + pub top_p: Option, + pub stop: Option>, + pub random_seed: Option, +} +impl Default for FimParams { + fn default() -> Self { + Self { + suffix: None, + max_tokens: None, + min_tokens: None, + temperature: None, + top_p: None, + stop: None, + random_seed: None, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct FimRequest { + pub model: constants::Model, + pub prompt: String, + pub stream: bool, + + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub random_seed: Option, +} +impl FimRequest { + pub fn new( + model: constants::Model, + prompt: String, + stream: bool, + options: Option, + ) -> Self { + let opts = options.unwrap_or_default(); + + Self { + model, + prompt, + stream, + suffix: opts.suffix, + max_tokens: opts.max_tokens, + min_tokens: opts.min_tokens, + temperature: opts.temperature, + top_p: opts.top_p, + stop: opts.stop, + random_seed: opts.random_seed, + } + } +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FimResponse { + pub id: String, + pub object: String, + pub created: u64, + pub model: constants::Model, + pub choices: Vec, + pub usage: common::ResponseUsage, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FimResponseChoice { + pub index: u32, + pub message: FimResponseMessage, + pub finish_reason: String, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FimResponseMessage { + pub role: String, + pub content: String, +} diff --git a/src/v1/fine_tuning.rs b/src/v1/fine_tuning.rs new file mode 100644 index 0000000..c5cce2d --- /dev/null +++ b/src/v1/fine_tuning.rs @@ -0,0 +1,101 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug, Serialize, Deserialize)] +pub struct FineTuningJobRequest { + pub model: constants::Model, + pub training_files: Vec, + + #[serde(skip_serializing_if = "Option::is_none")] + pub validation_files: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub hyperparameters: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub auto_start: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub job_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub integrations: Option>, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct TrainingFile { + pub file_id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub weight: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Hyperparameters { + #[serde(skip_serializing_if = "Option::is_none")] + pub learning_rate: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub training_steps: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub warmup_fraction: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub epochs: Option, +} + +#[derive(Clone, Debug, Serialize, Deserialize)] +pub struct Integration { + pub r#type: String, + pub project: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub api_key: Option, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FineTuningJobResponse { + pub id: String, + pub object: String, + pub model: constants::Model, + pub status: FineTuningJobStatus, + pub created_at: u64, + #[serde(skip_serializing_if = "Option::is_none")] + pub modified_at: Option, + pub training_files: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub validation_files: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub hyperparameters: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub fine_tuned_model: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub suffix: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub integrations: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub trained_tokens: Option, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum FineTuningJobStatus { + Queued, + Running, + Success, + Failed, + TimeoutExceeded, + CancellationRequested, + Cancelled, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct FineTuningJobListResponse { + pub data: Vec, + pub object: String, + #[serde(default)] + pub total: u32, +} diff --git a/src/v1/mod.rs b/src/v1/mod.rs index 72165bb..e1140b6 100644 --- a/src/v1/mod.rs +++ b/src/v1/mod.rs @@ -1,3 +1,6 @@ +pub mod agents; +pub mod audio; +pub mod batch; pub mod chat; pub mod chat_stream; pub mod client; @@ -5,6 +8,11 @@ pub mod common; pub mod constants; pub mod embedding; pub mod error; +pub mod files; +pub mod fim; +pub mod fine_tuning; pub mod model_list; +pub mod moderation; +pub mod ocr; pub mod tool; pub mod utils; diff --git a/src/v1/model_list.rs b/src/v1/model_list.rs index cd69493..d0350f2 100644 --- a/src/v1/model_list.rs +++ b/src/v1/model_list.rs @@ -15,23 +15,44 @@ pub struct ModelListData { pub id: String, pub object: String, /// Unix timestamp (in seconds). - pub created: u32, + pub created: u64, pub owned_by: String, + #[serde(skip_serializing_if = "Option::is_none")] pub root: Option, + #[serde(default)] pub archived: bool, - pub name: String, - pub description: String, - pub capabilities: ModelListDataCapabilies, - pub max_context_length: u32, + #[serde(default)] + pub name: Option, + #[serde(default)] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub capabilities: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_context_length: Option, + #[serde(default)] pub aliases: Vec, /// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`). + #[serde(skip_serializing_if = "Option::is_none")] pub deprecation: Option, } #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ModelListDataCapabilies { +pub struct ModelListDataCapabilities { + #[serde(default)] pub completion_chat: bool, + #[serde(default)] pub completion_fim: bool, + #[serde(default)] pub function_calling: bool, + #[serde(default)] pub fine_tuning: bool, + #[serde(default)] + pub vision: bool, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ModelDeleteResponse { + pub id: String, + pub object: String, + pub deleted: bool, } diff --git a/src/v1/moderation.rs b/src/v1/moderation.rs new file mode 100644 index 0000000..b8e199b --- /dev/null +++ b/src/v1/moderation.rs @@ -0,0 +1,70 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug, Serialize, Deserialize)] +pub struct ModerationRequest { + pub model: constants::Model, + pub input: Vec, +} +impl ModerationRequest { + pub fn new(model: constants::Model, input: Vec) -> Self { + Self { model, input } + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatModerationRequest { + pub model: constants::Model, + pub input: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatModerationInput { + pub role: String, + pub content: String, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ClassificationRequest { + pub model: constants::Model, + pub input: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct ChatClassificationRequest { + pub model: constants::Model, + pub input: Vec, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ModerationResponse { + pub id: String, + pub model: constants::Model, + pub results: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ModerationResult { + pub categories: serde_json::Value, + pub category_scores: serde_json::Value, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ClassificationResponse { + pub id: String, + pub model: constants::Model, + pub results: Vec, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ClassificationResult { + pub categories: serde_json::Value, + pub category_scores: serde_json::Value, +} diff --git a/src/v1/ocr.rs b/src/v1/ocr.rs new file mode 100644 index 0000000..3e4cff3 --- /dev/null +++ b/src/v1/ocr.rs @@ -0,0 +1,96 @@ +use serde::{Deserialize, Serialize}; + +use crate::v1::constants; + +// ----------------------------------------------------------------------------- +// Request + +#[derive(Debug, Serialize, Deserialize)] +pub struct OcrRequest { + pub model: constants::Model, + pub document: OcrDocument, + + #[serde(skip_serializing_if = "Option::is_none")] + pub pages: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub table_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub include_image_base64: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_limit: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct OcrDocument { + #[serde(rename = "type")] + pub type_: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub document_url: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub file_id: Option, +} +impl OcrDocument { + pub fn from_url(url: &str) -> Self { + Self { + type_: "document_url".to_string(), + document_url: Some(url.to_string()), + file_id: None, + } + } + + pub fn from_file_id(file_id: &str) -> Self { + Self { + type_: "file_id".to_string(), + document_url: None, + file_id: Some(file_id.to_string()), + } + } +} + +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum OcrTableFormat { + Markdown, + Html, +} + +// ----------------------------------------------------------------------------- +// Response + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrResponse { + pub pages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub usage_info: Option, + pub model: constants::Model, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrPage { + pub index: u32, + pub markdown: String, + #[serde(default)] + pub images: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub dimensions: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrImage { + pub id: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub image_base64: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrPageDimensions { + pub width: f32, + pub height: f32, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct OcrUsageInfo { + pub pages_processed: u32, + #[serde(skip_serializing_if = "Option::is_none")] + pub doc_size_bytes: Option, +} diff --git a/src/v1/tool.rs b/src/v1/tool.rs index 9612182..17fbea3 100644 --- a/src/v1/tool.rs +++ b/src/v1/tool.rs @@ -1,12 +1,16 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use std::{any::Any, collections::HashMap, fmt::Debug}; +use std::{any::Any, fmt::Debug}; // ----------------------------------------------------------------------------- // Definitions #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] pub struct ToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, pub function: ToolCallFunction, } @@ -22,31 +26,12 @@ pub struct Tool { pub function: ToolFunction, } impl Tool { + /// Create a tool with a JSON Schema parameters object. pub fn new( function_name: String, function_description: String, - function_parameters: Vec, + parameters: serde_json::Value, ) -> Self { - let properties: HashMap = function_parameters - .into_iter() - .map(|param| { - ( - param.name, - ToolFunctionParameterProperty { - r#type: param.r#type, - description: param.description, - }, - ) - }) - .collect(); - let property_names = properties.keys().cloned().collect(); - - let parameters = ToolFunctionParameters { - r#type: ToolFunctionParametersType::Object, - properties, - required: property_names, - }; - Self { r#type: ToolType::Function, function: ToolFunction { @@ -63,50 +48,9 @@ impl Tool { #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ToolFunction { - name: String, - description: String, - parameters: ToolFunctionParameters, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ToolFunctionParameter { - name: String, - description: String, - r#type: ToolFunctionParameterType, -} -impl ToolFunctionParameter { - pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self { - Self { - name, - r#type, - description, - } - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ToolFunctionParameters { - r#type: ToolFunctionParametersType, - properties: HashMap, - required: Vec, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ToolFunctionParameterProperty { - r#type: ToolFunctionParameterType, - description: String, -} - -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum ToolFunctionParametersType { - #[serde(rename = "object")] - Object, -} - -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum ToolFunctionParameterType { - #[serde(rename = "string")] - String, + pub name: String, + pub description: String, + pub parameters: serde_json::Value, } #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] @@ -127,6 +71,9 @@ pub enum ToolChoice { /// The model won't call a function and will generate a message instead. #[serde(rename = "none")] None, + /// The model must call at least one tool. + #[serde(rename = "required")] + Required, } // ----------------------------------------------------------------------------- diff --git a/tests/v1_client_chat_async_test.rs b/tests/v1_client_chat_async_test.rs index c06aa11..6afb942 100644 --- a/tests/v1_client_chat_async_test.rs +++ b/tests/v1_client_chat_async_test.rs @@ -3,7 +3,7 @@ use mistralai_client::v1::{ chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, client::Client, constants::Model, - tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, + tool::{Tool, ToolChoice}, }; mod setup; @@ -14,12 +14,12 @@ async fn test_client_chat_async() { let client = Client::new(None, None, None, None).unwrap(); - let model = Model::OpenMistral7b; + let model = Model::mistral_small_latest(); let messages = vec![ChatMessage::new_user_message( "Guess the next word: \"Eiffel ...\"?", )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; @@ -29,7 +29,6 @@ async fn test_client_chat_async() { .await .unwrap(); - expect!(response.model).to_be(Model::OpenMistral7b); expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); @@ -56,21 +55,26 @@ async fn test_client_chat_async_with_function_calling() { let tools = vec![Tool::new( "get_city_temperature".to_string(), "Get the current temperature in a city.".to_string(), - vec![ToolFunctionParameter::new( - "city".to_string(), - "The name of the city.".to_string(), - ToolFunctionParameterType::String, - )], + serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city." + } + }, + "required": ["city"] + }), )]; let client = Client::new(None, None, None, None).unwrap(); - let model = Model::MistralSmallLatest; + let model = Model::mistral_small_latest(); let messages = vec![ChatMessage::new_user_message( "What's the current temperature in Paris?", )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), tool_choice: Some(ToolChoice::Any), tools: Some(tools), @@ -82,7 +86,6 @@ async fn test_client_chat_async_with_function_calling() { .await .unwrap(); - expect!(response.model).to_be(Model::MistralSmallLatest); expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); @@ -91,13 +94,6 @@ async fn test_client_chat_async_with_function_calling() { .to_be(ChatResponseChoiceFinishReason::ToolCalls); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); - expect!(response.choices[0].message.content.clone()).to_be("".to_string()); - // expect!(response.choices[0].message.tool_calls.clone()).to_be(Some(vec![ToolCall { - // function: ToolCallFunction { - // name: "get_city_temperature".to_string(), - // arguments: "{\"city\": \"Paris\"}".to_string(), - // }, - // }])); expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0); diff --git a/tests/v1_client_chat_stream_test.rs b/tests/v1_client_chat_stream_test.rs index 23379fa..aae04c8 100644 --- a/tests/v1_client_chat_stream_test.rs +++ b/tests/v1_client_chat_stream_test.rs @@ -1,16 +1,18 @@ +// Streaming tests require a live API key and are not run in CI. +// Uncomment to test locally. + // use futures::stream::StreamExt; -// use jrest::expect; // use mistralai_client::v1::{ -// chat_completion::{ChatParams, ChatMessage, ChatMessageRole}, +// chat::{ChatMessage, ChatParams}, // client::Client, // constants::Model, // }; - +// // #[tokio::test] // async fn test_client_chat_stream() { // let client = Client::new(None, None, None, None).unwrap(); - -// let model = Model::OpenMistral7b; +// +// let model = Model::mistral_small_latest(); // let messages = vec![ChatMessage::new_user_message( // "Just guess the next word: \"Eiffel ...\"?", // )]; @@ -19,22 +21,24 @@ // random_seed: Some(42), // ..Default::default() // }; - -// let stream_result = client.chat_stream(model, messages, Some(options)).await; -// let mut stream = stream_result.expect("Failed to create stream."); -// while let Some(maybe_chunk_result) = stream.next().await { -// match maybe_chunk_result { -// Some(Ok(chunk)) => { -// if chunk.choices[0].delta.role == Some(ChatMessageRole::Assistant) -// || chunk.choices[0].finish_reason == Some("stop".to_string()) -// { -// expect!(chunk.choices[0].delta.content.len()).to_be(0); -// } else { -// expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0); +// +// let stream = client +// .chat_stream(model, messages, Some(options)) +// .await +// .expect("Failed to create stream."); +// +// stream +// .for_each(|chunk_result| async { +// match chunk_result { +// Ok(chunks) => { +// for chunk in &chunks { +// if let Some(content) = &chunk.choices[0].delta.content { +// print!("{}", content); +// } +// } // } +// Err(error) => eprintln!("Error: {:?}", error), // } -// Some(Err(error)) => eprintln!("Error processing chunk: {:?}", error), -// None => (), -// } -// } +// }) +// .await; // } diff --git a/tests/v1_client_chat_test.rs b/tests/v1_client_chat_test.rs index e489591..1ecf769 100644 --- a/tests/v1_client_chat_test.rs +++ b/tests/v1_client_chat_test.rs @@ -3,7 +3,7 @@ use mistralai_client::v1::{ chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, client::Client, constants::Model, - tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, + tool::{Tool, ToolChoice}, }; mod setup; @@ -14,19 +14,18 @@ fn test_client_chat() { let client = Client::new(None, None, None, None).unwrap(); - let model = Model::OpenMistral7b; + let model = Model::mistral_small_latest(); let messages = vec![ChatMessage::new_user_message( "Guess the next word: \"Eiffel ...\"?", )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() }; let response = client.chat(model, messages, Some(options)).unwrap(); - expect!(response.model).to_be(Model::OpenMistral7b); expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); @@ -50,21 +49,26 @@ fn test_client_chat_with_function_calling() { let tools = vec![Tool::new( "get_city_temperature".to_string(), "Get the current temperature in a city.".to_string(), - vec![ToolFunctionParameter::new( - "city".to_string(), - "The name of the city.".to_string(), - ToolFunctionParameterType::String, - )], + serde_json::json!({ + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The name of the city." + } + }, + "required": ["city"] + }), )]; let client = Client::new(None, None, None, None).unwrap(); - let model = Model::MistralSmallLatest; + let model = Model::mistral_small_latest(); let messages = vec![ChatMessage::new_user_message( "What's the current temperature in Paris?", )]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools), @@ -73,12 +77,10 @@ fn test_client_chat_with_function_calling() { let response = client.chat(model, messages, Some(options)).unwrap(); - expect!(response.model).to_be(Model::MistralSmallLatest); expect!(response.object).to_be("chat.completion".to_string()); expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); - expect!(response.choices[0].message.content.clone()).to_be("".to_string()); expect!(response.choices[0].finish_reason.clone()) .to_be(ChatResponseChoiceFinishReason::ToolCalls); expect!(response.usage.prompt_tokens).to_be_greater_than(0); diff --git a/tests/v1_client_embeddings_async_test.rs b/tests/v1_client_embeddings_async_test.rs index ad0c689..9fd972f 100644 --- a/tests/v1_client_embeddings_async_test.rs +++ b/tests/v1_client_embeddings_async_test.rs @@ -1,11 +1,11 @@ use jrest::expect; -use mistralai_client::v1::{client::Client, constants::EmbedModel}; +use mistralai_client::v1::{client::Client, constants::Model}; #[tokio::test] async fn test_client_embeddings_async() { let client: Client = Client::new(None, None, None, None).unwrap(); - let model = EmbedModel::MistralEmbed; + let model = Model::mistral_embed(); let input = vec!["Embed this sentence.", "As well as this one."] .iter() .map(|s| s.to_string()) @@ -17,7 +17,6 @@ async fn test_client_embeddings_async() { .await .unwrap(); - expect!(response.model).to_be(EmbedModel::MistralEmbed); expect!(response.object).to_be("list".to_string()); expect!(response.data.len()).to_be(2); expect!(response.data[0].index).to_be(0); diff --git a/tests/v1_client_embeddings_test.rs b/tests/v1_client_embeddings_test.rs index bb32fa4..d4c2a80 100644 --- a/tests/v1_client_embeddings_test.rs +++ b/tests/v1_client_embeddings_test.rs @@ -1,11 +1,11 @@ use jrest::expect; -use mistralai_client::v1::{client::Client, constants::EmbedModel}; +use mistralai_client::v1::{client::Client, constants::Model}; #[test] fn test_client_embeddings() { let client: Client = Client::new(None, None, None, None).unwrap(); - let model = EmbedModel::MistralEmbed; + let model = Model::mistral_embed(); let input = vec!["Embed this sentence.", "As well as this one."] .iter() .map(|s| s.to_string()) @@ -14,7 +14,6 @@ fn test_client_embeddings() { let response = client.embeddings(model, input, options).unwrap(); - expect!(response.model).to_be(EmbedModel::MistralEmbed); expect!(response.object).to_be("list".to_string()); expect!(response.data.len()).to_be(2); expect!(response.data[0].index).to_be(0); diff --git a/tests/v1_constants_test.rs b/tests/v1_constants_test.rs index 7903acf..7039af4 100644 --- a/tests/v1_constants_test.rs +++ b/tests/v1_constants_test.rs @@ -6,26 +6,19 @@ use mistralai_client::v1::{ }; #[test] -fn test_model_constant() { +fn test_model_constants() { let models = vec![ - Model::OpenMistral7b, - Model::OpenMixtral8x7b, - Model::OpenMixtral8x22b, - Model::OpenMistralNemo, - Model::MistralTiny, - Model::MistralSmallLatest, - Model::MistralMediumLatest, - Model::MistralLargeLatest, - Model::MistralLarge, - Model::CodestralLatest, - Model::CodestralMamba, + Model::mistral_small_latest(), + Model::mistral_large_latest(), + Model::open_mistral_nemo(), + Model::codestral_latest(), ]; let client = Client::new(None, None, None, None).unwrap(); let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")]; let options = ChatParams { - temperature: 0.0, + temperature: Some(0.0), random_seed: Some(42), ..Default::default() };