refactor!: modernize core types and client for latest Mistral API

BREAKING CHANGE: Model is now a string-based struct with constructor
methods instead of a closed enum. EmbedModel is removed — use
Model::mistral_embed() instead. Tool parameters now accept
serde_json::Value (JSON Schema) instead of limited enum types.

- Replace Model enum with flexible Model(String) supporting all
  current models: Large 3, Small 4, Medium 3.1, Magistral, Codestral,
  Devstral, Pixtral, Voxtral, Ministral, and arbitrary strings
- Remove EmbedModel enum (consolidated into Model)
- Chat: add frequency_penalty, presence_penalty, stop, n, min_tokens,
  parallel_tool_calls, reasoning_effort, json_schema response format
- Embeddings: add output_dimension and output_dtype fields
- Tools: accept raw JSON Schema, add tool call IDs and Required choice
- Stream delta content is now Option<String> for tool call chunks
- Add Length, ModelLength, Error finish reason variants
- DRY HTTP transport with shared response handlers
- Add DELETE method support and model get/delete endpoints
- Make model_list fields more lenient with Option/default for API compat
This commit is contained in:
2026-03-20 17:54:29 +00:00
parent 83396773ce
commit bbb6aaed1c
8 changed files with 518 additions and 472 deletions

View File

@@ -11,13 +11,31 @@ pub struct ChatMessage {
pub content: String, pub content: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>, pub tool_calls: Option<Vec<tool::ToolCall>>,
/// Tool call ID, required when role is Tool.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// Function name, used when role is Tool.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
} }
impl ChatMessage { impl ChatMessage {
pub fn new_system_message(content: &str) -> Self {
Self {
role: ChatMessageRole::System,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self { pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
Self { Self {
role: ChatMessageRole::Assistant, role: ChatMessageRole::Assistant,
content: content.to_string(), content: content.to_string(),
tool_calls, tool_calls,
tool_call_id: None,
name: None,
} }
} }
@@ -26,6 +44,18 @@ impl ChatMessage {
role: ChatMessageRole::User, role: ChatMessageRole::User,
content: content.to_string(), content: content.to_string(),
tool_calls: None, tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
Self {
role: ChatMessageRole::Tool,
content: content.to_string(),
tool_calls: None,
tool_call_id: Some(tool_call_id.to_string()),
name: name.map(|n| n.to_string()),
} }
} }
} }
@@ -44,17 +74,32 @@ pub enum ChatMessageRole {
} }
/// The format that the model must output. /// The format that the model must output.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResponseFormat { pub struct ResponseFormat {
#[serde(rename = "type")] #[serde(rename = "type")]
pub type_: String, pub type_: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<serde_json::Value>,
} }
impl ResponseFormat { impl ResponseFormat {
pub fn text() -> Self {
Self {
type_: "text".to_string(),
json_schema: None,
}
}
pub fn json_object() -> Self { pub fn json_object() -> Self {
Self { Self {
type_: "json_object".to_string(), type_: "json_object".to_string(),
json_schema: None,
}
}
pub fn json_schema(schema: serde_json::Value) -> Self {
Self {
type_: "json_schema".to_string(),
json_schema: Some(schema),
} }
} }
} }
@@ -63,91 +108,83 @@ impl ResponseFormat {
// Request // Request
/// The parameters for the chat request. /// The parameters for the chat request.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ChatParams { pub struct ChatParams {
/// The maximum number of tokens to generate in the completion.
///
/// Defaults to `None`.
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
/// The seed to use for random sampling. If set, different calls will generate deterministic results. pub min_tokens: Option<u32>,
///
/// Defaults to `None`.
pub random_seed: Option<u32>, pub random_seed: Option<u32>,
/// The format that the model must output.
///
/// Defaults to `None`.
pub response_format: Option<ResponseFormat>, pub response_format: Option<ResponseFormat>,
/// Whether to inject a safety prompt before all conversations.
///
/// Defaults to `false`.
pub safe_prompt: bool, pub safe_prompt: bool,
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`. pub temperature: Option<f32>,
///
/// Defaults to `0.7`.
pub temperature: f32,
/// Specifies if/how functions are called.
///
/// Defaults to `None`.
pub tool_choice: Option<tool::ToolChoice>, pub tool_choice: Option<tool::ToolChoice>,
/// A list of available tools for the model.
///
/// Defaults to `None`.
pub tools: Option<Vec<tool::Tool>>, pub tools: Option<Vec<tool::Tool>>,
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass. pub top_p: Option<f32>,
/// pub stop: Option<Vec<String>>,
/// Defaults to `1.0`. pub n: Option<u32>,
pub top_p: f32, pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
/// For reasoning models (Magistral). "high" or "none".
pub reasoning_effort: Option<String>,
} }
impl Default for ChatParams { impl Default for ChatParams {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_tokens: None, max_tokens: None,
min_tokens: None,
random_seed: None, random_seed: None,
safe_prompt: false, safe_prompt: false,
response_format: None, response_format: None,
temperature: 0.7, temperature: None,
tool_choice: None, tool_choice: None,
tools: None, tools: None,
top_p: 1.0, top_p: None,
} stop: None,
} n: None,
} frequency_penalty: None,
impl ChatParams { presence_penalty: None,
pub fn json_default() -> Self { parallel_tool_calls: None,
Self { reasoning_effort: None,
max_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
tool_choice: None,
tools: None,
top_p: 1.0,
} }
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct ChatRequest { pub struct ChatRequest {
pub messages: Vec<ChatMessage>,
pub model: constants::Model, pub model: constants::Model,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>, pub random_seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>, pub response_format: Option<ResponseFormat>,
pub safe_prompt: bool, #[serde(skip_serializing_if = "Option::is_none")]
pub stream: bool, pub safe_prompt: Option<bool>,
pub temperature: f32, #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<tool::ToolChoice>, pub tool_choice: Option<tool::ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<tool::Tool>>, pub tools: Option<Vec<tool::Tool>>,
pub top_p: f32, #[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
} }
impl ChatRequest { impl ChatRequest {
pub fn new( pub fn new(
@@ -156,30 +193,28 @@ impl ChatRequest {
stream: bool, stream: bool,
options: Option<ChatParams>, options: Option<ChatParams>,
) -> Self { ) -> Self {
let ChatParams { let opts = options.unwrap_or_default();
max_tokens, let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
random_seed,
safe_prompt,
temperature,
tool_choice,
tools,
top_p,
response_format,
} = options.unwrap_or_default();
Self { Self {
messages,
model, model,
messages,
max_tokens,
random_seed,
safe_prompt,
stream, stream,
temperature, max_tokens: opts.max_tokens,
tool_choice, min_tokens: opts.min_tokens,
tools, random_seed: opts.random_seed,
top_p, safe_prompt,
response_format, temperature: opts.temperature,
tool_choice: opts.tool_choice,
tools: opts.tools,
top_p: opts.top_p,
response_format: opts.response_format,
stop: opts.stop,
n: opts.n,
frequency_penalty: opts.frequency_penalty,
presence_penalty: opts.presence_penalty,
parallel_tool_calls: opts.parallel_tool_calls,
reasoning_effort: opts.reasoning_effort,
} }
} }
} }
@@ -192,7 +227,7 @@ pub struct ChatResponse {
pub id: String, pub id: String,
pub object: String, pub object: String,
/// Unix timestamp (in seconds). /// Unix timestamp (in seconds).
pub created: u32, pub created: u64,
pub model: constants::Model, pub model: constants::Model,
pub choices: Vec<ChatResponseChoice>, pub choices: Vec<ChatResponseChoice>,
pub usage: common::ResponseUsage, pub usage: common::ResponseUsage,
@@ -203,14 +238,18 @@ pub struct ChatResponseChoice {
pub index: u32, pub index: u32,
pub message: ChatMessage, pub message: ChatMessage,
pub finish_reason: ChatResponseChoiceFinishReason, pub finish_reason: ChatResponseChoiceFinishReason,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???
} }
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ChatResponseChoiceFinishReason { pub enum ChatResponseChoiceFinishReason {
#[serde(rename = "stop")] #[serde(rename = "stop")]
Stop, Stop,
#[serde(rename = "length")]
Length,
#[serde(rename = "tool_calls")] #[serde(rename = "tool_calls")]
ToolCalls, ToolCalls,
#[serde(rename = "model_length")]
ModelLength,
#[serde(rename = "error")]
Error,
} }

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::from_str; use serde_json::from_str;
use crate::v1::{chat, common, constants, error}; use crate::v1::{chat, common, constants, error, tool};
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Response // Response
@@ -11,12 +11,11 @@ pub struct ChatStreamChunk {
pub id: String, pub id: String,
pub object: String, pub object: String,
/// Unix timestamp (in seconds). /// Unix timestamp (in seconds).
pub created: u32, pub created: u64,
pub model: constants::Model, pub model: constants::Model,
pub choices: Vec<ChatStreamChunkChoice>, pub choices: Vec<ChatStreamChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<common::ResponseUsage>, pub usage: Option<common::ResponseUsage>,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
@@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice {
pub index: u32, pub index: u32,
pub delta: ChatStreamChunkChoiceDelta, pub delta: ChatStreamChunkChoiceDelta,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ChatStreamChunkChoiceDelta { pub struct ChatStreamChunkChoiceDelta {
pub role: Option<chat::ChatMessageRole>, pub role: Option<chat::ChatMessageRole>,
pub content: String, #[serde(default)]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>,
} }
/// Extracts serialized chunks from a stream message. /// Extracts serialized chunks from a stream message.
@@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line(
return Ok(Some(vec![])); return Ok(Some(vec![]));
} }
// Attempt to deserialize the JSON string into ChatStreamChunk
match from_str::<ChatStreamChunk>(chunk_as_json) { match from_str::<ChatStreamChunk>(chunk_as_json) {
Ok(chunk) => Ok(Some(vec![chunk])), Ok(chunk) => Ok(Some(vec![chunk])),
Err(e) => Err(error::ApiError { Err(e) => Err(error::ApiError {

View File

@@ -26,25 +26,10 @@ impl Client {
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `api_key` - An optional API key. /// * `api_key` - An optional API key. If not provided, uses `MISTRAL_API_KEY` env var.
/// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable. /// * `endpoint` - An optional custom API endpoint. Defaults to `https://api.mistral.ai/v1`.
/// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided. /// * `max_retries` - Optional maximum number of retries. Defaults to `5`.
/// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`. /// * `timeout` - Optional timeout in seconds. Defaults to `120`.
/// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`.
///
/// # Examples
///
/// ```
/// use mistralai_client::v1::client::Client;
///
/// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60));
/// assert!(client.is_ok());
/// ```
///
/// # Errors
///
/// This method fails whenever neither the `api_key` is provided
/// nor the `MISTRAL_API_KEY` environment variable is set.
pub fn new( pub fn new(
api_key: Option<String>, api_key: Option<String>,
endpoint: Option<String>, endpoint: Option<String>,
@@ -69,43 +54,15 @@ impl Client {
endpoint, endpoint,
max_retries, max_retries,
timeout, timeout,
functions, functions,
last_function_call_result, last_function_call_result,
}) })
} }
/// Synchronously sends a chat completion request and returns the response. // =========================================================================
/// // Chat Completions
/// # Arguments // =========================================================================
///
/// * `model` - The [Model] to use for the chat completion.
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
/// * `options` - Optional [ChatParams] to customize the request.
///
/// # Returns
///
/// Returns a [Result] containing the `ChatResponse` if the request is successful,
/// or an [ApiError] if there is an error.
///
/// # Examples
///
/// ```
/// use mistralai_client::v1::{
/// chat::{ChatMessage, ChatMessageRole},
/// client::Client,
/// constants::Model,
/// };
///
/// let client = Client::new(None, None, None, None).unwrap();
/// let messages = vec![ChatMessage {
/// role: ChatMessageRole::User,
/// content: "Hello, world!".to_string(),
/// tool_calls: None,
/// }];
/// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap();
/// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
/// ```
pub fn chat( pub fn chat(
&self, &self,
model: constants::Model, model: constants::Model,
@@ -119,49 +76,13 @@ impl Client {
match result { match result {
Ok(data) => { Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data); utils::debug_pretty_json_from_struct("Response Data", &data);
self.call_function_if_any(data.clone()); self.call_function_if_any(data.clone());
Ok(data) Ok(data)
} }
Err(error) => Err(self.to_api_error(error)), Err(error) => Err(self.to_api_error(error)),
} }
} }
/// Asynchronously sends a chat completion request and returns the response.
///
/// # Arguments
///
/// * `model` - The [Model] to use for the chat completion.
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
/// * `options` - Optional [ChatParams] to customize the request.
///
/// # Returns
///
/// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
/// or an [ApiError] if there is an error.
///
/// # Examples
///
/// ```
/// use mistralai_client::v1::{
/// chat::{ChatMessage, ChatMessageRole},
/// client::Client,
/// constants::Model,
/// };
///
/// #[tokio::main]
/// async fn main() {
/// let client = Client::new(None, None, None, None).unwrap();
/// let messages = vec![ChatMessage {
/// role: ChatMessageRole::User,
/// content: "Hello, world!".to_string(),
/// tool_calls: None,
/// }];
/// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap();
/// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
/// }
/// ```
pub async fn chat_async( pub async fn chat_async(
&self, &self,
model: constants::Model, model: constants::Model,
@@ -175,68 +96,13 @@ impl Client {
match result { match result {
Ok(data) => { Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data); utils::debug_pretty_json_from_struct("Response Data", &data);
self.call_function_if_any_async(data.clone()).await; self.call_function_if_any_async(data.clone()).await;
Ok(data) Ok(data)
} }
Err(error) => Err(self.to_api_error(error)), Err(error) => Err(self.to_api_error(error)),
} }
} }
/// Asynchronously sends a chat completion request and returns a stream of message chunks.
///
/// # Arguments
///
/// * `model` - The [Model] to use for the chat completion.
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
/// * `options` - Optional [ChatParams] to customize the request.
///
/// # Returns
///
/// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
/// or an [ApiError] if there is an error.
///
/// # Examples
///
/// ```
/// use futures::stream::StreamExt;
/// use mistralai_client::v1::{
/// chat::{ChatMessage, ChatMessageRole},
/// client::Client,
/// constants::Model,
/// };
/// use std::io::{self, Write};
///
/// #[tokio::main]
/// async fn main() {
/// let client = Client::new(None, None, None, None).unwrap();
/// let messages = vec![ChatMessage {
/// role: ChatMessageRole::User,
/// content: "Hello, world!".to_string(),
/// tool_calls: None,
/// }];
///
/// let stream_result = client
/// .chat_stream(Model::OpenMistral7b,messages, None)
/// .await
/// .unwrap();
/// stream_result
/// .for_each(|chunk_result| async {
/// match chunk_result {
/// Ok(chunks) => chunks.iter().for_each(|chunk| {
/// print!("{}", chunk.choices[0].delta.content);
/// io::stdout().flush().unwrap();
/// // => "Once upon a time, [...]"
/// }),
/// Err(error) => {
/// eprintln!("Error processing chunk: {:?}", error)
/// }
/// }
/// })
/// .await;
/// print!("\n") // To persist the last chunk output.
/// }
pub async fn chat_stream( pub async fn chat_stream(
&self, &self,
model: constants::Model, model: constants::Model,
@@ -292,9 +158,13 @@ impl Client {
Ok(deserialized_stream) Ok(deserialized_stream)
} }
// =========================================================================
// Embeddings
// =========================================================================
pub fn embeddings( pub fn embeddings(
&self, &self,
model: constants::EmbedModel, model: constants::Model,
input: Vec<String>, input: Vec<String>,
options: Option<embedding::EmbeddingRequestOptions>, options: Option<embedding::EmbeddingRequestOptions>,
) -> Result<embedding::EmbeddingResponse, error::ApiError> { ) -> Result<embedding::EmbeddingResponse, error::ApiError> {
@@ -305,7 +175,6 @@ impl Client {
match result { match result {
Ok(data) => { Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data); utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data) Ok(data)
} }
Err(error) => Err(self.to_api_error(error)), Err(error) => Err(self.to_api_error(error)),
@@ -314,7 +183,7 @@ impl Client {
pub async fn embeddings_async( pub async fn embeddings_async(
&self, &self,
model: constants::EmbedModel, model: constants::Model,
input: Vec<String>, input: Vec<String>,
options: Option<embedding::EmbeddingRequestOptions>, options: Option<embedding::EmbeddingRequestOptions>,
) -> Result<embedding::EmbeddingResponse, error::ApiError> { ) -> Result<embedding::EmbeddingResponse, error::ApiError> {
@@ -325,18 +194,15 @@ impl Client {
match result { match result {
Ok(data) => { Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data); utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data) Ok(data)
} }
Err(error) => Err(self.to_api_error(error)), Err(error) => Err(self.to_api_error(error)),
} }
} }
pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> { // =========================================================================
let mut result_lock = self.last_function_call_result.lock().unwrap(); // Models
// =========================================================================
result_lock.take()
}
pub fn list_models(&self) -> Result<model_list::ModelListResponse, error::ApiError> { pub fn list_models(&self) -> Result<model_list::ModelListResponse, error::ApiError> {
let response = self.get_sync("/models")?; let response = self.get_sync("/models")?;
@@ -344,7 +210,6 @@ impl Client {
match result { match result {
Ok(data) => { Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data); utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data) Ok(data)
} }
Err(error) => Err(self.to_api_error(error)), Err(error) => Err(self.to_api_error(error)),
@@ -359,68 +224,136 @@ impl Client {
match result { match result {
Ok(data) => { Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data); utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data) Ok(data)
} }
Err(error) => Err(self.to_api_error(error)), Err(error) => Err(self.to_api_error(error)),
} }
} }
pub fn get_model(&self, model_id: &str) -> Result<model_list::ModelListData, error::ApiError> {
let response = self.get_sync(&format!("/models/{}", model_id))?;
let result = response.json::<model_list::ModelListData>();
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
}
}
pub async fn get_model_async(
&self,
model_id: &str,
) -> Result<model_list::ModelListData, error::ApiError> {
let response = self.get_async(&format!("/models/{}", model_id)).await?;
let result = response.json::<model_list::ModelListData>().await;
match result {
Ok(data) => {
utils::debug_pretty_json_from_struct("Response Data", &data);
Ok(data)
}
Err(error) => Err(self.to_api_error(error)),
}
}
pub fn delete_model(
&self,
model_id: &str,
) -> Result<model_list::ModelDeleteResponse, error::ApiError> {
let response = self.delete_sync(&format!("/models/{}", model_id))?;
let result = response.json::<model_list::ModelDeleteResponse>();
match result {
Ok(data) => Ok(data),
Err(error) => Err(self.to_api_error(error)),
}
}
pub async fn delete_model_async(
&self,
model_id: &str,
) -> Result<model_list::ModelDeleteResponse, error::ApiError> {
let response = self
.delete_async(&format!("/models/{}", model_id))
.await?;
let result = response.json::<model_list::ModelDeleteResponse>().await;
match result {
Ok(data) => Ok(data),
Err(error) => Err(self.to_api_error(error)),
}
}
// =========================================================================
// Function Calling
// =========================================================================
pub fn register_function(&mut self, name: String, function: Box<dyn tool::Function>) { pub fn register_function(&mut self, name: String, function: Box<dyn tool::Function>) {
let mut functions = self.functions.lock().unwrap(); let mut functions = self.functions.lock().unwrap();
functions.insert(name, function); functions.insert(name, function);
} }
pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
let mut result_lock = self.last_function_call_result.lock().unwrap();
result_lock.take()
}
// =========================================================================
// HTTP Transport
// =========================================================================
fn user_agent(&self) -> String {
format!(
"mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
)
}
fn build_request_sync( fn build_request_sync(
&self, &self,
request: reqwest::blocking::RequestBuilder, request: reqwest::blocking::RequestBuilder,
) -> reqwest::blocking::RequestBuilder { ) -> reqwest::blocking::RequestBuilder {
let user_agent = format!( request
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
.bearer_auth(&self.api_key) .bearer_auth(&self.api_key)
.header("Accept", "application/json") .header("Accept", "application/json")
.header("User-Agent", user_agent); .header("User-Agent", self.user_agent())
}
request_builder fn build_request_sync_no_accept(
&self,
request: reqwest::blocking::RequestBuilder,
) -> reqwest::blocking::RequestBuilder {
request
.bearer_auth(&self.api_key)
.header("User-Agent", self.user_agent())
} }
fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let user_agent = format!( request
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
.bearer_auth(&self.api_key) .bearer_auth(&self.api_key)
.header("Accept", "application/json") .header("Accept", "application/json")
.header("User-Agent", user_agent); .header("User-Agent", self.user_agent())
}
request_builder fn build_request_async_no_accept(
&self,
request: reqwest::RequestBuilder,
) -> reqwest::RequestBuilder {
request
.bearer_auth(&self.api_key)
.header("User-Agent", self.user_agent())
} }
fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder { fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
let user_agent = format!( request
"ivangabriele/mistralai-client-rs/{}",
env!("CARGO_PKG_VERSION")
);
let request_builder = request
.bearer_auth(&self.api_key) .bearer_auth(&self.api_key)
.header("Accept", "text/event-stream") .header("Accept", "text/event-stream")
.header("User-Agent", user_agent); .header("User-Agent", self.user_agent())
request_builder
} }
fn call_function_if_any(&self, response: chat::ChatResponse) -> () { fn call_function_if_any(&self, response: chat::ChatResponse) {
let next_result = match response.choices.get(0) { let next_result = match response.choices.first() {
Some(first_choice) => match first_choice.message.tool_calls.to_owned() { Some(first_choice) => match first_choice.message.tool_calls.as_ref() {
Some(tool_calls) => match tool_calls.get(0) { Some(tool_calls) => match tool_calls.first() {
Some(first_tool_call) => { Some(first_tool_call) => {
let functions = self.functions.lock().unwrap(); let functions = self.functions.lock().unwrap();
match functions.get(&first_tool_call.function.name) { match functions.get(&first_tool_call.function.name) {
@@ -431,7 +364,6 @@ impl Client {
.execute(first_tool_call.function.arguments.to_owned()) .execute(first_tool_call.function.arguments.to_owned())
.await .await
}); });
Some(result) Some(result)
} }
None => None, None => None,
@@ -448,10 +380,10 @@ impl Client {
*last_result_lock = next_result; *last_result_lock = next_result;
} }
async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () { async fn call_function_if_any_async(&self, response: chat::ChatResponse) {
let next_result = match response.choices.get(0) { let next_result = match response.choices.first() {
Some(first_choice) => match first_choice.message.tool_calls.to_owned() { Some(first_choice) => match first_choice.message.tool_calls.as_ref() {
Some(tool_calls) => match tool_calls.get(0) { Some(tool_calls) => match tool_calls.first() {
Some(first_tool_call) => { Some(first_tool_call) => {
let functions = self.functions.lock().unwrap(); let functions = self.functions.lock().unwrap();
match functions.get(&first_tool_call.function.name) { match functions.get(&first_tool_call.function.name) {
@@ -459,7 +391,6 @@ impl Client {
let result = function let result = function
.execute(first_tool_call.function.arguments.to_owned()) .execute(first_tool_call.function.arguments.to_owned())
.await; .await;
Some(result) Some(result)
} }
None => None, None => None,
@@ -482,27 +413,8 @@ impl Client {
debug!("Request URL: {}", url); debug!("Request URL: {}", url);
let request = self.build_request_sync(reqwest_client.get(url)); let request = self.build_request_sync(reqwest_client.get(url));
let result = request.send(); let result = request.send();
match result { self.handle_sync_response(result)
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
Err(error::ApiError {
message: format!("{}: {}", response_status, response_body),
})
}
}
Err(error) => Err(error::ApiError {
message: error.to_string(),
}),
}
} }
async fn get_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> { async fn get_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
@@ -510,29 +422,9 @@ impl Client {
let url = format!("{}{}", self.endpoint, path); let url = format!("{}{}", self.endpoint, path);
debug!("Request URL: {}", url); debug!("Request URL: {}", url);
let request_builder = reqwest_client.get(url); let request = self.build_request_async(reqwest_client.get(url));
let request = self.build_request_async(request_builder);
let result = request.send().await; let result = request.send().await;
match result { self.handle_async_response(result).await
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().await.unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
Err(error::ApiError {
message: format!("{}: {}", response_status, response_body),
})
}
}
Err(error) => Err(error::ApiError {
message: error.to_string(),
}),
}
} }
fn post_sync<T: std::fmt::Debug + serde::ser::Serialize>( fn post_sync<T: std::fmt::Debug + serde::ser::Serialize>(
@@ -545,29 +437,22 @@ impl Client {
debug!("Request URL: {}", url); debug!("Request URL: {}", url);
utils::debug_pretty_json_from_struct("Request Body", params); utils::debug_pretty_json_from_struct("Request Body", params);
let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_sync(reqwest_client.post(url).json(params));
let request = self.build_request_sync(request_builder);
let result = request.send(); let result = request.send();
match result { self.handle_sync_response(result)
Ok(response) => { }
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
Err(error::ApiError { fn post_sync_empty(
message: format!("{}: {}", response_body, response_status), &self,
}) path: &str,
} ) -> Result<reqwest::blocking::Response, error::ApiError> {
} let reqwest_client = reqwest::blocking::Client::new();
Err(error) => Err(error::ApiError { let url = format!("{}{}", self.endpoint, path);
message: error.to_string(), debug!("Request URL: {}", url);
}),
} let request = self.build_request_sync(reqwest_client.post(url));
let result = request.send();
self.handle_sync_response(result)
} }
async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>( async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
@@ -580,29 +465,19 @@ impl Client {
debug!("Request URL: {}", url); debug!("Request URL: {}", url);
utils::debug_pretty_json_from_struct("Request Body", params); utils::debug_pretty_json_from_struct("Request Body", params);
let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_async(reqwest_client.post(url).json(params));
let request = self.build_request_async(request_builder);
let result = request.send().await; let result = request.send().await;
match result { self.handle_async_response(result).await
Ok(response) => { }
if response.status().is_success() {
Ok(response)
} else {
let response_status = response.status();
let response_body = response.text().await.unwrap_or_default();
debug!("Response Status: {}", &response_status);
utils::debug_pretty_json_from_string("Response Data", &response_body);
Err(error::ApiError { async fn post_async_empty(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
message: format!("{}: {}", response_status, response_body), let reqwest_client = reqwest::Client::new();
}) let url = format!("{}{}", self.endpoint, path);
} debug!("Request URL: {}", url);
}
Err(error) => Err(error::ApiError { let request = self.build_request_async(reqwest_client.post(url));
message: error.to_string(), let result = request.send().await;
}), self.handle_async_response(result).await
}
} }
async fn post_stream<T: serde::ser::Serialize + std::fmt::Debug>( async fn post_stream<T: serde::ser::Serialize + std::fmt::Debug>(
@@ -615,22 +490,70 @@ impl Client {
debug!("Request URL: {}", url); debug!("Request URL: {}", url);
utils::debug_pretty_json_from_struct("Request Body", params); utils::debug_pretty_json_from_struct("Request Body", params);
let request_builder = reqwest_client.post(url).json(params); let request = self.build_request_stream(reqwest_client.post(url).json(params));
let request = self.build_request_stream(request_builder);
let result = request.send().await; let result = request.send().await;
self.handle_async_response(result).await
}
fn delete_sync(&self, path: &str) -> Result<reqwest::blocking::Response, error::ApiError> {
let reqwest_client = reqwest::blocking::Client::new();
let url = format!("{}{}", self.endpoint, path);
debug!("Request URL: {}", url);
let request = self.build_request_sync(reqwest_client.delete(url));
let result = request.send();
self.handle_sync_response(result)
}
async fn delete_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
let reqwest_client = reqwest::Client::new();
let url = format!("{}{}", self.endpoint, path);
debug!("Request URL: {}", url);
let request = self.build_request_async(reqwest_client.delete(url));
let result = request.send().await;
self.handle_async_response(result).await
}
fn handle_sync_response(
&self,
result: Result<reqwest::blocking::Response, reqwest::Error>,
) -> Result<reqwest::blocking::Response, error::ApiError> {
match result { match result {
Ok(response) => { Ok(response) => {
if response.status().is_success() { if response.status().is_success() {
Ok(response) Ok(response)
} else { } else {
let response_status = response.status(); let status = response.status();
let response_body = response.text().await.unwrap_or_default(); let body = response.text().unwrap_or_default();
debug!("Response Status: {}", &response_status); debug!("Response Status: {}", &status);
utils::debug_pretty_json_from_string("Response Data", &response_body); utils::debug_pretty_json_from_string("Response Data", &body);
Err(error::ApiError { Err(error::ApiError {
message: format!("{}: {}", response_status, response_body), message: format!("{}: {}", status, body),
})
}
}
Err(error) => Err(error::ApiError {
message: error.to_string(),
}),
}
}
async fn handle_async_response(
&self,
result: Result<reqwest::Response, reqwest::Error>,
) -> Result<reqwest::Response, error::ApiError> {
match result {
Ok(response) => {
if response.status().is_success() {
Ok(response)
} else {
let status = response.status();
let body = response.text().await.unwrap_or_default();
debug!("Response Status: {}", &status);
utils::debug_pretty_json_from_string("Response Data", &body);
Err(error::ApiError {
message: format!("{}: {}", status, body),
}) })
} }
} }

View File

@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ResponseUsage { pub struct ResponseUsage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
#[serde(default)]
pub completion_tokens: u32, pub completion_tokens: u32,
pub total_tokens: u32, pub total_tokens: u32,
} }

View File

@@ -1,35 +1,131 @@
use std::fmt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1"; pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] /// A Mistral AI model identifier.
pub enum Model { ///
#[serde(rename = "open-mistral-7b")] /// Use the associated constants for known models, or construct with `Model::new()` for any model string.
OpenMistral7b, #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename = "open-mixtral-8x7b")] #[serde(transparent)]
OpenMixtral8x7b, pub struct Model(pub String);
#[serde(rename = "open-mixtral-8x22b")]
OpenMixtral8x22b, impl Model {
#[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")] pub fn new(id: impl Into<String>) -> Self {
OpenMistralNemo, Self(id.into())
#[serde(rename = "mistral-tiny")] }
MistralTiny,
#[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")] // Flagship / Premier
MistralSmallLatest, pub fn mistral_large_latest() -> Self {
#[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")] Self::new("mistral-large-latest")
MistralMediumLatest, }
#[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")] pub fn mistral_large_3() -> Self {
MistralLargeLatest, Self::new("mistral-large-3-25-12")
#[serde(rename = "mistral-large-2402")] }
MistralLarge, pub fn mistral_medium_latest() -> Self {
#[serde(rename = "codestral-latest", alias = "codestral-2405")] Self::new("mistral-medium-latest")
CodestralLatest, }
#[serde(rename = "open-codestral-mamba")] pub fn mistral_medium_3_1() -> Self {
CodestralMamba, Self::new("mistral-medium-3-1-25-08")
}
pub fn mistral_small_latest() -> Self {
Self::new("mistral-small-latest")
}
pub fn mistral_small_4() -> Self {
Self::new("mistral-small-4-0-26-03")
}
pub fn mistral_small_3_2() -> Self {
Self::new("mistral-small-3-2-25-06")
}
// Ministral
pub fn ministral_3_14b() -> Self {
Self::new("ministral-3-14b-25-12")
}
pub fn ministral_3_8b() -> Self {
Self::new("ministral-3-8b-25-12")
}
pub fn ministral_3_3b() -> Self {
Self::new("ministral-3-3b-25-12")
}
// Reasoning
pub fn magistral_medium_latest() -> Self {
Self::new("magistral-medium-latest")
}
pub fn magistral_small_latest() -> Self {
Self::new("magistral-small-latest")
}
// Code
pub fn codestral_latest() -> Self {
Self::new("codestral-latest")
}
pub fn codestral_2508() -> Self {
Self::new("codestral-2508")
}
pub fn codestral_embed() -> Self {
Self::new("codestral-embed-25-05")
}
pub fn devstral_2() -> Self {
Self::new("devstral-2-25-12")
}
pub fn devstral_small_2() -> Self {
Self::new("devstral-small-2-25-12")
}
// Multimodal / Vision
pub fn pixtral_large() -> Self {
Self::new("pixtral-large-2411")
}
// Audio
pub fn voxtral_mini_transcribe() -> Self {
Self::new("voxtral-mini-transcribe-2-26-02")
}
pub fn voxtral_small() -> Self {
Self::new("voxtral-small-25-07")
}
pub fn voxtral_mini() -> Self {
Self::new("voxtral-mini-25-07")
}
// Legacy (kept for backward compatibility)
pub fn open_mistral_nemo() -> Self {
Self::new("open-mistral-nemo")
}
// Embedding
pub fn mistral_embed() -> Self {
Self::new("mistral-embed")
}
// Moderation
pub fn mistral_moderation_latest() -> Self {
Self::new("mistral-moderation-26-03")
}
// OCR
pub fn mistral_ocr_latest() -> Self {
Self::new("mistral-ocr-latest")
}
} }
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] impl fmt::Display for Model {
pub enum EmbedModel { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
#[serde(rename = "mistral-embed")] write!(f, "{}", self.0)
MistralEmbed, }
}
impl From<&str> for Model {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for Model {
fn from(s: String) -> Self {
Self(s)
}
} }

View File

@@ -8,42 +8,63 @@ use crate::v1::{common, constants};
#[derive(Debug)] #[derive(Debug)]
pub struct EmbeddingRequestOptions { pub struct EmbeddingRequestOptions {
pub encoding_format: Option<EmbeddingRequestEncodingFormat>, pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
pub output_dimension: Option<u32>,
pub output_dtype: Option<EmbeddingOutputDtype>,
} }
impl Default for EmbeddingRequestOptions { impl Default for EmbeddingRequestOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
encoding_format: None, encoding_format: None,
output_dimension: None,
output_dtype: None,
} }
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingRequest { pub struct EmbeddingRequest {
pub model: constants::EmbedModel, pub model: constants::Model,
pub input: Vec<String>, pub input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EmbeddingRequestEncodingFormat>, pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dimension: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dtype: Option<EmbeddingOutputDtype>,
} }
impl EmbeddingRequest { impl EmbeddingRequest {
pub fn new( pub fn new(
model: constants::EmbedModel, model: constants::Model,
input: Vec<String>, input: Vec<String>,
options: Option<EmbeddingRequestOptions>, options: Option<EmbeddingRequestOptions>,
) -> Self { ) -> Self {
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default(); let opts = options.unwrap_or_default();
Self { Self {
model, model,
input, input,
encoding_format, encoding_format: opts.encoding_format,
output_dimension: opts.output_dimension,
output_dtype: opts.output_dtype,
} }
} }
} }
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[allow(non_camel_case_types)] #[serde(rename_all = "lowercase")]
pub enum EmbeddingRequestEncodingFormat { pub enum EmbeddingRequestEncodingFormat {
float, Float,
Base64,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingOutputDtype {
Float,
Int8,
Uint8,
Binary,
Ubinary,
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat {
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EmbeddingResponse { pub struct EmbeddingResponse {
pub id: String,
pub object: String, pub object: String,
pub model: constants::EmbedModel, pub model: constants::Model,
pub data: Vec<EmbeddingResponseDataItem>, pub data: Vec<EmbeddingResponseDataItem>,
pub usage: common::ResponseUsage, pub usage: common::ResponseUsage,
} }

View File

@@ -15,23 +15,44 @@ pub struct ModelListData {
pub id: String, pub id: String,
pub object: String, pub object: String,
/// Unix timestamp (in seconds). /// Unix timestamp (in seconds).
pub created: u32, pub created: u64,
pub owned_by: String, pub owned_by: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub root: Option<String>, pub root: Option<String>,
#[serde(default)]
pub archived: bool, pub archived: bool,
pub name: String, #[serde(default)]
pub description: String, pub name: Option<String>,
pub capabilities: ModelListDataCapabilies, #[serde(default)]
pub max_context_length: u32, pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub capabilities: Option<ModelListDataCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_context_length: Option<u32>,
#[serde(default)]
pub aliases: Vec<String>, pub aliases: Vec<String>,
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`). /// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
#[serde(skip_serializing_if = "Option::is_none")]
pub deprecation: Option<String>, pub deprecation: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelListDataCapabilies { pub struct ModelListDataCapabilities {
#[serde(default)]
pub completion_chat: bool, pub completion_chat: bool,
#[serde(default)]
pub completion_fim: bool, pub completion_fim: bool,
#[serde(default)]
pub function_calling: bool, pub function_calling: bool,
#[serde(default)]
pub fine_tuning: bool, pub fine_tuning: bool,
#[serde(default)]
pub vision: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelDeleteResponse {
pub id: String,
pub object: String,
pub deleted: bool,
} }

View File

@@ -1,12 +1,16 @@
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{any::Any, collections::HashMap, fmt::Debug}; use std::{any::Any, fmt::Debug};
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Definitions // Definitions
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct ToolCall { pub struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
pub function: ToolCallFunction, pub function: ToolCallFunction,
} }
@@ -22,31 +26,12 @@ pub struct Tool {
pub function: ToolFunction, pub function: ToolFunction,
} }
impl Tool { impl Tool {
/// Create a tool with a JSON Schema parameters object.
pub fn new( pub fn new(
function_name: String, function_name: String,
function_description: String, function_description: String,
function_parameters: Vec<ToolFunctionParameter>, parameters: serde_json::Value,
) -> Self { ) -> Self {
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
.into_iter()
.map(|param| {
(
param.name,
ToolFunctionParameterProperty {
r#type: param.r#type,
description: param.description,
},
)
})
.collect();
let property_names = properties.keys().cloned().collect();
let parameters = ToolFunctionParameters {
r#type: ToolFunctionParametersType::Object,
properties,
required: property_names,
};
Self { Self {
r#type: ToolType::Function, r#type: ToolType::Function,
function: ToolFunction { function: ToolFunction {
@@ -63,50 +48,9 @@ impl Tool {
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunction { pub struct ToolFunction {
name: String, pub name: String,
description: String, pub description: String,
parameters: ToolFunctionParameters, pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameter {
name: String,
description: String,
r#type: ToolFunctionParameterType,
}
impl ToolFunctionParameter {
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
Self {
name,
r#type,
description,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameters {
r#type: ToolFunctionParametersType,
properties: HashMap<String, ToolFunctionParameterProperty>,
required: Vec<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameterProperty {
r#type: ToolFunctionParameterType,
description: String,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParametersType {
#[serde(rename = "object")]
Object,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParameterType {
#[serde(rename = "string")]
String,
} }
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
@@ -127,6 +71,9 @@ pub enum ToolChoice {
/// The model won't call a function and will generate a message instead. /// The model won't call a function and will generate a message instead.
#[serde(rename = "none")] #[serde(rename = "none")]
None, None,
/// The model must call at least one tool.
#[serde(rename = "required")]
Required,
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------