diff --git a/src/v1/chat.rs b/src/v1/chat.rs index 2309029..1411357 100644 --- a/src/v1/chat.rs +++ b/src/v1/chat.rs @@ -11,13 +11,31 @@ pub struct ChatMessage { pub content: String, #[serde(skip_serializing_if = "Option::is_none")] pub tool_calls: Option>, + /// Tool call ID, required when role is Tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, + /// Function name, used when role is Tool. + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, } impl ChatMessage { + pub fn new_system_message(content: &str) -> Self { + Self { + role: ChatMessageRole::System, + content: content.to_string(), + tool_calls: None, + tool_call_id: None, + name: None, + } + } + pub fn new_assistant_message(content: &str, tool_calls: Option>) -> Self { Self { role: ChatMessageRole::Assistant, content: content.to_string(), tool_calls, + tool_call_id: None, + name: None, } } @@ -26,6 +44,18 @@ impl ChatMessage { role: ChatMessageRole::User, content: content.to_string(), tool_calls: None, + tool_call_id: None, + name: None, + } + } + + pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self { + Self { + role: ChatMessageRole::Tool, + content: content.to_string(), + tool_calls: None, + tool_call_id: Some(tool_call_id.to_string()), + name: name.map(|n| n.to_string()), } } } @@ -44,17 +74,32 @@ pub enum ChatMessageRole { } /// The format that the model must output. -/// -/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct ResponseFormat { #[serde(rename = "type")] pub type_: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, } impl ResponseFormat { + pub fn text() -> Self { + Self { + type_: "text".to_string(), + json_schema: None, + } + } + pub fn json_object() -> Self { Self { type_: "json_object".to_string(), + json_schema: None, + } + } + + pub fn json_schema(schema: serde_json::Value) -> Self { + Self { + type_: "json_schema".to_string(), + json_schema: Some(schema), } } } @@ -63,91 +108,83 @@ impl ResponseFormat { // Request /// The parameters for the chat request. -/// -/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information. #[derive(Clone, Debug)] pub struct ChatParams { - /// The maximum number of tokens to generate in the completion. - /// - /// Defaults to `None`. pub max_tokens: Option, - /// The seed to use for random sampling. If set, different calls will generate deterministic results. - /// - /// Defaults to `None`. + pub min_tokens: Option, pub random_seed: Option, - /// The format that the model must output. - /// - /// Defaults to `None`. pub response_format: Option, - /// Whether to inject a safety prompt before all conversations. - /// - /// Defaults to `false`. pub safe_prompt: bool, - /// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`. - /// - /// Defaults to `0.7`. - pub temperature: f32, - /// Specifies if/how functions are called. - /// - /// Defaults to `None`. + pub temperature: Option, pub tool_choice: Option, - /// A list of available tools for the model. - /// - /// Defaults to `None`. pub tools: Option>, - /// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass. - /// - /// Defaults to `1.0`. - pub top_p: f32, + pub top_p: Option, + pub stop: Option>, + pub n: Option, + pub frequency_penalty: Option, + pub presence_penalty: Option, + pub parallel_tool_calls: Option, + /// For reasoning models (Magistral). "high" or "none". + pub reasoning_effort: Option, } impl Default for ChatParams { fn default() -> Self { Self { max_tokens: None, + min_tokens: None, random_seed: None, safe_prompt: false, response_format: None, - temperature: 0.7, + temperature: None, tool_choice: None, tools: None, - top_p: 1.0, - } - } -} -impl ChatParams { - pub fn json_default() -> Self { - Self { - max_tokens: None, - random_seed: None, - safe_prompt: false, - response_format: None, - temperature: 0.7, - tool_choice: None, - tools: None, - top_p: 1.0, + top_p: None, + stop: None, + n: None, + frequency_penalty: None, + presence_penalty: None, + parallel_tool_calls: None, + reasoning_effort: None, } } } #[derive(Debug, Serialize, Deserialize)] pub struct ChatRequest { - pub messages: Vec, pub model: constants::Model, + pub messages: Vec, + pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub max_tokens: Option, #[serde(skip_serializing_if = "Option::is_none")] + pub min_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub random_seed: Option, #[serde(skip_serializing_if = "Option::is_none")] pub response_format: Option, - pub safe_prompt: bool, - pub stream: bool, - pub temperature: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub safe_prompt: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tool_choice: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, - pub top_p: f32, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub n: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parallel_tool_calls: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub reasoning_effort: Option, } impl ChatRequest { pub fn new( @@ -156,30 +193,28 @@ impl ChatRequest { stream: bool, options: Option, ) -> Self { - let ChatParams { - max_tokens, - random_seed, - safe_prompt, - temperature, - tool_choice, - tools, - top_p, - response_format, - } = options.unwrap_or_default(); + let opts = options.unwrap_or_default(); + let safe_prompt = if opts.safe_prompt { Some(true) } else { None }; Self { - messages, model, - - max_tokens, - random_seed, - safe_prompt, + messages, stream, - temperature, - tool_choice, - tools, - top_p, - response_format, + max_tokens: opts.max_tokens, + min_tokens: opts.min_tokens, + random_seed: opts.random_seed, + safe_prompt, + temperature: opts.temperature, + tool_choice: opts.tool_choice, + tools: opts.tools, + top_p: opts.top_p, + response_format: opts.response_format, + stop: opts.stop, + n: opts.n, + frequency_penalty: opts.frequency_penalty, + presence_penalty: opts.presence_penalty, + parallel_tool_calls: opts.parallel_tool_calls, + reasoning_effort: opts.reasoning_effort, } } } @@ -192,7 +227,7 @@ pub struct ChatResponse { pub id: String, pub object: String, /// Unix timestamp (in seconds). - pub created: u32, + pub created: u64, pub model: constants::Model, pub choices: Vec, pub usage: common::ResponseUsage, @@ -203,14 +238,18 @@ pub struct ChatResponseChoice { pub index: u32, pub message: ChatMessage, pub finish_reason: ChatResponseChoiceFinishReason, - // TODO Check this prop (seen in API responses but undocumented). - // pub logprobs: ??? } #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] pub enum ChatResponseChoiceFinishReason { #[serde(rename = "stop")] Stop, + #[serde(rename = "length")] + Length, #[serde(rename = "tool_calls")] ToolCalls, + #[serde(rename = "model_length")] + ModelLength, + #[serde(rename = "error")] + Error, } diff --git a/src/v1/chat_stream.rs b/src/v1/chat_stream.rs index 1daf481..bce8a89 100644 --- a/src/v1/chat_stream.rs +++ b/src/v1/chat_stream.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use serde_json::from_str; -use crate::v1::{chat, common, constants, error}; +use crate::v1::{chat, common, constants, error, tool}; // ----------------------------------------------------------------------------- // Response @@ -11,12 +11,11 @@ pub struct ChatStreamChunk { pub id: String, pub object: String, /// Unix timestamp (in seconds). - pub created: u32, + pub created: u64, pub model: constants::Model, pub choices: Vec, + #[serde(skip_serializing_if = "Option::is_none")] pub usage: Option, - // TODO Check this prop (seen in API responses but undocumented). - // pub logprobs: ???, } #[derive(Clone, Debug, Deserialize, Serialize)] @@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice { pub index: u32, pub delta: ChatStreamChunkChoiceDelta, pub finish_reason: Option, - // TODO Check this prop (seen in API responses but undocumented). - // pub logprobs: ???, } #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ChatStreamChunkChoiceDelta { pub role: Option, - pub content: String, + #[serde(default)] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, } /// Extracts serialized chunks from a stream message. @@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line( return Ok(Some(vec![])); } - // Attempt to deserialize the JSON string into ChatStreamChunk match from_str::(chunk_as_json) { Ok(chunk) => Ok(Some(vec![chunk])), Err(e) => Err(error::ApiError { diff --git a/src/v1/client.rs b/src/v1/client.rs index 9a8ac81..ef396f3 100644 --- a/src/v1/client.rs +++ b/src/v1/client.rs @@ -26,25 +26,10 @@ impl Client { /// /// # Arguments /// - /// * `api_key` - An optional API key. - /// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable. - /// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided. - /// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`. - /// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`. - /// - /// # Examples - /// - /// ``` - /// use mistralai_client::v1::client::Client; - /// - /// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60)); - /// assert!(client.is_ok()); - /// ``` - /// - /// # Errors - /// - /// This method fails whenever neither the `api_key` is provided - /// nor the `MISTRAL_API_KEY` environment variable is set. + /// * `api_key` - An optional API key. If not provided, uses `MISTRAL_API_KEY` env var. + /// * `endpoint` - An optional custom API endpoint. Defaults to `https://api.mistral.ai/v1`. + /// * `max_retries` - Optional maximum number of retries. Defaults to `5`. + /// * `timeout` - Optional timeout in seconds. Defaults to `120`. pub fn new( api_key: Option, endpoint: Option, @@ -69,43 +54,15 @@ impl Client { endpoint, max_retries, timeout, - functions, last_function_call_result, }) } - /// Synchronously sends a chat completion request and returns the response. - /// - /// # Arguments - /// - /// * `model` - The [Model] to use for the chat completion. - /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatParams] to customize the request. - /// - /// # Returns - /// - /// Returns a [Result] containing the `ChatResponse` if the request is successful, - /// or an [ApiError] if there is an error. - /// - /// # Examples - /// - /// ``` - /// use mistralai_client::v1::{ - /// chat::{ChatMessage, ChatMessageRole}, - /// client::Client, - /// constants::Model, - /// }; - /// - /// let client = Client::new(None, None, None, None).unwrap(); - /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::User, - /// content: "Hello, world!".to_string(), - /// tool_calls: None, - /// }]; - /// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap(); - /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content); - /// ``` + // ========================================================================= + // Chat Completions + // ========================================================================= + pub fn chat( &self, model: constants::Model, @@ -119,49 +76,13 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - self.call_function_if_any(data.clone()); - Ok(data) } Err(error) => Err(self.to_api_error(error)), } } - /// Asynchronously sends a chat completion request and returns the response. - /// - /// # Arguments - /// - /// * `model` - The [Model] to use for the chat completion. - /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatParams] to customize the request. - /// - /// # Returns - /// - /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful, - /// or an [ApiError] if there is an error. - /// - /// # Examples - /// - /// ``` - /// use mistralai_client::v1::{ - /// chat::{ChatMessage, ChatMessageRole}, - /// client::Client, - /// constants::Model, - /// }; - /// - /// #[tokio::main] - /// async fn main() { - /// let client = Client::new(None, None, None, None).unwrap(); - /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::User, - /// content: "Hello, world!".to_string(), - /// tool_calls: None, - /// }]; - /// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap(); - /// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content); - /// } - /// ``` pub async fn chat_async( &self, model: constants::Model, @@ -175,68 +96,13 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - self.call_function_if_any_async(data.clone()).await; - Ok(data) } Err(error) => Err(self.to_api_error(error)), } } - /// Asynchronously sends a chat completion request and returns a stream of message chunks. - /// - /// # Arguments - /// - /// * `model` - The [Model] to use for the chat completion. - /// * `messages` - A vector of [ChatMessage] to send as part of the chat. - /// * `options` - Optional [ChatParams] to customize the request. - /// - /// # Returns - /// - /// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful, - /// or an [ApiError] if there is an error. - /// - /// # Examples - /// - /// ``` - /// use futures::stream::StreamExt; - /// use mistralai_client::v1::{ - /// chat::{ChatMessage, ChatMessageRole}, - /// client::Client, - /// constants::Model, - /// }; - /// use std::io::{self, Write}; - /// - /// #[tokio::main] - /// async fn main() { - /// let client = Client::new(None, None, None, None).unwrap(); - /// let messages = vec![ChatMessage { - /// role: ChatMessageRole::User, - /// content: "Hello, world!".to_string(), - /// tool_calls: None, - /// }]; - /// - /// let stream_result = client - /// .chat_stream(Model::OpenMistral7b,messages, None) - /// .await - /// .unwrap(); - /// stream_result - /// .for_each(|chunk_result| async { - /// match chunk_result { - /// Ok(chunks) => chunks.iter().for_each(|chunk| { - /// print!("{}", chunk.choices[0].delta.content); - /// io::stdout().flush().unwrap(); - /// // => "Once upon a time, [...]" - /// }), - /// Err(error) => { - /// eprintln!("Error processing chunk: {:?}", error) - /// } - /// } - /// }) - /// .await; - /// print!("\n") // To persist the last chunk output. - /// } pub async fn chat_stream( &self, model: constants::Model, @@ -292,9 +158,13 @@ impl Client { Ok(deserialized_stream) } + // ========================================================================= + // Embeddings + // ========================================================================= + pub fn embeddings( &self, - model: constants::EmbedModel, + model: constants::Model, input: Vec, options: Option, ) -> Result { @@ -305,7 +175,6 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - Ok(data) } Err(error) => Err(self.to_api_error(error)), @@ -314,7 +183,7 @@ impl Client { pub async fn embeddings_async( &self, - model: constants::EmbedModel, + model: constants::Model, input: Vec, options: Option, ) -> Result { @@ -325,18 +194,15 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - Ok(data) } Err(error) => Err(self.to_api_error(error)), } } - pub fn get_last_function_call_result(&self) -> Option> { - let mut result_lock = self.last_function_call_result.lock().unwrap(); - - result_lock.take() - } + // ========================================================================= + // Models + // ========================================================================= pub fn list_models(&self) -> Result { let response = self.get_sync("/models")?; @@ -344,7 +210,6 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - Ok(data) } Err(error) => Err(self.to_api_error(error)), @@ -359,68 +224,136 @@ impl Client { match result { Ok(data) => { utils::debug_pretty_json_from_struct("Response Data", &data); - Ok(data) } Err(error) => Err(self.to_api_error(error)), } } + pub fn get_model(&self, model_id: &str) -> Result { + let response = self.get_sync(&format!("/models/{}", model_id))?; + let result = response.json::(); + match result { + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + Ok(data) + } + Err(error) => Err(self.to_api_error(error)), + } + } + + pub async fn get_model_async( + &self, + model_id: &str, + ) -> Result { + let response = self.get_async(&format!("/models/{}", model_id)).await?; + let result = response.json::().await; + match result { + Ok(data) => { + utils::debug_pretty_json_from_struct("Response Data", &data); + Ok(data) + } + Err(error) => Err(self.to_api_error(error)), + } + } + + pub fn delete_model( + &self, + model_id: &str, + ) -> Result { + let response = self.delete_sync(&format!("/models/{}", model_id))?; + let result = response.json::(); + match result { + Ok(data) => Ok(data), + Err(error) => Err(self.to_api_error(error)), + } + } + + pub async fn delete_model_async( + &self, + model_id: &str, + ) -> Result { + let response = self + .delete_async(&format!("/models/{}", model_id)) + .await?; + let result = response.json::().await; + match result { + Ok(data) => Ok(data), + Err(error) => Err(self.to_api_error(error)), + } + } + + // ========================================================================= + // Function Calling + // ========================================================================= + pub fn register_function(&mut self, name: String, function: Box) { let mut functions = self.functions.lock().unwrap(); - functions.insert(name, function); } + pub fn get_last_function_call_result(&self) -> Option> { + let mut result_lock = self.last_function_call_result.lock().unwrap(); + result_lock.take() + } + + // ========================================================================= + // HTTP Transport + // ========================================================================= + + fn user_agent(&self) -> String { + format!( + "mistralai-client-rs/{}", + env!("CARGO_PKG_VERSION") + ) + } + fn build_request_sync( &self, request: reqwest::blocking::RequestBuilder, ) -> reqwest::blocking::RequestBuilder { - let user_agent = format!( - "ivangabriele/mistralai-client-rs/{}", - env!("CARGO_PKG_VERSION") - ); - - let request_builder = request + request .bearer_auth(&self.api_key) .header("Accept", "application/json") - .header("User-Agent", user_agent); + .header("User-Agent", self.user_agent()) + } - request_builder + fn build_request_sync_no_accept( + &self, + request: reqwest::blocking::RequestBuilder, + ) -> reqwest::blocking::RequestBuilder { + request + .bearer_auth(&self.api_key) + .header("User-Agent", self.user_agent()) } fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - let user_agent = format!( - "ivangabriele/mistralai-client-rs/{}", - env!("CARGO_PKG_VERSION") - ); - - let request_builder = request + request .bearer_auth(&self.api_key) .header("Accept", "application/json") - .header("User-Agent", user_agent); + .header("User-Agent", self.user_agent()) + } - request_builder + fn build_request_async_no_accept( + &self, + request: reqwest::RequestBuilder, + ) -> reqwest::RequestBuilder { + request + .bearer_auth(&self.api_key) + .header("User-Agent", self.user_agent()) } fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { - let user_agent = format!( - "ivangabriele/mistralai-client-rs/{}", - env!("CARGO_PKG_VERSION") - ); - - let request_builder = request + request .bearer_auth(&self.api_key) .header("Accept", "text/event-stream") - .header("User-Agent", user_agent); - - request_builder + .header("User-Agent", self.user_agent()) } - fn call_function_if_any(&self, response: chat::ChatResponse) -> () { - let next_result = match response.choices.get(0) { - Some(first_choice) => match first_choice.message.tool_calls.to_owned() { - Some(tool_calls) => match tool_calls.get(0) { + fn call_function_if_any(&self, response: chat::ChatResponse) { + let next_result = match response.choices.first() { + Some(first_choice) => match first_choice.message.tool_calls.as_ref() { + Some(tool_calls) => match tool_calls.first() { Some(first_tool_call) => { let functions = self.functions.lock().unwrap(); match functions.get(&first_tool_call.function.name) { @@ -431,7 +364,6 @@ impl Client { .execute(first_tool_call.function.arguments.to_owned()) .await }); - Some(result) } None => None, @@ -448,10 +380,10 @@ impl Client { *last_result_lock = next_result; } - async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () { - let next_result = match response.choices.get(0) { - Some(first_choice) => match first_choice.message.tool_calls.to_owned() { - Some(tool_calls) => match tool_calls.get(0) { + async fn call_function_if_any_async(&self, response: chat::ChatResponse) { + let next_result = match response.choices.first() { + Some(first_choice) => match first_choice.message.tool_calls.as_ref() { + Some(tool_calls) => match tool_calls.first() { Some(first_tool_call) => { let functions = self.functions.lock().unwrap(); match functions.get(&first_tool_call.function.name) { @@ -459,7 +391,6 @@ impl Client { let result = function .execute(first_tool_call.function.arguments.to_owned()) .await; - Some(result) } None => None, @@ -482,27 +413,8 @@ impl Client { debug!("Request URL: {}", url); let request = self.build_request_sync(reqwest_client.get(url)); - let result = request.send(); - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let response_status = response.status(); - let response_body = response.text().unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); - - Err(error::ApiError { - message: format!("{}: {}", response_status, response_body), - }) - } - } - Err(error) => Err(error::ApiError { - message: error.to_string(), - }), - } + self.handle_sync_response(result) } async fn get_async(&self, path: &str) -> Result { @@ -510,29 +422,9 @@ impl Client { let url = format!("{}{}", self.endpoint, path); debug!("Request URL: {}", url); - let request_builder = reqwest_client.get(url); - let request = self.build_request_async(request_builder); - + let request = self.build_request_async(reqwest_client.get(url)); let result = request.send().await; - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let response_status = response.status(); - let response_body = response.text().await.unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); - - Err(error::ApiError { - message: format!("{}: {}", response_status, response_body), - }) - } - } - Err(error) => Err(error::ApiError { - message: error.to_string(), - }), - } + self.handle_async_response(result).await } fn post_sync( @@ -545,29 +437,22 @@ impl Client { debug!("Request URL: {}", url); utils::debug_pretty_json_from_struct("Request Body", params); - let request_builder = reqwest_client.post(url).json(params); - let request = self.build_request_sync(request_builder); - + let request = self.build_request_sync(reqwest_client.post(url).json(params)); let result = request.send(); - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let response_status = response.status(); - let response_body = response.text().unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); + self.handle_sync_response(result) + } - Err(error::ApiError { - message: format!("{}: {}", response_body, response_status), - }) - } - } - Err(error) => Err(error::ApiError { - message: error.to_string(), - }), - } + fn post_sync_empty( + &self, + path: &str, + ) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + + let request = self.build_request_sync(reqwest_client.post(url)); + let result = request.send(); + self.handle_sync_response(result) } async fn post_async( @@ -580,29 +465,19 @@ impl Client { debug!("Request URL: {}", url); utils::debug_pretty_json_from_struct("Request Body", params); - let request_builder = reqwest_client.post(url).json(params); - let request = self.build_request_async(request_builder); - + let request = self.build_request_async(reqwest_client.post(url).json(params)); let result = request.send().await; - match result { - Ok(response) => { - if response.status().is_success() { - Ok(response) - } else { - let response_status = response.status(); - let response_body = response.text().await.unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); + self.handle_async_response(result).await + } - Err(error::ApiError { - message: format!("{}: {}", response_status, response_body), - }) - } - } - Err(error) => Err(error::ApiError { - message: error.to_string(), - }), - } + async fn post_async_empty(&self, path: &str) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + + let request = self.build_request_async(reqwest_client.post(url)); + let result = request.send().await; + self.handle_async_response(result).await } async fn post_stream( @@ -615,22 +490,70 @@ impl Client { debug!("Request URL: {}", url); utils::debug_pretty_json_from_struct("Request Body", params); - let request_builder = reqwest_client.post(url).json(params); - let request = self.build_request_stream(request_builder); - + let request = self.build_request_stream(reqwest_client.post(url).json(params)); let result = request.send().await; + self.handle_async_response(result).await + } + + fn delete_sync(&self, path: &str) -> Result { + let reqwest_client = reqwest::blocking::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + + let request = self.build_request_sync(reqwest_client.delete(url)); + let result = request.send(); + self.handle_sync_response(result) + } + + async fn delete_async(&self, path: &str) -> Result { + let reqwest_client = reqwest::Client::new(); + let url = format!("{}{}", self.endpoint, path); + debug!("Request URL: {}", url); + + let request = self.build_request_async(reqwest_client.delete(url)); + let result = request.send().await; + self.handle_async_response(result).await + } + + fn handle_sync_response( + &self, + result: Result, + ) -> Result { match result { Ok(response) => { if response.status().is_success() { Ok(response) } else { - let response_status = response.status(); - let response_body = response.text().await.unwrap_or_default(); - debug!("Response Status: {}", &response_status); - utils::debug_pretty_json_from_string("Response Data", &response_body); - + let status = response.status(); + let body = response.text().unwrap_or_default(); + debug!("Response Status: {}", &status); + utils::debug_pretty_json_from_string("Response Data", &body); Err(error::ApiError { - message: format!("{}: {}", response_status, response_body), + message: format!("{}: {}", status, body), + }) + } + } + Err(error) => Err(error::ApiError { + message: error.to_string(), + }), + } + } + + async fn handle_async_response( + &self, + result: Result, + ) -> Result { + match result { + Ok(response) => { + if response.status().is_success() { + Ok(response) + } else { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + debug!("Response Status: {}", &status); + utils::debug_pretty_json_from_string("Response Data", &body); + Err(error::ApiError { + message: format!("{}: {}", status, body), }) } } diff --git a/src/v1/common.rs b/src/v1/common.rs index 160073c..3597580 100644 --- a/src/v1/common.rs +++ b/src/v1/common.rs @@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ResponseUsage { pub prompt_tokens: u32, + #[serde(default)] pub completion_tokens: u32, pub total_tokens: u32, } diff --git a/src/v1/constants.rs b/src/v1/constants.rs index 52c0976..80168b7 100644 --- a/src/v1/constants.rs +++ b/src/v1/constants.rs @@ -1,35 +1,131 @@ +use std::fmt; + use serde::{Deserialize, Serialize}; pub const API_URL_BASE: &str = "https://api.mistral.ai/v1"; -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum Model { - #[serde(rename = "open-mistral-7b")] - OpenMistral7b, - #[serde(rename = "open-mixtral-8x7b")] - OpenMixtral8x7b, - #[serde(rename = "open-mixtral-8x22b")] - OpenMixtral8x22b, - #[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")] - OpenMistralNemo, - #[serde(rename = "mistral-tiny")] - MistralTiny, - #[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")] - MistralSmallLatest, - #[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")] - MistralMediumLatest, - #[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")] - MistralLargeLatest, - #[serde(rename = "mistral-large-2402")] - MistralLarge, - #[serde(rename = "codestral-latest", alias = "codestral-2405")] - CodestralLatest, - #[serde(rename = "open-codestral-mamba")] - CodestralMamba, +/// A Mistral AI model identifier. +/// +/// Use the associated constants for known models, or construct with `Model::new()` for any model string. +#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)] +#[serde(transparent)] +pub struct Model(pub String); + +impl Model { + pub fn new(id: impl Into) -> Self { + Self(id.into()) + } + + // Flagship / Premier + pub fn mistral_large_latest() -> Self { + Self::new("mistral-large-latest") + } + pub fn mistral_large_3() -> Self { + Self::new("mistral-large-3-25-12") + } + pub fn mistral_medium_latest() -> Self { + Self::new("mistral-medium-latest") + } + pub fn mistral_medium_3_1() -> Self { + Self::new("mistral-medium-3-1-25-08") + } + pub fn mistral_small_latest() -> Self { + Self::new("mistral-small-latest") + } + pub fn mistral_small_4() -> Self { + Self::new("mistral-small-4-0-26-03") + } + pub fn mistral_small_3_2() -> Self { + Self::new("mistral-small-3-2-25-06") + } + + // Ministral + pub fn ministral_3_14b() -> Self { + Self::new("ministral-3-14b-25-12") + } + pub fn ministral_3_8b() -> Self { + Self::new("ministral-3-8b-25-12") + } + pub fn ministral_3_3b() -> Self { + Self::new("ministral-3-3b-25-12") + } + + // Reasoning + pub fn magistral_medium_latest() -> Self { + Self::new("magistral-medium-latest") + } + pub fn magistral_small_latest() -> Self { + Self::new("magistral-small-latest") + } + + // Code + pub fn codestral_latest() -> Self { + Self::new("codestral-latest") + } + pub fn codestral_2508() -> Self { + Self::new("codestral-2508") + } + pub fn codestral_embed() -> Self { + Self::new("codestral-embed-25-05") + } + pub fn devstral_2() -> Self { + Self::new("devstral-2-25-12") + } + pub fn devstral_small_2() -> Self { + Self::new("devstral-small-2-25-12") + } + + // Multimodal / Vision + pub fn pixtral_large() -> Self { + Self::new("pixtral-large-2411") + } + + // Audio + pub fn voxtral_mini_transcribe() -> Self { + Self::new("voxtral-mini-transcribe-2-26-02") + } + pub fn voxtral_small() -> Self { + Self::new("voxtral-small-25-07") + } + pub fn voxtral_mini() -> Self { + Self::new("voxtral-mini-25-07") + } + + // Legacy (kept for backward compatibility) + pub fn open_mistral_nemo() -> Self { + Self::new("open-mistral-nemo") + } + + // Embedding + pub fn mistral_embed() -> Self { + Self::new("mistral-embed") + } + + // Moderation + pub fn mistral_moderation_latest() -> Self { + Self::new("mistral-moderation-26-03") + } + + // OCR + pub fn mistral_ocr_latest() -> Self { + Self::new("mistral-ocr-latest") + } } -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum EmbedModel { - #[serde(rename = "mistral-embed")] - MistralEmbed, +impl fmt::Display for Model { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl From<&str> for Model { + fn from(s: &str) -> Self { + Self(s.to_string()) + } +} + +impl From for Model { + fn from(s: String) -> Self { + Self(s) + } } diff --git a/src/v1/embedding.rs b/src/v1/embedding.rs index 7d1c2d7..080a57b 100644 --- a/src/v1/embedding.rs +++ b/src/v1/embedding.rs @@ -8,42 +8,63 @@ use crate::v1::{common, constants}; #[derive(Debug)] pub struct EmbeddingRequestOptions { pub encoding_format: Option, + pub output_dimension: Option, + pub output_dtype: Option, } impl Default for EmbeddingRequestOptions { fn default() -> Self { Self { encoding_format: None, + output_dimension: None, + output_dtype: None, } } } #[derive(Debug, Serialize, Deserialize)] pub struct EmbeddingRequest { - pub model: constants::EmbedModel, + pub model: constants::Model, pub input: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub encoding_format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_dimension: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_dtype: Option, } impl EmbeddingRequest { pub fn new( - model: constants::EmbedModel, + model: constants::Model, input: Vec, options: Option, ) -> Self { - let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default(); + let opts = options.unwrap_or_default(); Self { model, input, - encoding_format, + encoding_format: opts.encoding_format, + output_dimension: opts.output_dimension, + output_dtype: opts.output_dtype, } } } #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -#[allow(non_camel_case_types)] +#[serde(rename_all = "lowercase")] pub enum EmbeddingRequestEncodingFormat { - float, + Float, + Base64, +} + +#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] +#[serde(rename_all = "lowercase")] +pub enum EmbeddingOutputDtype { + Float, + Int8, + Uint8, + Binary, + Ubinary, } // ----------------------------------------------------------------------------- @@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat { #[derive(Clone, Debug, Deserialize, Serialize)] pub struct EmbeddingResponse { - pub id: String, pub object: String, - pub model: constants::EmbedModel, + pub model: constants::Model, pub data: Vec, pub usage: common::ResponseUsage, } diff --git a/src/v1/model_list.rs b/src/v1/model_list.rs index cd69493..d0350f2 100644 --- a/src/v1/model_list.rs +++ b/src/v1/model_list.rs @@ -15,23 +15,44 @@ pub struct ModelListData { pub id: String, pub object: String, /// Unix timestamp (in seconds). - pub created: u32, + pub created: u64, pub owned_by: String, + #[serde(skip_serializing_if = "Option::is_none")] pub root: Option, + #[serde(default)] pub archived: bool, - pub name: String, - pub description: String, - pub capabilities: ModelListDataCapabilies, - pub max_context_length: u32, + #[serde(default)] + pub name: Option, + #[serde(default)] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub capabilities: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_context_length: Option, + #[serde(default)] pub aliases: Vec, /// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`). + #[serde(skip_serializing_if = "Option::is_none")] pub deprecation: Option, } #[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ModelListDataCapabilies { +pub struct ModelListDataCapabilities { + #[serde(default)] pub completion_chat: bool, + #[serde(default)] pub completion_fim: bool, + #[serde(default)] pub function_calling: bool, + #[serde(default)] pub fine_tuning: bool, + #[serde(default)] + pub vision: bool, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct ModelDeleteResponse { + pub id: String, + pub object: String, + pub deleted: bool, } diff --git a/src/v1/tool.rs b/src/v1/tool.rs index 9612182..17fbea3 100644 --- a/src/v1/tool.rs +++ b/src/v1/tool.rs @@ -1,12 +1,16 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use std::{any::Any, collections::HashMap, fmt::Debug}; +use std::{any::Any, fmt::Debug}; // ----------------------------------------------------------------------------- // Definitions #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] pub struct ToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub r#type: Option, pub function: ToolCallFunction, } @@ -22,31 +26,12 @@ pub struct Tool { pub function: ToolFunction, } impl Tool { + /// Create a tool with a JSON Schema parameters object. pub fn new( function_name: String, function_description: String, - function_parameters: Vec, + parameters: serde_json::Value, ) -> Self { - let properties: HashMap = function_parameters - .into_iter() - .map(|param| { - ( - param.name, - ToolFunctionParameterProperty { - r#type: param.r#type, - description: param.description, - }, - ) - }) - .collect(); - let property_names = properties.keys().cloned().collect(); - - let parameters = ToolFunctionParameters { - r#type: ToolFunctionParametersType::Object, - properties, - required: property_names, - }; - Self { r#type: ToolType::Function, function: ToolFunction { @@ -63,50 +48,9 @@ impl Tool { #[derive(Clone, Debug, Deserialize, Serialize)] pub struct ToolFunction { - name: String, - description: String, - parameters: ToolFunctionParameters, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ToolFunctionParameter { - name: String, - description: String, - r#type: ToolFunctionParameterType, -} -impl ToolFunctionParameter { - pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self { - Self { - name, - r#type, - description, - } - } -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ToolFunctionParameters { - r#type: ToolFunctionParametersType, - properties: HashMap, - required: Vec, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct ToolFunctionParameterProperty { - r#type: ToolFunctionParameterType, - description: String, -} - -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum ToolFunctionParametersType { - #[serde(rename = "object")] - Object, -} - -#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] -pub enum ToolFunctionParameterType { - #[serde(rename = "string")] - String, + pub name: String, + pub description: String, + pub parameters: serde_json::Value, } #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] @@ -127,6 +71,9 @@ pub enum ToolChoice { /// The model won't call a function and will generate a message instead. #[serde(rename = "none")] None, + /// The model must call at least one tool. + #[serde(rename = "required")] + Required, } // -----------------------------------------------------------------------------