refactor!: modernize core types and client for latest Mistral API

BREAKING CHANGE: Model is now a string-based struct with constructor
methods instead of a closed enum. EmbedModel is removed — use
Model::mistral_embed() instead. Tool parameters now accept
serde_json::Value (JSON Schema) instead of limited enum types.

- Replace Model enum with flexible Model(String) supporting all
  current models: Large 3, Small 4, Medium 3.1, Magistral, Codestral,
  Devstral, Pixtral, Voxtral, Ministral, and arbitrary strings
- Remove EmbedModel enum (consolidated into Model)
- Chat: add frequency_penalty, presence_penalty, stop, n, min_tokens,
  parallel_tool_calls, reasoning_effort, json_schema response format
- Embeddings: add output_dimension and output_dtype fields
- Tools: accept raw JSON Schema, add tool call IDs and Required choice
- Stream delta content is now Option<String> for tool call chunks
- Add Length, ModelLength, Error finish reason variants
- DRY HTTP transport with shared response handlers
- Add DELETE method support and model get/delete endpoints
- Make model_list fields more lenient with Option/default for API compat
This commit is contained in:
2026-03-20 17:54:29 +00:00
parent 83396773ce
commit bbb6aaed1c
8 changed files with 518 additions and 472 deletions

View File

@@ -11,13 +11,31 @@ pub struct ChatMessage {
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>,
/// Tool call ID, required when role is Tool.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// Function name, used when role is Tool.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ChatMessage {
pub fn new_system_message(content: &str) -> Self {
Self {
role: ChatMessageRole::System,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
Self {
role: ChatMessageRole::Assistant,
content: content.to_string(),
tool_calls,
tool_call_id: None,
name: None,
}
}
@@ -26,6 +44,18 @@ impl ChatMessage {
role: ChatMessageRole::User,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
Self {
role: ChatMessageRole::Tool,
content: content.to_string(),
tool_calls: None,
tool_call_id: Some(tool_call_id.to_string()),
name: name.map(|n| n.to_string()),
}
}
}
@@ -44,17 +74,32 @@ pub enum ChatMessageRole {
}
/// The format that the model must output.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub type_: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<serde_json::Value>,
}
impl ResponseFormat {
pub fn text() -> Self {
Self {
type_: "text".to_string(),
json_schema: None,
}
}
pub fn json_object() -> Self {
Self {
type_: "json_object".to_string(),
json_schema: None,
}
}
pub fn json_schema(schema: serde_json::Value) -> Self {
Self {
type_: "json_schema".to_string(),
json_schema: Some(schema),
}
}
}
@@ -63,91 +108,83 @@ impl ResponseFormat {
// Request
/// The parameters for the chat request.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Clone, Debug)]
pub struct ChatParams {
/// The maximum number of tokens to generate in the completion.
///
/// Defaults to `None`.
pub max_tokens: Option<u32>,
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
///
/// Defaults to `None`.
pub min_tokens: Option<u32>,
pub random_seed: Option<u32>,
/// The format that the model must output.
///
/// Defaults to `None`.
pub response_format: Option<ResponseFormat>,
/// Whether to inject a safety prompt before all conversations.
///
/// Defaults to `false`.
pub safe_prompt: bool,
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
///
/// Defaults to `0.7`.
pub temperature: f32,
/// Specifies if/how functions are called.
///
/// Defaults to `None`.
pub temperature: Option<f32>,
pub tool_choice: Option<tool::ToolChoice>,
/// A list of available tools for the model.
///
/// Defaults to `None`.
pub tools: Option<Vec<tool::Tool>>,
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
///
/// Defaults to `1.0`.
pub top_p: f32,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
pub n: Option<u32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
/// For reasoning models (Magistral). "high" or "none".
pub reasoning_effort: Option<String>,
}
impl Default for ChatParams {
fn default() -> Self {
Self {
max_tokens: None,
min_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
temperature: None,
tool_choice: None,
tools: None,
top_p: 1.0,
}
}
}
impl ChatParams {
pub fn json_default() -> Self {
Self {
max_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
tool_choice: None,
tools: None,
top_p: 1.0,
top_p: None,
stop: None,
n: None,
frequency_penalty: None,
presence_penalty: None,
parallel_tool_calls: None,
reasoning_effort: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatRequest {
pub messages: Vec<ChatMessage>,
pub model: constants::Model,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
pub safe_prompt: bool,
pub stream: bool,
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub safe_prompt: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<tool::ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<tool::Tool>>,
pub top_p: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
}
impl ChatRequest {
pub fn new(
@@ -156,30 +193,28 @@ impl ChatRequest {
stream: bool,
options: Option<ChatParams>,
) -> Self {
let ChatParams {
max_tokens,
random_seed,
safe_prompt,
temperature,
tool_choice,
tools,
top_p,
response_format,
} = options.unwrap_or_default();
let opts = options.unwrap_or_default();
let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
Self {
messages,
model,
max_tokens,
random_seed,
safe_prompt,
messages,
stream,
temperature,
tool_choice,
tools,
top_p,
response_format,
max_tokens: opts.max_tokens,
min_tokens: opts.min_tokens,
random_seed: opts.random_seed,
safe_prompt,
temperature: opts.temperature,
tool_choice: opts.tool_choice,
tools: opts.tools,
top_p: opts.top_p,
response_format: opts.response_format,
stop: opts.stop,
n: opts.n,
frequency_penalty: opts.frequency_penalty,
presence_penalty: opts.presence_penalty,
parallel_tool_calls: opts.parallel_tool_calls,
reasoning_effort: opts.reasoning_effort,
}
}
}
@@ -192,7 +227,7 @@ pub struct ChatResponse {
pub id: String,
pub object: String,
/// Unix timestamp (in seconds).
pub created: u32,
pub created: u64,
pub model: constants::Model,
pub choices: Vec<ChatResponseChoice>,
pub usage: common::ResponseUsage,
@@ -203,14 +238,18 @@ pub struct ChatResponseChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: ChatResponseChoiceFinishReason,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ChatResponseChoiceFinishReason {
#[serde(rename = "stop")]
Stop,
#[serde(rename = "length")]
Length,
#[serde(rename = "tool_calls")]
ToolCalls,
#[serde(rename = "model_length")]
ModelLength,
#[serde(rename = "error")]
Error,
}

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use crate::v1::{chat, common, constants, error};
use crate::v1::{chat, common, constants, error, tool};
// -----------------------------------------------------------------------------
// Response
@@ -11,12 +11,11 @@ pub struct ChatStreamChunk {
pub id: String,
pub object: String,
/// Unix timestamp (in seconds).
pub created: u32,
pub created: u64,
pub model: constants::Model,
pub choices: Vec<ChatStreamChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<common::ResponseUsage>,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
@@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice {
pub index: u32,
pub delta: ChatStreamChunkChoiceDelta,
pub finish_reason: Option<String>,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ChatStreamChunkChoiceDelta {
pub role: Option<chat::ChatMessageRole>,
pub content: String,
#[serde(default)]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>,
}
/// Extracts serialized chunks from a stream message.
@@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line(
return Ok(Some(vec![]));
}
// Attempt to deserialize the JSON string into ChatStreamChunk
match from_str::<ChatStreamChunk>(chunk_as_json) {
Ok(chunk) => Ok(Some(vec![chunk])),
Err(e) => Err(error::ApiError {

View File

@@ -26,25 +26,10 @@ impl Client {
///
/// # Arguments
///
/// * `api_key` - An optional API key.
/// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable.
/// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided.
/// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`.
/// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`.
///
/// # Examples
///
/// ```
/// use mistralai_client::v1::client::Client;
///
/// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60));
/// assert!(client.is_ok());
/// ```
///
/// # Errors
///
/// This method fails whenever neither the `api_key` is provided
/// nor the `MISTRAL_API_KEY` environment variable is set.
/// * `api_key` - An optional API key. If not provided, uses `MISTRAL_API_KEY` env var.
/// * `endpoint` - An optional custom API endpoint. Defaults to `https://api.mistral.ai/v1`.
/// * `max_retries` - Optional maximum number of retries. Defaults to `5`.
/// * `timeout` - Optional timeout in seconds. Defaults to `120`.
pub fn new(
api_key: Option<String>,
endpoint: Option<String>,
@@ -69,43 +54,15 @@ impl Client {
endpoint,
max_retries,
timeout,
functions,
last_function_call_result,
})
}
/// Synchronously sends a chat completion request and returns the response.
///
/// # Arguments
///
/// * `model` - The [Model] to use for the chat completion.
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
/// * `options` - Optional [ChatParams] to customize the request.
///
/// # Returns
///
/// Returns a [Result] containing the `ChatResponse` if the request is successful,
/// or an [ApiError] if there is an error.
///
/// # Examples
///
/// ```
/// use mistralai_client::v1::{
/// chat::{ChatMessage, ChatMessageRole},
/// client::Client,
/// constants::Model,
/// };
///
/// let client = Client::new(None, None, None, None).unwrap();
/// let messages = vec![ChatMessage {
/// role: ChatMessageRole::User,
/// content: "Hello, world!".to_string(),
/// tool_calls: None,
/// }];
/// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap();
/// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
/// ```
// =========================================================================
// Chat Completions
// =========================================================================
pub fn chat(
&self,
model: constants::Model,
@@ -119,49 +76,13 @@ impl Client {
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
self.call_function_if_any(data.clone());
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
}
}
/// Asynchronously sends a chat completion request and returns the response.
///
/// # Arguments
///
/// * `model` - The [Model] to use for the chat completion.
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
/// * `options` - Optional [ChatParams] to customize the request.
///
/// # Returns
///
/// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
/// or an [ApiError] if there is an error.
///
/// # Examples
///
/// ```
/// use mistralai_client::v1::{
/// chat::{ChatMessage, ChatMessageRole},
/// client::Client,
/// constants::Model,
/// };
///
/// #[tokio::main]
/// async fn main() {
/// let client = Client::new(None, None, None, None).unwrap();
/// let messages = vec![ChatMessage {
/// role: ChatMessageRole::User,
/// content: "Hello, world!".to_string(),
/// tool_calls: None,
/// }];
/// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap();
/// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
/// }
/// ```
pub async fn chat_async(
&self,
model: constants::Model,
@@ -175,68 +96,13 @@ impl Client {
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
self.call_function_if_any_async(data.clone()).await;
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
}
}
/// Asynchronously sends a chat completion request and returns a stream of message chunks.
///
/// # Arguments
///
/// * `model` - The [Model] to use for the chat completion.
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
/// * `options` - Optional [ChatParams] to customize the request.
///
/// # Returns
///
/// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
/// or an [ApiError] if there is an error.
///
/// # Examples
///
/// ```
/// use futures::stream::StreamExt;
/// use mistralai_client::v1::{
/// chat::{ChatMessage, ChatMessageRole},
/// client::Client,
/// constants::Model,
/// };
/// use std::io::{self, Write};
///
/// #[tokio::main]
/// async fn main() {
/// let client = Client::new(None, None, None, None).unwrap();
/// let messages = vec![ChatMessage {
/// role: ChatMessageRole::User,
/// content: "Hello, world!".to_string(),
/// tool_calls: None,
/// }];
///
/// let stream_result = client
/// .chat_stream(Model::OpenMistral7b,messages, None)
/// .await
/// .unwrap();
/// stream_result
/// .for_each(|chunk_result| async {
/// match chunk_result {
/// Ok(chunks) => chunks.iter().for_each(|chunk| {
/// print!("{}", chunk.choices[0].delta.content);
/// io::stdout().flush().unwrap();
/// // => "Once upon a time, [...]"
/// }),
/// Err(error) => {
/// eprintln!("Error processing chunk: {:?}", error)
/// }
/// }
/// })
/// .await;
/// print!("\n") // To persist the last chunk output.
/// }
pub async fn chat_stream(
&self,
model: constants::Model,
@@ -292,9 +158,13 @@ impl Client {
Ok(deserialized_stream)
}
// =========================================================================
// Embeddings
// =========================================================================
pub fn embeddings(
&self,
model: constants::EmbedModel,
model: constants::Model,
input: Vec<String>,
options: Option<embedding::EmbeddingRequestOptions>,
) -> Result<embedding::EmbeddingResponse, error::ApiError> {
@@ -305,7 +175,6 @@ impl Client {
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
@@ -314,7 +183,7 @@ impl Client {
pub async fn embeddings_async(
&self,
model: constants::EmbedModel,
model: constants::Model,
input: Vec<String>,
options: Option<embedding::EmbeddingRequestOptions>,
) -> Result<embedding::EmbeddingResponse, error::ApiError> {
@@ -325,18 +194,15 @@ impl Client {
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
}
}
pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
let mut result_lock = self.last_function_call_result.lock().unwrap();
result_lock.take()
}
// =========================================================================
// Models
// =========================================================================
pub fn list_models(&self) -> Result<model_list::ModelListResponse, error::ApiError> {
let response = self.get_sync("/models")?;
@@ -344,7 +210,6 @@ impl Client {
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
@@ -359,68 +224,136 @@ impl Client {
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
}
}
pub fn get_model(&self, model_id: &str) -> Result<model_list::ModelListData, error::ApiError> {
let response = self.get_sync(&format!("/models/{}", model_id))?;
let result = response.json::<model_list::ModelListData>();
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
}
}
pub async fn get_model_async(
&self,
model_id: &str,
) -> Result<model_list::ModelListData, error::ApiError> {
let response = self.get_async(&format!("/models/{}", model_id)).await?;
let result = response.json::<model_list::ModelListData>().await;
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
}
}
pub fn delete_model(
&self,
model_id: &str,
) -> Result<model_list::ModelDeleteResponse, error::ApiError> {
let response = self.delete_sync(&format!("/models/{}", model_id))?;
let result = response.json::<model_list::ModelDeleteResponse>();
match result {
Ok(data) => Ok(data),
Err(error) => Err(self.to_api_error(error)),
}
}
pub async fn delete_model_async(
&self,
model_id: &str,
) -> Result<model_list::ModelDeleteResponse, error::ApiError> {
let response = self
.delete_async(&format!("/models/{}", model_id))
.await?;
let result = response.json::<model_list::ModelDeleteResponse>().await;
match result {
Ok(data) => Ok(data),
Err(error) => Err(self.to_api_error(error)),
}
}
// =========================================================================
// Function Calling
// =========================================================================
pub fn register_function(&mut self, name: String, function: Box<dyn tool::Function>) {
let mut functions = self.functions.lock().unwrap();
functions.insert(name, function);
}
pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
let mut result_lock = self.last_function_call_result.lock().unwrap();
result_lock.take()
}
// =========================================================================
// HTTP Transport
// =========================================================================
fn user_agent(&self) -> String {
format!(
"mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
)
}
fn build_request_sync(
&self,
request: reqwest::blocking::RequestBuilder,
) -> reqwest::blocking::RequestBuilder {
let user_agent = format!(
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
request
.bearer_auth(&self.api_key)
.header("Accept", "application/json")
.header("User-Agent", user_agent);
.header("User-Agent", self.user_agent())
}
request_builder
fn build_request_sync_no_accept(
&self,
request: reqwest::blocking::RequestBuilder,
) -> reqwest::blocking::RequestBuilder {
request
.bearer_auth(&self.api_key)
.header("User-Agent", self.user_agent())
}
fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let user_agent = format!(
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
request
.bearer_auth(&self.api_key)
.header("Accept", "application/json")
.header("User-Agent", user_agent);
.header("User-Agent", self.user_agent())
}
request_builder
fn build_request_async_no_accept(
&self,
request: reqwest::RequestBuilder,
) -> reqwest::RequestBuilder {
request
.bearer_auth(&self.api_key)
.header("User-Agent", self.user_agent())
}
fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let user_agent = format!(
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
request
.bearer_auth(&self.api_key)
.header("Accept", "text/event-stream")
.header("User-Agent", user_agent);
request_builder
.header("User-Agent", self.user_agent())
}
fn call_function_if_any(&self, response: chat::ChatResponse) -> () {
let next_result = match response.choices.get(0) {
Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
Some(tool_calls) => match tool_calls.get(0) {
fn call_function_if_any(&self, response: chat::ChatResponse) {
let next_result = match response.choices.first() {
Some(first_choice) => match first_choice.message.tool_calls.as_ref() {
Some(tool_calls) => match tool_calls.first() {
Some(first_tool_call) => {
let functions = self.functions.lock().unwrap();
match functions.get(&first_tool_call.function.name) {
@@ -431,7 +364,6 @@ impl Client {
.execute(first_tool_call.function.arguments.to_owned())
.await
});
Some(result)
}
None => None,
@@ -448,10 +380,10 @@ impl Client {
*last_result_lock = next_result;
}
async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () {
let next_result = match response.choices.get(0) {
Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
Some(tool_calls) => match tool_calls.get(0) {
async fn call_function_if_any_async(&self, response: chat::ChatResponse) {
let next_result = match response.choices.first() {
Some(first_choice) => match first_choice.message.tool_calls.as_ref() {
Some(tool_calls) => match tool_calls.first() {
Some(first_tool_call) => {
let functions = self.functions.lock().unwrap();
match functions.get(&first_tool_call.function.name) {
@@ -459,7 +391,6 @@ impl Client {
let result = function
.execute(first_tool_call.function.arguments.to_owned())
.await;
Some(result)
}
None => None,
@@ -482,27 +413,8 @@ impl Client {
debug!("Request URL: {}", url);
let request = self.build_request_sync(reqwest_client.get(url));
let result = request.send();
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
Err(error::ApiError {
message: format!("{}: {}", response_status, response_body),
})
}
}
Err(error) => Err(error::ApiError {
message: error.to_string(),
}),
}
self.handle_sync_response(result)
}
async fn get_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
@@ -510,29 +422,9 @@ impl Client {
let url = format!("{}{}", self.endpoint, path);
debug!("Request URL: {}", url);
let request_builder = reqwest_client.get(url);
let request = self.build_request_async(request_builder);
let request = self.build_request_async(reqwest_client.get(url));
let result = request.send().await;
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().await.unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
Err(error::ApiError {
message: format!("{}: {}", response_status, response_body),
})
}
}
Err(error) => Err(error::ApiError {
message: error.to_string(),
}),
}
self.handle_async_response(result).await
}
fn post_sync<T: std::fmt::Debug + serde::ser::Serialize>(
@@ -545,29 +437,22 @@ impl Client {
debug!("Request URL: {}", url);
utils::debug_pretty_json_from_struct("Request Body", params);
let request_builder = reqwest_client.post(url).json(params);
let request = self.build_request_sync(request_builder);
let request = self.build_request_sync(reqwest_client.post(url).json(params));
let result = request.send();
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
self.handle_sync_response(result)
}
Err(error::ApiError {
message: format!("{}: {}", response_body, response_status),
})
}
}
Err(error) => Err(error::ApiError {
message: error.to_string(),
}),
}
fn post_sync_empty(
&self,
path: &str,
) -> Result<reqwest::blocking::Response, error::ApiError> {
let reqwest_client = reqwest::blocking::Client::new();
let url = format!("{}{}", self.endpoint, path);
debug!("Request URL: {}", url);
let request = self.build_request_sync(reqwest_client.post(url));
let result = request.send();
self.handle_sync_response(result)
}
async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
@@ -580,29 +465,19 @@ impl Client {
debug!("Request URL: {}", url);
utils::debug_pretty_json_from_struct("Request Body", params);
let request_builder = reqwest_client.post(url).json(params);
let request = self.build_request_async(request_builder);
let request = self.build_request_async(reqwest_client.post(url).json(params));
let result = request.send().await;
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().await.unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
self.handle_async_response(result).await
}
Err(error::ApiError {
message: format!("{}: {}", response_status, response_body),
})
}
}
Err(error) => Err(error::ApiError {
message: error.to_string(),
}),
}
async fn post_async_empty(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
let reqwest_client = reqwest::Client::new();
let url = format!("{}{}", self.endpoint, path);
debug!("Request URL: {}", url);
let request = self.build_request_async(reqwest_client.post(url));
let result = request.send().await;
self.handle_async_response(result).await
}
async fn post_stream<T: serde::ser::Serialize + std::fmt::Debug>(
@@ -615,22 +490,70 @@ impl Client {
debug!("Request URL: {}", url);
utils::debug_pretty_json_from_struct("Request Body", params);
let request_builder = reqwest_client.post(url).json(params);
let request = self.build_request_stream(request_builder);
let request = self.build_request_stream(reqwest_client.post(url).json(params));
let result = request.send().await;
self.handle_async_response(result).await
}
fn delete_sync(&self, path: &str) -> Result<reqwest::blocking::Response, error::ApiError> {
let reqwest_client = reqwest::blocking::Client::new();
let url = format!("{}{}", self.endpoint, path);
debug!("Request URL: {}", url);
let request = self.build_request_sync(reqwest_client.delete(url));
let result = request.send();
self.handle_sync_response(result)
}
async fn delete_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
let reqwest_client = reqwest::Client::new();
let url = format!("{}{}", self.endpoint, path);
debug!("Request URL: {}", url);
let request = self.build_request_async(reqwest_client.delete(url));
let result = request.send().await;
self.handle_async_response(result).await
}
fn handle_sync_response(
&self,
result: Result<reqwest::blocking::Response, reqwest::Error>,
) -> Result<reqwest::blocking::Response, error::ApiError> {
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().await.unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
let status = response.status();
let body = response.text().unwrap_or_default();
debug!("Response Status: {}", &status);
utils::debug_pretty_json_from_string("Response Data", &body);
Err(error::ApiError {
message: format!("{}: {}", response_status, response_body),
message: format!("{}: {}", status, body),
})
}
}
Err(error) => Err(error::ApiError {
message: error.to_string(),
}),
}
}
async fn handle_async_response(
&self,
result: Result<reqwest::Response, reqwest::Error>,
) -> Result<reqwest::Response, error::ApiError> {
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let body = response.text().await.unwrap_or_default();
debug!("Response Status: {}", &status);
utils::debug_pretty_json_from_string("Response Data", &body);
Err(error::ApiError {
message: format!("{}: {}", status, body),
})
}
}

View File

@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ResponseUsage {
pub prompt_tokens: u32,
#[serde(default)]
pub completion_tokens: u32,
pub total_tokens: u32,
}

View File

@@ -1,35 +1,131 @@
use std::fmt;
use serde::{Deserialize, Serialize};
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum Model {
#[serde(rename = "open-mistral-7b")]
OpenMistral7b,
#[serde(rename = "open-mixtral-8x7b")]
OpenMixtral8x7b,
#[serde(rename = "open-mixtral-8x22b")]
OpenMixtral8x22b,
#[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")]
OpenMistralNemo,
#[serde(rename = "mistral-tiny")]
MistralTiny,
#[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")]
MistralSmallLatest,
#[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")]
MistralMediumLatest,
#[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")]
MistralLargeLatest,
#[serde(rename = "mistral-large-2402")]
MistralLarge,
#[serde(rename = "codestral-latest", alias = "codestral-2405")]
CodestralLatest,
#[serde(rename = "open-codestral-mamba")]
CodestralMamba,
/// A Mistral AI model identifier.
///
/// Use the associated constants for known models, or construct with `Model::new()` for any model string.
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Model(pub String);
impl Model {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
// Flagship / Premier
pub fn mistral_large_latest() -> Self {
Self::new("mistral-large-latest")
}
pub fn mistral_large_3() -> Self {
Self::new("mistral-large-3-25-12")
}
pub fn mistral_medium_latest() -> Self {
Self::new("mistral-medium-latest")
}
pub fn mistral_medium_3_1() -> Self {
Self::new("mistral-medium-3-1-25-08")
}
pub fn mistral_small_latest() -> Self {
Self::new("mistral-small-latest")
}
pub fn mistral_small_4() -> Self {
Self::new("mistral-small-4-0-26-03")
}
pub fn mistral_small_3_2() -> Self {
Self::new("mistral-small-3-2-25-06")
}
// Ministral
pub fn ministral_3_14b() -> Self {
Self::new("ministral-3-14b-25-12")
}
pub fn ministral_3_8b() -> Self {
Self::new("ministral-3-8b-25-12")
}
pub fn ministral_3_3b() -> Self {
Self::new("ministral-3-3b-25-12")
}
// Reasoning
pub fn magistral_medium_latest() -> Self {
Self::new("magistral-medium-latest")
}
pub fn magistral_small_latest() -> Self {
Self::new("magistral-small-latest")
}
// Code
pub fn codestral_latest() -> Self {
Self::new("codestral-latest")
}
pub fn codestral_2508() -> Self {
Self::new("codestral-2508")
}
pub fn codestral_embed() -> Self {
Self::new("codestral-embed-25-05")
}
pub fn devstral_2() -> Self {
Self::new("devstral-2-25-12")
}
pub fn devstral_small_2() -> Self {
Self::new("devstral-small-2-25-12")
}
// Multimodal / Vision
pub fn pixtral_large() -> Self {
Self::new("pixtral-large-2411")
}
// Audio
pub fn voxtral_mini_transcribe() -> Self {
Self::new("voxtral-mini-transcribe-2-26-02")
}
pub fn voxtral_small() -> Self {
Self::new("voxtral-small-25-07")
}
pub fn voxtral_mini() -> Self {
Self::new("voxtral-mini-25-07")
}
// Legacy (kept for backward compatibility)
pub fn open_mistral_nemo() -> Self {
Self::new("open-mistral-nemo")
}
// Embedding
pub fn mistral_embed() -> Self {
Self::new("mistral-embed")
}
// Moderation
pub fn mistral_moderation_latest() -> Self {
Self::new("mistral-moderation-26-03")
}
// OCR
pub fn mistral_ocr_latest() -> Self {
Self::new("mistral-ocr-latest")
}
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum EmbedModel {
#[serde(rename = "mistral-embed")]
MistralEmbed,
impl fmt::Display for Model {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for Model {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for Model {
fn from(s: String) -> Self {
Self(s)
}
}

View File

@@ -8,42 +8,63 @@ use crate::v1::{common, constants};
#[derive(Debug)]
pub struct EmbeddingRequestOptions {
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
pub output_dimension: Option<u32>,
pub output_dtype: Option<EmbeddingOutputDtype>,
}
impl Default for EmbeddingRequestOptions {
fn default() -> Self {
Self {
encoding_format: None,
output_dimension: None,
output_dtype: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub model: constants::EmbedModel,
pub model: constants::Model,
pub input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dimension: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dtype: Option<EmbeddingOutputDtype>,
}
impl EmbeddingRequest {
pub fn new(
model: constants::EmbedModel,
model: constants::Model,
input: Vec<String>,
options: Option<EmbeddingRequestOptions>,
) -> Self {
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default();
let opts = options.unwrap_or_default();
Self {
model,
input,
encoding_format,
encoding_format: opts.encoding_format,
output_dimension: opts.output_dimension,
output_dtype: opts.output_dtype,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[allow(non_camel_case_types)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingRequestEncodingFormat {
float,
Float,
Base64,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingOutputDtype {
Float,
Int8,
Uint8,
Binary,
Ubinary,
}
// -----------------------------------------------------------------------------
@@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat {
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EmbeddingResponse {
pub id: String,
pub object: String,
pub model: constants::EmbedModel,
pub model: constants::Model,
pub data: Vec<EmbeddingResponseDataItem>,
pub usage: common::ResponseUsage,
}

View File

@@ -15,23 +15,44 @@ pub struct ModelListData {
pub id: String,
pub object: String,
/// Unix timestamp (in seconds).
pub created: u32,
pub created: u64,
pub owned_by: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub root: Option<String>,
#[serde(default)]
pub archived: bool,
pub name: String,
pub description: String,
pub capabilities: ModelListDataCapabilies,
pub max_context_length: u32,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub capabilities: Option<ModelListDataCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_context_length: Option<u32>,
#[serde(default)]
pub aliases: Vec<String>,
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
#[serde(skip_serializing_if = "Option::is_none")]
pub deprecation: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelListDataCapabilies {
pub struct ModelListDataCapabilities {
#[serde(default)]
pub completion_chat: bool,
#[serde(default)]
pub completion_fim: bool,
#[serde(default)]
pub function_calling: bool,
#[serde(default)]
pub fine_tuning: bool,
#[serde(default)]
pub vision: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelDeleteResponse {
pub id: String,
pub object: String,
pub deleted: bool,
}

View File

@@ -1,12 +1,16 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::{any::Any, collections::HashMap, fmt::Debug};
use std::{any::Any, fmt::Debug};
// -----------------------------------------------------------------------------
// Definitions
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
pub function: ToolCallFunction,
}
@@ -22,31 +26,12 @@ pub struct Tool {
pub function: ToolFunction,
}
impl Tool {
/// Create a tool with a JSON Schema parameters object.
pub fn new(
function_name: String,
function_description: String,
function_parameters: Vec<ToolFunctionParameter>,
parameters: serde_json::Value,
) -> Self {
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
.into_iter()
.map(|param| {
(
param.name,
ToolFunctionParameterProperty {
r#type: param.r#type,
description: param.description,
},
)
})
.collect();
let property_names = properties.keys().cloned().collect();
let parameters = ToolFunctionParameters {
r#type: ToolFunctionParametersType::Object,
properties,
required: property_names,
};
Self {
r#type: ToolType::Function,
function: ToolFunction {
@@ -63,50 +48,9 @@ impl Tool {
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunction {
name: String,
description: String,
parameters: ToolFunctionParameters,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameter {
name: String,
description: String,
r#type: ToolFunctionParameterType,
}
impl ToolFunctionParameter {
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
Self {
name,
r#type,
description,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameters {
r#type: ToolFunctionParametersType,
properties: HashMap<String, ToolFunctionParameterProperty>,
required: Vec<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameterProperty {
r#type: ToolFunctionParameterType,
description: String,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParametersType {
#[serde(rename = "object")]
Object,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParameterType {
#[serde(rename = "string")]
String,
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
@@ -127,6 +71,9 @@ pub enum ToolChoice {
/// The model won't call a function and will generate a message instead.
#[serde(rename = "none")]
None,
/// The model must call at least one tool.
#[serde(rename = "required")]
Required,
}
// -----------------------------------------------------------------------------