refactor!: modernize core types and client for latest Mistral API
BREAKING CHANGE: Model is now a string-based struct with constructor methods instead of a closed enum. EmbedModel is removed — use Model::mistral_embed() instead. Tool parameters now accept serde_json::Value (JSON Schema) instead of limited enum types. - Replace Model enum with flexible Model(String) supporting all current models: Large 3, Small 4, Medium 3.1, Magistral, Codestral, Devstral, Pixtral, Voxtral, Ministral, and arbitrary strings - Remove EmbedModel enum (consolidated into Model) - Chat: add frequency_penalty, presence_penalty, stop, n, min_tokens, parallel_tool_calls, reasoning_effort, json_schema response format - Embeddings: add output_dimension and output_dtype fields - Tools: accept raw JSON Schema, add tool call IDs and Required choice - Stream delta content is now Option<String> for tool call chunks - Add Length, ModelLength, Error finish reason variants - DRY HTTP transport with shared response handlers - Add DELETE method support and model get/delete endpoints - Make model_list fields more lenient with Option/default for API compat
This commit is contained in:
187
src/v1/chat.rs
187
src/v1/chat.rs
@@ -11,13 +11,31 @@ pub struct ChatMessage {
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||
/// Tool call ID, required when role is Tool.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
/// Function name, used when role is Tool.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
impl ChatMessage {
|
||||
pub fn new_system_message(content: &str) -> Self {
|
||||
Self {
|
||||
role: ChatMessageRole::System,
|
||||
content: content.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
|
||||
Self {
|
||||
role: ChatMessageRole::Assistant,
|
||||
content: content.to_string(),
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +44,18 @@ impl ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: content.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
|
||||
Self {
|
||||
role: ChatMessageRole::Tool,
|
||||
content: content.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
name: name.map(|n| n.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,17 +74,32 @@ pub enum ChatMessageRole {
|
||||
}
|
||||
|
||||
/// The format that the model must output.
|
||||
///
|
||||
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ResponseFormat {
|
||||
#[serde(rename = "type")]
|
||||
pub type_: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub json_schema: Option<serde_json::Value>,
|
||||
}
|
||||
impl ResponseFormat {
|
||||
pub fn text() -> Self {
|
||||
Self {
|
||||
type_: "text".to_string(),
|
||||
json_schema: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn json_object() -> Self {
|
||||
Self {
|
||||
type_: "json_object".to_string(),
|
||||
json_schema: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn json_schema(schema: serde_json::Value) -> Self {
|
||||
Self {
|
||||
type_: "json_schema".to_string(),
|
||||
json_schema: Some(schema),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,91 +108,83 @@ impl ResponseFormat {
|
||||
// Request
|
||||
|
||||
/// The parameters for the chat request.
|
||||
///
|
||||
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChatParams {
|
||||
/// The maximum number of tokens to generate in the completion.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub max_tokens: Option<u32>,
|
||||
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub min_tokens: Option<u32>,
|
||||
pub random_seed: Option<u32>,
|
||||
/// The format that the model must output.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
/// Whether to inject a safety prompt before all conversations.
|
||||
///
|
||||
/// Defaults to `false`.
|
||||
pub safe_prompt: bool,
|
||||
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
|
||||
///
|
||||
/// Defaults to `0.7`.
|
||||
pub temperature: f32,
|
||||
/// Specifies if/how functions are called.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub temperature: Option<f32>,
|
||||
pub tool_choice: Option<tool::ToolChoice>,
|
||||
/// A list of available tools for the model.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub tools: Option<Vec<tool::Tool>>,
|
||||
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
|
||||
///
|
||||
/// Defaults to `1.0`.
|
||||
pub top_p: f32,
|
||||
pub top_p: Option<f32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub n: Option<u32>,
|
||||
pub frequency_penalty: Option<f32>,
|
||||
pub presence_penalty: Option<f32>,
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
/// For reasoning models (Magistral). "high" or "none".
|
||||
pub reasoning_effort: Option<String>,
|
||||
}
|
||||
impl Default for ChatParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_tokens: None,
|
||||
min_tokens: None,
|
||||
random_seed: None,
|
||||
safe_prompt: false,
|
||||
response_format: None,
|
||||
temperature: 0.7,
|
||||
temperature: None,
|
||||
tool_choice: None,
|
||||
tools: None,
|
||||
top_p: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ChatParams {
|
||||
pub fn json_default() -> Self {
|
||||
Self {
|
||||
max_tokens: None,
|
||||
random_seed: None,
|
||||
safe_prompt: false,
|
||||
response_format: None,
|
||||
temperature: 0.7,
|
||||
tool_choice: None,
|
||||
tools: None,
|
||||
top_p: 1.0,
|
||||
top_p: None,
|
||||
stop: None,
|
||||
n: None,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
parallel_tool_calls: None,
|
||||
reasoning_effort: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub model: constants::Model,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub stream: bool,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub min_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub random_seed: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
pub safe_prompt: bool,
|
||||
pub stream: bool,
|
||||
pub temperature: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub safe_prompt: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<tool::ToolChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<tool::Tool>>,
|
||||
pub top_p: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_effort: Option<String>,
|
||||
}
|
||||
impl ChatRequest {
|
||||
pub fn new(
|
||||
@@ -156,30 +193,28 @@ impl ChatRequest {
|
||||
stream: bool,
|
||||
options: Option<ChatParams>,
|
||||
) -> Self {
|
||||
let ChatParams {
|
||||
max_tokens,
|
||||
random_seed,
|
||||
safe_prompt,
|
||||
temperature,
|
||||
tool_choice,
|
||||
tools,
|
||||
top_p,
|
||||
response_format,
|
||||
} = options.unwrap_or_default();
|
||||
let opts = options.unwrap_or_default();
|
||||
let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
|
||||
|
||||
Self {
|
||||
messages,
|
||||
model,
|
||||
|
||||
max_tokens,
|
||||
random_seed,
|
||||
safe_prompt,
|
||||
messages,
|
||||
stream,
|
||||
temperature,
|
||||
tool_choice,
|
||||
tools,
|
||||
top_p,
|
||||
response_format,
|
||||
max_tokens: opts.max_tokens,
|
||||
min_tokens: opts.min_tokens,
|
||||
random_seed: opts.random_seed,
|
||||
safe_prompt,
|
||||
temperature: opts.temperature,
|
||||
tool_choice: opts.tool_choice,
|
||||
tools: opts.tools,
|
||||
top_p: opts.top_p,
|
||||
response_format: opts.response_format,
|
||||
stop: opts.stop,
|
||||
n: opts.n,
|
||||
frequency_penalty: opts.frequency_penalty,
|
||||
presence_penalty: opts.presence_penalty,
|
||||
parallel_tool_calls: opts.parallel_tool_calls,
|
||||
reasoning_effort: opts.reasoning_effort,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -192,7 +227,7 @@ pub struct ChatResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
/// Unix timestamp (in seconds).
|
||||
pub created: u32,
|
||||
pub created: u64,
|
||||
pub model: constants::Model,
|
||||
pub choices: Vec<ChatResponseChoice>,
|
||||
pub usage: common::ResponseUsage,
|
||||
@@ -203,14 +238,18 @@ pub struct ChatResponseChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: ChatResponseChoiceFinishReason,
|
||||
// TODO Check this prop (seen in API responses but undocumented).
|
||||
// pub logprobs: ???
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ChatResponseChoiceFinishReason {
|
||||
#[serde(rename = "stop")]
|
||||
Stop,
|
||||
#[serde(rename = "length")]
|
||||
Length,
|
||||
#[serde(rename = "tool_calls")]
|
||||
ToolCalls,
|
||||
#[serde(rename = "model_length")]
|
||||
ModelLength,
|
||||
#[serde(rename = "error")]
|
||||
Error,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
|
||||
use crate::v1::{chat, common, constants, error};
|
||||
use crate::v1::{chat, common, constants, error, tool};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
@@ -11,12 +11,11 @@ pub struct ChatStreamChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
/// Unix timestamp (in seconds).
|
||||
pub created: u32,
|
||||
pub created: u64,
|
||||
pub model: constants::Model,
|
||||
pub choices: Vec<ChatStreamChunkChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<common::ResponseUsage>,
|
||||
// TODO Check this prop (seen in API responses but undocumented).
|
||||
// pub logprobs: ???,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
@@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatStreamChunkChoiceDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
// TODO Check this prop (seen in API responses but undocumented).
|
||||
// pub logprobs: ???,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ChatStreamChunkChoiceDelta {
|
||||
pub role: Option<chat::ChatMessageRole>,
|
||||
pub content: String,
|
||||
#[serde(default)]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||
}
|
||||
|
||||
/// Extracts serialized chunks from a stream message.
|
||||
@@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line(
|
||||
return Ok(Some(vec![]));
|
||||
}
|
||||
|
||||
// Attempt to deserialize the JSON string into ChatStreamChunk
|
||||
match from_str::<ChatStreamChunk>(chunk_as_json) {
|
||||
Ok(chunk) => Ok(Some(vec![chunk])),
|
||||
Err(e) => Err(error::ApiError {
|
||||
|
||||
487
src/v1/client.rs
487
src/v1/client.rs
@@ -26,25 +26,10 @@ impl Client {
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `api_key` - An optional API key.
|
||||
/// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable.
|
||||
/// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided.
|
||||
/// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`.
|
||||
/// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use mistralai_client::v1::client::Client;
|
||||
///
|
||||
/// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60));
|
||||
/// assert!(client.is_ok());
|
||||
/// ```
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// This method fails whenever neither the `api_key` is provided
|
||||
/// nor the `MISTRAL_API_KEY` environment variable is set.
|
||||
/// * `api_key` - An optional API key. If not provided, uses `MISTRAL_API_KEY` env var.
|
||||
/// * `endpoint` - An optional custom API endpoint. Defaults to `https://api.mistral.ai/v1`.
|
||||
/// * `max_retries` - Optional maximum number of retries. Defaults to `5`.
|
||||
/// * `timeout` - Optional timeout in seconds. Defaults to `120`.
|
||||
pub fn new(
|
||||
api_key: Option<String>,
|
||||
endpoint: Option<String>,
|
||||
@@ -69,43 +54,15 @@ impl Client {
|
||||
endpoint,
|
||||
max_retries,
|
||||
timeout,
|
||||
|
||||
functions,
|
||||
last_function_call_result,
|
||||
})
|
||||
}
|
||||
|
||||
/// Synchronously sends a chat completion request and returns the response.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The [Model] to use for the chat completion.
|
||||
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
|
||||
/// * `options` - Optional [ChatParams] to customize the request.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a [Result] containing the `ChatResponse` if the request is successful,
|
||||
/// or an [ApiError] if there is an error.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use mistralai_client::v1::{
|
||||
/// chat::{ChatMessage, ChatMessageRole},
|
||||
/// client::Client,
|
||||
/// constants::Model,
|
||||
/// };
|
||||
///
|
||||
/// let client = Client::new(None, None, None, None).unwrap();
|
||||
/// let messages = vec![ChatMessage {
|
||||
/// role: ChatMessageRole::User,
|
||||
/// content: "Hello, world!".to_string(),
|
||||
/// tool_calls: None,
|
||||
/// }];
|
||||
/// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap();
|
||||
/// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
|
||||
/// ```
|
||||
// =========================================================================
|
||||
// Chat Completions
|
||||
// =========================================================================
|
||||
|
||||
pub fn chat(
|
||||
&self,
|
||||
model: constants::Model,
|
||||
@@ -119,49 +76,13 @@ impl Client {
|
||||
match result {
|
||||
Ok(data) => {
|
||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||
|
||||
self.call_function_if_any(data.clone());
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Asynchronously sends a chat completion request and returns the response.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The [Model] to use for the chat completion.
|
||||
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
|
||||
/// * `options` - Optional [ChatParams] to customize the request.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
|
||||
/// or an [ApiError] if there is an error.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use mistralai_client::v1::{
|
||||
/// chat::{ChatMessage, ChatMessageRole},
|
||||
/// client::Client,
|
||||
/// constants::Model,
|
||||
/// };
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let client = Client::new(None, None, None, None).unwrap();
|
||||
/// let messages = vec![ChatMessage {
|
||||
/// role: ChatMessageRole::User,
|
||||
/// content: "Hello, world!".to_string(),
|
||||
/// tool_calls: None,
|
||||
/// }];
|
||||
/// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap();
|
||||
/// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
|
||||
/// }
|
||||
/// ```
|
||||
pub async fn chat_async(
|
||||
&self,
|
||||
model: constants::Model,
|
||||
@@ -175,68 +96,13 @@ impl Client {
|
||||
match result {
|
||||
Ok(data) => {
|
||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||
|
||||
self.call_function_if_any_async(data.clone()).await;
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Asynchronously sends a chat completion request and returns a stream of message chunks.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The [Model] to use for the chat completion.
|
||||
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
|
||||
/// * `options` - Optional [ChatParams] to customize the request.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
|
||||
/// or an [ApiError] if there is an error.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use futures::stream::StreamExt;
|
||||
/// use mistralai_client::v1::{
|
||||
/// chat::{ChatMessage, ChatMessageRole},
|
||||
/// client::Client,
|
||||
/// constants::Model,
|
||||
/// };
|
||||
/// use std::io::{self, Write};
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let client = Client::new(None, None, None, None).unwrap();
|
||||
/// let messages = vec![ChatMessage {
|
||||
/// role: ChatMessageRole::User,
|
||||
/// content: "Hello, world!".to_string(),
|
||||
/// tool_calls: None,
|
||||
/// }];
|
||||
///
|
||||
/// let stream_result = client
|
||||
/// .chat_stream(Model::OpenMistral7b,messages, None)
|
||||
/// .await
|
||||
/// .unwrap();
|
||||
/// stream_result
|
||||
/// .for_each(|chunk_result| async {
|
||||
/// match chunk_result {
|
||||
/// Ok(chunks) => chunks.iter().for_each(|chunk| {
|
||||
/// print!("{}", chunk.choices[0].delta.content);
|
||||
/// io::stdout().flush().unwrap();
|
||||
/// // => "Once upon a time, [...]"
|
||||
/// }),
|
||||
/// Err(error) => {
|
||||
/// eprintln!("Error processing chunk: {:?}", error)
|
||||
/// }
|
||||
/// }
|
||||
/// })
|
||||
/// .await;
|
||||
/// print!("\n") // To persist the last chunk output.
|
||||
/// }
|
||||
pub async fn chat_stream(
|
||||
&self,
|
||||
model: constants::Model,
|
||||
@@ -292,9 +158,13 @@ impl Client {
|
||||
Ok(deserialized_stream)
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Embeddings
|
||||
// =========================================================================
|
||||
|
||||
pub fn embeddings(
|
||||
&self,
|
||||
model: constants::EmbedModel,
|
||||
model: constants::Model,
|
||||
input: Vec<String>,
|
||||
options: Option<embedding::EmbeddingRequestOptions>,
|
||||
) -> Result<embedding::EmbeddingResponse, error::ApiError> {
|
||||
@@ -305,7 +175,6 @@ impl Client {
|
||||
match result {
|
||||
Ok(data) => {
|
||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
@@ -314,7 +183,7 @@ impl Client {
|
||||
|
||||
pub async fn embeddings_async(
|
||||
&self,
|
||||
model: constants::EmbedModel,
|
||||
model: constants::Model,
|
||||
input: Vec<String>,
|
||||
options: Option<embedding::EmbeddingRequestOptions>,
|
||||
) -> Result<embedding::EmbeddingResponse, error::ApiError> {
|
||||
@@ -325,18 +194,15 @@ impl Client {
|
||||
match result {
|
||||
Ok(data) => {
|
||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
|
||||
let mut result_lock = self.last_function_call_result.lock().unwrap();
|
||||
|
||||
result_lock.take()
|
||||
}
|
||||
// =========================================================================
|
||||
// Models
|
||||
// =========================================================================
|
||||
|
||||
pub fn list_models(&self) -> Result<model_list::ModelListResponse, error::ApiError> {
|
||||
let response = self.get_sync("/models")?;
|
||||
@@ -344,7 +210,6 @@ impl Client {
|
||||
match result {
|
||||
Ok(data) => {
|
||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
@@ -359,68 +224,136 @@ impl Client {
|
||||
match result {
|
||||
Ok(data) => {
|
||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||
|
||||
Ok(data)
|
||||
}
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_model(&self, model_id: &str) -> Result<model_list::ModelListData, error::ApiError> {
|
||||
let response = self.get_sync(&format!("/models/{}", model_id))?;
|
||||
let result = response.json::<model_list::ModelListData>();
|
||||
match result {
|
||||
Ok(data) => {
|
||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||
Ok(data)
|
||||
}
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_model_async(
|
||||
&self,
|
||||
model_id: &str,
|
||||
) -> Result<model_list::ModelListData, error::ApiError> {
|
||||
let response = self.get_async(&format!("/models/{}", model_id)).await?;
|
||||
let result = response.json::<model_list::ModelListData>().await;
|
||||
match result {
|
||||
Ok(data) => {
|
||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||
Ok(data)
|
||||
}
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn delete_model(
|
||||
&self,
|
||||
model_id: &str,
|
||||
) -> Result<model_list::ModelDeleteResponse, error::ApiError> {
|
||||
let response = self.delete_sync(&format!("/models/{}", model_id))?;
|
||||
let result = response.json::<model_list::ModelDeleteResponse>();
|
||||
match result {
|
||||
Ok(data) => Ok(data),
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn delete_model_async(
|
||||
&self,
|
||||
model_id: &str,
|
||||
) -> Result<model_list::ModelDeleteResponse, error::ApiError> {
|
||||
let response = self
|
||||
.delete_async(&format!("/models/{}", model_id))
|
||||
.await?;
|
||||
let result = response.json::<model_list::ModelDeleteResponse>().await;
|
||||
match result {
|
||||
Ok(data) => Ok(data),
|
||||
Err(error) => Err(self.to_api_error(error)),
|
||||
}
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// Function Calling
|
||||
// =========================================================================
|
||||
|
||||
pub fn register_function(&mut self, name: String, function: Box<dyn tool::Function>) {
|
||||
let mut functions = self.functions.lock().unwrap();
|
||||
|
||||
functions.insert(name, function);
|
||||
}
|
||||
|
||||
pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
|
||||
let mut result_lock = self.last_function_call_result.lock().unwrap();
|
||||
result_lock.take()
|
||||
}
|
||||
|
||||
// =========================================================================
|
||||
// HTTP Transport
|
||||
// =========================================================================
|
||||
|
||||
fn user_agent(&self) -> String {
|
||||
format!(
|
||||
"mistralai-client-rs/{}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
)
|
||||
}
|
||||
|
||||
fn build_request_sync(
|
||||
&self,
|
||||
request: reqwest::blocking::RequestBuilder,
|
||||
) -> reqwest::blocking::RequestBuilder {
|
||||
let user_agent = format!(
|
||||
"ivangabriele/mistralai-client-rs/{}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
);
|
||||
|
||||
let request_builder = request
|
||||
request
|
||||
.bearer_auth(&self.api_key)
|
||||
.header("Accept", "application/json")
|
||||
.header("User-Agent", user_agent);
|
||||
.header("User-Agent", self.user_agent())
|
||||
}
|
||||
|
||||
request_builder
|
||||
fn build_request_sync_no_accept(
|
||||
&self,
|
||||
request: reqwest::blocking::RequestBuilder,
|
||||
) -> reqwest::blocking::RequestBuilder {
|
||||
request
|
||||
.bearer_auth(&self.api_key)
|
||||
.header("User-Agent", self.user_agent())
|
||||
}
|
||||
|
||||
fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
let user_agent = format!(
|
||||
"ivangabriele/mistralai-client-rs/{}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
);
|
||||
|
||||
let request_builder = request
|
||||
request
|
||||
.bearer_auth(&self.api_key)
|
||||
.header("Accept", "application/json")
|
||||
.header("User-Agent", user_agent);
|
||||
.header("User-Agent", self.user_agent())
|
||||
}
|
||||
|
||||
request_builder
|
||||
fn build_request_async_no_accept(
|
||||
&self,
|
||||
request: reqwest::RequestBuilder,
|
||||
) -> reqwest::RequestBuilder {
|
||||
request
|
||||
.bearer_auth(&self.api_key)
|
||||
.header("User-Agent", self.user_agent())
|
||||
}
|
||||
|
||||
fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||
let user_agent = format!(
|
||||
"ivangabriele/mistralai-client-rs/{}",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
);
|
||||
|
||||
let request_builder = request
|
||||
request
|
||||
.bearer_auth(&self.api_key)
|
||||
.header("Accept", "text/event-stream")
|
||||
.header("User-Agent", user_agent);
|
||||
|
||||
request_builder
|
||||
.header("User-Agent", self.user_agent())
|
||||
}
|
||||
|
||||
fn call_function_if_any(&self, response: chat::ChatResponse) -> () {
|
||||
let next_result = match response.choices.get(0) {
|
||||
Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
|
||||
Some(tool_calls) => match tool_calls.get(0) {
|
||||
fn call_function_if_any(&self, response: chat::ChatResponse) {
|
||||
let next_result = match response.choices.first() {
|
||||
Some(first_choice) => match first_choice.message.tool_calls.as_ref() {
|
||||
Some(tool_calls) => match tool_calls.first() {
|
||||
Some(first_tool_call) => {
|
||||
let functions = self.functions.lock().unwrap();
|
||||
match functions.get(&first_tool_call.function.name) {
|
||||
@@ -431,7 +364,6 @@ impl Client {
|
||||
.execute(first_tool_call.function.arguments.to_owned())
|
||||
.await
|
||||
});
|
||||
|
||||
Some(result)
|
||||
}
|
||||
None => None,
|
||||
@@ -448,10 +380,10 @@ impl Client {
|
||||
*last_result_lock = next_result;
|
||||
}
|
||||
|
||||
async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () {
|
||||
let next_result = match response.choices.get(0) {
|
||||
Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
|
||||
Some(tool_calls) => match tool_calls.get(0) {
|
||||
async fn call_function_if_any_async(&self, response: chat::ChatResponse) {
|
||||
let next_result = match response.choices.first() {
|
||||
Some(first_choice) => match first_choice.message.tool_calls.as_ref() {
|
||||
Some(tool_calls) => match tool_calls.first() {
|
||||
Some(first_tool_call) => {
|
||||
let functions = self.functions.lock().unwrap();
|
||||
match functions.get(&first_tool_call.function.name) {
|
||||
@@ -459,7 +391,6 @@ impl Client {
|
||||
let result = function
|
||||
.execute(first_tool_call.function.arguments.to_owned())
|
||||
.await;
|
||||
|
||||
Some(result)
|
||||
}
|
||||
None => None,
|
||||
@@ -482,27 +413,8 @@ impl Client {
|
||||
debug!("Request URL: {}", url);
|
||||
|
||||
let request = self.build_request_sync(reqwest_client.get(url));
|
||||
|
||||
let result = request.send();
|
||||
match result {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
Ok(response)
|
||||
} else {
|
||||
let response_status = response.status();
|
||||
let response_body = response.text().unwrap_or_default();
|
||||
debug!("Response Status: {}", &response_status);
|
||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
||||
|
||||
Err(error::ApiError {
|
||||
message: format!("{}: {}", response_status, response_body),
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(error) => Err(error::ApiError {
|
||||
message: error.to_string(),
|
||||
}),
|
||||
}
|
||||
self.handle_sync_response(result)
|
||||
}
|
||||
|
||||
async fn get_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
|
||||
@@ -510,29 +422,9 @@ impl Client {
|
||||
let url = format!("{}{}", self.endpoint, path);
|
||||
debug!("Request URL: {}", url);
|
||||
|
||||
let request_builder = reqwest_client.get(url);
|
||||
let request = self.build_request_async(request_builder);
|
||||
|
||||
let request = self.build_request_async(reqwest_client.get(url));
|
||||
let result = request.send().await;
|
||||
match result {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
Ok(response)
|
||||
} else {
|
||||
let response_status = response.status();
|
||||
let response_body = response.text().await.unwrap_or_default();
|
||||
debug!("Response Status: {}", &response_status);
|
||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
||||
|
||||
Err(error::ApiError {
|
||||
message: format!("{}: {}", response_status, response_body),
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(error) => Err(error::ApiError {
|
||||
message: error.to_string(),
|
||||
}),
|
||||
}
|
||||
self.handle_async_response(result).await
|
||||
}
|
||||
|
||||
fn post_sync<T: std::fmt::Debug + serde::ser::Serialize>(
|
||||
@@ -545,29 +437,22 @@ impl Client {
|
||||
debug!("Request URL: {}", url);
|
||||
utils::debug_pretty_json_from_struct("Request Body", params);
|
||||
|
||||
let request_builder = reqwest_client.post(url).json(params);
|
||||
let request = self.build_request_sync(request_builder);
|
||||
|
||||
let request = self.build_request_sync(reqwest_client.post(url).json(params));
|
||||
let result = request.send();
|
||||
match result {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
Ok(response)
|
||||
} else {
|
||||
let response_status = response.status();
|
||||
let response_body = response.text().unwrap_or_default();
|
||||
debug!("Response Status: {}", &response_status);
|
||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
||||
self.handle_sync_response(result)
|
||||
}
|
||||
|
||||
Err(error::ApiError {
|
||||
message: format!("{}: {}", response_body, response_status),
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(error) => Err(error::ApiError {
|
||||
message: error.to_string(),
|
||||
}),
|
||||
}
|
||||
fn post_sync_empty(
|
||||
&self,
|
||||
path: &str,
|
||||
) -> Result<reqwest::blocking::Response, error::ApiError> {
|
||||
let reqwest_client = reqwest::blocking::Client::new();
|
||||
let url = format!("{}{}", self.endpoint, path);
|
||||
debug!("Request URL: {}", url);
|
||||
|
||||
let request = self.build_request_sync(reqwest_client.post(url));
|
||||
let result = request.send();
|
||||
self.handle_sync_response(result)
|
||||
}
|
||||
|
||||
async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
|
||||
@@ -580,29 +465,19 @@ impl Client {
|
||||
debug!("Request URL: {}", url);
|
||||
utils::debug_pretty_json_from_struct("Request Body", params);
|
||||
|
||||
let request_builder = reqwest_client.post(url).json(params);
|
||||
let request = self.build_request_async(request_builder);
|
||||
|
||||
let request = self.build_request_async(reqwest_client.post(url).json(params));
|
||||
let result = request.send().await;
|
||||
match result {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
Ok(response)
|
||||
} else {
|
||||
let response_status = response.status();
|
||||
let response_body = response.text().await.unwrap_or_default();
|
||||
debug!("Response Status: {}", &response_status);
|
||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
||||
self.handle_async_response(result).await
|
||||
}
|
||||
|
||||
Err(error::ApiError {
|
||||
message: format!("{}: {}", response_status, response_body),
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(error) => Err(error::ApiError {
|
||||
message: error.to_string(),
|
||||
}),
|
||||
}
|
||||
async fn post_async_empty(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
|
||||
let reqwest_client = reqwest::Client::new();
|
||||
let url = format!("{}{}", self.endpoint, path);
|
||||
debug!("Request URL: {}", url);
|
||||
|
||||
let request = self.build_request_async(reqwest_client.post(url));
|
||||
let result = request.send().await;
|
||||
self.handle_async_response(result).await
|
||||
}
|
||||
|
||||
async fn post_stream<T: serde::ser::Serialize + std::fmt::Debug>(
|
||||
@@ -615,22 +490,70 @@ impl Client {
|
||||
debug!("Request URL: {}", url);
|
||||
utils::debug_pretty_json_from_struct("Request Body", params);
|
||||
|
||||
let request_builder = reqwest_client.post(url).json(params);
|
||||
let request = self.build_request_stream(request_builder);
|
||||
|
||||
let request = self.build_request_stream(reqwest_client.post(url).json(params));
|
||||
let result = request.send().await;
|
||||
self.handle_async_response(result).await
|
||||
}
|
||||
|
||||
fn delete_sync(&self, path: &str) -> Result<reqwest::blocking::Response, error::ApiError> {
|
||||
let reqwest_client = reqwest::blocking::Client::new();
|
||||
let url = format!("{}{}", self.endpoint, path);
|
||||
debug!("Request URL: {}", url);
|
||||
|
||||
let request = self.build_request_sync(reqwest_client.delete(url));
|
||||
let result = request.send();
|
||||
self.handle_sync_response(result)
|
||||
}
|
||||
|
||||
async fn delete_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
|
||||
let reqwest_client = reqwest::Client::new();
|
||||
let url = format!("{}{}", self.endpoint, path);
|
||||
debug!("Request URL: {}", url);
|
||||
|
||||
let request = self.build_request_async(reqwest_client.delete(url));
|
||||
let result = request.send().await;
|
||||
self.handle_async_response(result).await
|
||||
}
|
||||
|
||||
fn handle_sync_response(
|
||||
&self,
|
||||
result: Result<reqwest::blocking::Response, reqwest::Error>,
|
||||
) -> Result<reqwest::blocking::Response, error::ApiError> {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
Ok(response)
|
||||
} else {
|
||||
let response_status = response.status();
|
||||
let response_body = response.text().await.unwrap_or_default();
|
||||
debug!("Response Status: {}", &response_status);
|
||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
||||
|
||||
let status = response.status();
|
||||
let body = response.text().unwrap_or_default();
|
||||
debug!("Response Status: {}", &status);
|
||||
utils::debug_pretty_json_from_string("Response Data", &body);
|
||||
Err(error::ApiError {
|
||||
message: format!("{}: {}", response_status, response_body),
|
||||
message: format!("{}: {}", status, body),
|
||||
})
|
||||
}
|
||||
}
|
||||
Err(error) => Err(error::ApiError {
|
||||
message: error.to_string(),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_async_response(
|
||||
&self,
|
||||
result: Result<reqwest::Response, reqwest::Error>,
|
||||
) -> Result<reqwest::Response, error::ApiError> {
|
||||
match result {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
Ok(response)
|
||||
} else {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
debug!("Response Status: {}", &status);
|
||||
utils::debug_pretty_json_from_string("Response Data", &body);
|
||||
Err(error::ApiError {
|
||||
message: format!("{}: {}", status, body),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ResponseUsage {
|
||||
pub prompt_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
@@ -1,35 +1,131 @@
|
||||
use std::fmt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum Model {
|
||||
#[serde(rename = "open-mistral-7b")]
|
||||
OpenMistral7b,
|
||||
#[serde(rename = "open-mixtral-8x7b")]
|
||||
OpenMixtral8x7b,
|
||||
#[serde(rename = "open-mixtral-8x22b")]
|
||||
OpenMixtral8x22b,
|
||||
#[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")]
|
||||
OpenMistralNemo,
|
||||
#[serde(rename = "mistral-tiny")]
|
||||
MistralTiny,
|
||||
#[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")]
|
||||
MistralSmallLatest,
|
||||
#[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")]
|
||||
MistralMediumLatest,
|
||||
#[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")]
|
||||
MistralLargeLatest,
|
||||
#[serde(rename = "mistral-large-2402")]
|
||||
MistralLarge,
|
||||
#[serde(rename = "codestral-latest", alias = "codestral-2405")]
|
||||
CodestralLatest,
|
||||
#[serde(rename = "open-codestral-mamba")]
|
||||
CodestralMamba,
|
||||
/// A Mistral AI model identifier.
|
||||
///
|
||||
/// Use the associated constants for known models, or construct with `Model::new()` for any model string.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct Model(pub String);
|
||||
|
||||
impl Model {
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
|
||||
// Flagship / Premier
|
||||
pub fn mistral_large_latest() -> Self {
|
||||
Self::new("mistral-large-latest")
|
||||
}
|
||||
pub fn mistral_large_3() -> Self {
|
||||
Self::new("mistral-large-3-25-12")
|
||||
}
|
||||
pub fn mistral_medium_latest() -> Self {
|
||||
Self::new("mistral-medium-latest")
|
||||
}
|
||||
pub fn mistral_medium_3_1() -> Self {
|
||||
Self::new("mistral-medium-3-1-25-08")
|
||||
}
|
||||
pub fn mistral_small_latest() -> Self {
|
||||
Self::new("mistral-small-latest")
|
||||
}
|
||||
pub fn mistral_small_4() -> Self {
|
||||
Self::new("mistral-small-4-0-26-03")
|
||||
}
|
||||
pub fn mistral_small_3_2() -> Self {
|
||||
Self::new("mistral-small-3-2-25-06")
|
||||
}
|
||||
|
||||
// Ministral
|
||||
pub fn ministral_3_14b() -> Self {
|
||||
Self::new("ministral-3-14b-25-12")
|
||||
}
|
||||
pub fn ministral_3_8b() -> Self {
|
||||
Self::new("ministral-3-8b-25-12")
|
||||
}
|
||||
pub fn ministral_3_3b() -> Self {
|
||||
Self::new("ministral-3-3b-25-12")
|
||||
}
|
||||
|
||||
// Reasoning
|
||||
pub fn magistral_medium_latest() -> Self {
|
||||
Self::new("magistral-medium-latest")
|
||||
}
|
||||
pub fn magistral_small_latest() -> Self {
|
||||
Self::new("magistral-small-latest")
|
||||
}
|
||||
|
||||
// Code
|
||||
pub fn codestral_latest() -> Self {
|
||||
Self::new("codestral-latest")
|
||||
}
|
||||
pub fn codestral_2508() -> Self {
|
||||
Self::new("codestral-2508")
|
||||
}
|
||||
pub fn codestral_embed() -> Self {
|
||||
Self::new("codestral-embed-25-05")
|
||||
}
|
||||
pub fn devstral_2() -> Self {
|
||||
Self::new("devstral-2-25-12")
|
||||
}
|
||||
pub fn devstral_small_2() -> Self {
|
||||
Self::new("devstral-small-2-25-12")
|
||||
}
|
||||
|
||||
// Multimodal / Vision
|
||||
pub fn pixtral_large() -> Self {
|
||||
Self::new("pixtral-large-2411")
|
||||
}
|
||||
|
||||
// Audio
|
||||
pub fn voxtral_mini_transcribe() -> Self {
|
||||
Self::new("voxtral-mini-transcribe-2-26-02")
|
||||
}
|
||||
pub fn voxtral_small() -> Self {
|
||||
Self::new("voxtral-small-25-07")
|
||||
}
|
||||
pub fn voxtral_mini() -> Self {
|
||||
Self::new("voxtral-mini-25-07")
|
||||
}
|
||||
|
||||
// Legacy (kept for backward compatibility)
|
||||
pub fn open_mistral_nemo() -> Self {
|
||||
Self::new("open-mistral-nemo")
|
||||
}
|
||||
|
||||
// Embedding
|
||||
pub fn mistral_embed() -> Self {
|
||||
Self::new("mistral-embed")
|
||||
}
|
||||
|
||||
// Moderation
|
||||
pub fn mistral_moderation_latest() -> Self {
|
||||
Self::new("mistral-moderation-26-03")
|
||||
}
|
||||
|
||||
// OCR
|
||||
pub fn mistral_ocr_latest() -> Self {
|
||||
Self::new("mistral-ocr-latest")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum EmbedModel {
|
||||
#[serde(rename = "mistral-embed")]
|
||||
MistralEmbed,
|
||||
impl fmt::Display for Model {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Model {
|
||||
fn from(s: &str) -> Self {
|
||||
Self(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Model {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,42 +8,63 @@ use crate::v1::{common, constants};
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingRequestOptions {
|
||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||
pub output_dimension: Option<u32>,
|
||||
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||
}
|
||||
impl Default for EmbeddingRequestOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
encoding_format: None,
|
||||
output_dimension: None,
|
||||
output_dtype: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub model: constants::EmbedModel,
|
||||
pub model: constants::Model,
|
||||
pub input: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_dimension: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||
}
|
||||
impl EmbeddingRequest {
|
||||
pub fn new(
|
||||
model: constants::EmbedModel,
|
||||
model: constants::Model,
|
||||
input: Vec<String>,
|
||||
options: Option<EmbeddingRequestOptions>,
|
||||
) -> Self {
|
||||
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default();
|
||||
let opts = options.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
model,
|
||||
input,
|
||||
encoding_format,
|
||||
encoding_format: opts.encoding_format,
|
||||
output_dimension: opts.output_dimension,
|
||||
output_dtype: opts.output_dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
#[allow(non_camel_case_types)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EmbeddingRequestEncodingFormat {
|
||||
float,
|
||||
Float,
|
||||
Base64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EmbeddingOutputDtype {
|
||||
Float,
|
||||
Int8,
|
||||
Uint8,
|
||||
Binary,
|
||||
Ubinary,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
@@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat {
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub model: constants::EmbedModel,
|
||||
pub model: constants::Model,
|
||||
pub data: Vec<EmbeddingResponseDataItem>,
|
||||
pub usage: common::ResponseUsage,
|
||||
}
|
||||
|
||||
@@ -15,23 +15,44 @@ pub struct ModelListData {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
/// Unix timestamp (in seconds).
|
||||
pub created: u32,
|
||||
pub created: u64,
|
||||
pub owned_by: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub root: Option<String>,
|
||||
#[serde(default)]
|
||||
pub archived: bool,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub capabilities: ModelListDataCapabilies,
|
||||
pub max_context_length: u32,
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub capabilities: Option<ModelListDataCapabilities>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_context_length: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub aliases: Vec<String>,
|
||||
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub deprecation: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ModelListDataCapabilies {
|
||||
pub struct ModelListDataCapabilities {
|
||||
#[serde(default)]
|
||||
pub completion_chat: bool,
|
||||
#[serde(default)]
|
||||
pub completion_fim: bool,
|
||||
#[serde(default)]
|
||||
pub function_calling: bool,
|
||||
#[serde(default)]
|
||||
pub fine_tuning: bool,
|
||||
#[serde(default)]
|
||||
pub vision: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ModelDeleteResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub deleted: bool,
|
||||
}
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{any::Any, collections::HashMap, fmt::Debug};
|
||||
use std::{any::Any, fmt::Debug};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Definitions
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub r#type: Option<String>,
|
||||
pub function: ToolCallFunction,
|
||||
}
|
||||
|
||||
@@ -22,31 +26,12 @@ pub struct Tool {
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
impl Tool {
|
||||
/// Create a tool with a JSON Schema parameters object.
|
||||
pub fn new(
|
||||
function_name: String,
|
||||
function_description: String,
|
||||
function_parameters: Vec<ToolFunctionParameter>,
|
||||
parameters: serde_json::Value,
|
||||
) -> Self {
|
||||
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
|
||||
.into_iter()
|
||||
.map(|param| {
|
||||
(
|
||||
param.name,
|
||||
ToolFunctionParameterProperty {
|
||||
r#type: param.r#type,
|
||||
description: param.description,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let property_names = properties.keys().cloned().collect();
|
||||
|
||||
let parameters = ToolFunctionParameters {
|
||||
r#type: ToolFunctionParametersType::Object,
|
||||
properties,
|
||||
required: property_names,
|
||||
};
|
||||
|
||||
Self {
|
||||
r#type: ToolType::Function,
|
||||
function: ToolFunction {
|
||||
@@ -63,50 +48,9 @@ impl Tool {
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolFunction {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: ToolFunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolFunctionParameter {
|
||||
name: String,
|
||||
description: String,
|
||||
r#type: ToolFunctionParameterType,
|
||||
}
|
||||
impl ToolFunctionParameter {
|
||||
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
|
||||
Self {
|
||||
name,
|
||||
r#type,
|
||||
description,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolFunctionParameters {
|
||||
r#type: ToolFunctionParametersType,
|
||||
properties: HashMap<String, ToolFunctionParameterProperty>,
|
||||
required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolFunctionParameterProperty {
|
||||
r#type: ToolFunctionParameterType,
|
||||
description: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ToolFunctionParametersType {
|
||||
#[serde(rename = "object")]
|
||||
Object,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ToolFunctionParameterType {
|
||||
#[serde(rename = "string")]
|
||||
String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
@@ -127,6 +71,9 @@ pub enum ToolChoice {
|
||||
/// The model won't call a function and will generate a message instead.
|
||||
#[serde(rename = "none")]
|
||||
None,
|
||||
/// The model must call at least one tool.
|
||||
#[serde(rename = "required")]
|
||||
Required,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
Reference in New Issue
Block a user