Update to latest Mistral AI API (v1.0.0)
Some checks failed
Test / Test Documentation (push) Has been cancelled
Test / Test Examples (push) Has been cancelled
Test / Test (push) Has been cancelled

- Replace closed Model enum with flexible string-based Model type
  with constructor methods for all current models (Mistral Large 3,
  Small 4, Magistral, Codestral, Devstral, Pixtral, Voxtral, etc.)
- Add new API endpoints: FIM completions, Files, Fine-tuning, Batch
  jobs, OCR, Audio transcription, Moderations/Classifications, and
  Agent completions (sync + async for all)
- Add new chat fields: frequency_penalty, presence_penalty, stop,
  n, parallel_tool_calls, reasoning_effort, min_tokens, json_schema
  response format
- Add embedding fields: output_dimension, output_dtype
- Tool parameters now accept raw JSON Schema (serde_json::Value)
  instead of limited enum types
- Add tool call IDs and Required tool choice variant
- Add DELETE HTTP method support and multipart file upload
- Bump thiserror to v2, add reqwest multipart feature
- Remove strum dependency (no longer needed)
- Update all tests and examples for new API
This commit is contained in:
2026-03-20 17:16:26 +00:00
parent 9ad6a1dc84
commit 79bc40bb15
33 changed files with 1977 additions and 622 deletions

View File

@@ -2,7 +2,7 @@
name = "mistralai-client"
description = "Mistral AI API client library for Rust (unofficial)."
license = "Apache-2.0"
version = "0.14.0"
version = "1.0.0"
edition = "2021"
rust-version = "1.76.0"
@@ -15,18 +15,17 @@ readme = "README.md"
repository = "https://github.com/ivangabriele/mistralai-client-rs"
[dependencies]
async-stream = "0.3.5"
async-trait = "0.1.77"
env_logger = "0.11.3"
futures = "0.3.30"
log = "0.4.21"
reqwest = { version = "0.12.0", features = ["json", "blocking", "stream"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
strum = "0.26.1"
thiserror = "1.0.57"
tokio = { version = "1.36.0", features = ["full"] }
tokio-stream = "0.1.14"
async-stream = "0.3"
async-trait = "0.1"
env_logger = "0.11"
futures = "0.3"
log = "0.4"
reqwest = { version = "0.12", features = ["json", "blocking", "stream", "multipart"] }
serde = { version = "1", features = ["derive"] }
serde_json = "1"
thiserror = "2"
tokio = { version = "1", features = ["full"] }
tokio-stream = "0.1"
[dev-dependencies]
jrest = "0.2.3"
jrest = "0.2"

View File

@@ -1,5 +1,5 @@
use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams},
chat::{ChatMessage, ChatParams},
client::Client,
constants::Model,
};
@@ -8,14 +8,12 @@ fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b;
let messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
tool_calls: None,
}];
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"Just guess the next word: \"Eiffel ...\"?",
)];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
};

View File

@@ -1,5 +1,5 @@
use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams},
chat::{ChatMessage, ChatParams},
client::Client,
constants::Model,
};
@@ -9,14 +9,12 @@ async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b;
let messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
tool_calls: None,
}];
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"Just guess the next word: \"Eiffel ...\"?",
)];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
};
@@ -29,5 +27,4 @@ async fn main() {
"{:?}: {}",
result.choices[0].message.role, result.choices[0].message.content
);
// => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
}

View File

@@ -1,8 +1,8 @@
use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams},
chat::{ChatMessage, ChatParams},
client::Client,
constants::Model,
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
tool::{Function, Tool, ToolChoice},
};
use serde::Deserialize;
use std::any::Any;
@@ -16,7 +16,6 @@ struct GetCityTemperatureFunction;
#[async_trait::async_trait]
impl Function for GetCityTemperatureFunction {
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
// Deserialize arguments, perform the logic, and return the result
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
let temperature = match city.as_str() {
@@ -32,11 +31,16 @@ fn main() {
let tools = vec![Tool::new(
"get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new(
"city".to_string(),
"The name of the city.".to_string(),
ToolFunctionParameterType::String,
)],
serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)];
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
@@ -46,14 +50,12 @@ fn main() {
Box::new(GetCityTemperatureFunction),
);
let model = Model::MistralSmallLatest;
let messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "What's the temperature in Paris?".to_string(),
tool_calls: None,
}];
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"What's the temperature in Paris?",
)];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto),
tools: Some(tools),

View File

@@ -1,8 +1,8 @@
use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams},
chat::{ChatMessage, ChatParams},
client::Client,
constants::Model,
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
tool::{Function, Tool, ToolChoice},
};
use serde::Deserialize;
use std::any::Any;
@@ -16,7 +16,6 @@ struct GetCityTemperatureFunction;
#[async_trait::async_trait]
impl Function for GetCityTemperatureFunction {
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
// Deserialize arguments, perform the logic, and return the result
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
let temperature = match city.as_str() {
@@ -33,11 +32,16 @@ async fn main() {
let tools = vec![Tool::new(
"get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new(
"city".to_string(),
"The name of the city.".to_string(),
ToolFunctionParameterType::String,
)],
serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)];
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
@@ -47,14 +51,12 @@ async fn main() {
Box::new(GetCityTemperatureFunction),
);
let model = Model::MistralSmallLatest;
let messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "What's the temperature in Paris?".to_string(),
tool_calls: None,
}];
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"What's the temperature in Paris?",
)];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto),
tools: Some(tools),

View File

@@ -1,6 +1,6 @@
use futures::stream::StreamExt;
use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams},
chat::{ChatMessage, ChatParams},
client::Client,
constants::Model,
};
@@ -11,14 +11,10 @@ async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b;
let messages = vec![ChatMessage {
role: ChatMessageRole::User,
content: "Tell me a short happy story.".to_string(),
tool_calls: None,
}];
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message("Tell me a short happy story.")];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
};
@@ -31,9 +27,10 @@ async fn main() {
.for_each(|chunk_result| async {
match chunk_result {
Ok(chunks) => chunks.iter().for_each(|chunk| {
print!("{}", chunk.choices[0].delta.content);
io::stdout().flush().unwrap();
// => "Once upon a time, [...]"
if let Some(content) = &chunk.choices[0].delta.content {
print!("{}", content);
io::stdout().flush().unwrap();
}
}),
Err(error) => {
eprintln!("Error processing chunk: {:?}", error)
@@ -41,5 +38,5 @@ async fn main() {
}
})
.await;
print!("\n") // To persist the last chunk output.
println!();
}

View File

@@ -1,10 +1,10 @@
use mistralai_client::v1::{client::Client, constants::EmbedModel};
use mistralai_client::v1::{client::Client, constants::Model};
fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client: Client = Client::new(None, None, None, None).unwrap();
let model = EmbedModel::MistralEmbed;
let model = Model::mistral_embed();
let input = vec!["Embed this sentence.", "As well as this one."]
.iter()
.map(|s| s.to_string())
@@ -13,5 +13,4 @@ fn main() {
let response = client.embeddings(model, input, options).unwrap();
println!("First Embedding: {:?}", response.data[0]);
// => "First Embedding: {...}"
}

View File

@@ -1,11 +1,11 @@
use mistralai_client::v1::{client::Client, constants::EmbedModel};
use mistralai_client::v1::{client::Client, constants::Model};
#[tokio::main]
async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client: Client = Client::new(None, None, None, None).unwrap();
let model = EmbedModel::MistralEmbed;
let model = Model::mistral_embed();
let input = vec!["Embed this sentence.", "As well as this one."]
.iter()
.map(|s| s.to_string())
@@ -17,5 +17,4 @@ async fn main() {
.await
.unwrap();
println!("First Embedding: {:?}", response.data[0]);
// => "First Embedding: {...}"
}

21
examples/fim.rs Normal file
View File

@@ -0,0 +1,21 @@
use mistralai_client::v1::{
client::Client,
constants::Model,
fim::FimParams,
};
fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap();
let model = Model::codestral_latest();
let prompt = "def fibonacci(n):".to_string();
let options = FimParams {
suffix: Some("\n return result".to_string()),
temperature: Some(0.0),
..Default::default()
};
let response = client.fim(model, prompt, Some(options)).unwrap();
println!("Completion: {}", response.choices[0].message.content);
}

25
examples/ocr.rs Normal file
View File

@@ -0,0 +1,25 @@
use mistralai_client::v1::{
client::Client,
constants::Model,
ocr::{OcrDocument, OcrRequest},
};
fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap();
let request = OcrRequest {
model: Model::mistral_ocr_latest(),
document: OcrDocument::from_url("https://arxiv.org/pdf/2201.04234"),
pages: Some(vec![0]),
table_format: None,
include_image_base64: None,
image_limit: None,
};
let response = client.ocr(&request).unwrap();
for page in &response.pages {
println!("--- Page {} ---", page.index);
println!("{}", &page.markdown[..200.min(page.markdown.len())]);
}
}

98
src/v1/agents.rs Normal file
View File

@@ -0,0 +1,98 @@
use serde::{Deserialize, Serialize};
use crate::v1::{chat, common, constants, tool};
// -----------------------------------------------------------------------------
// Request
#[derive(Debug)]
pub struct AgentCompletionParams {
pub max_tokens: Option<u32>,
pub min_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub random_seed: Option<u32>,
pub stop: Option<Vec<String>>,
pub response_format: Option<chat::ResponseFormat>,
pub tools: Option<Vec<tool::Tool>>,
pub tool_choice: Option<tool::ToolChoice>,
}
impl Default for AgentCompletionParams {
fn default() -> Self {
Self {
max_tokens: None,
min_tokens: None,
temperature: None,
top_p: None,
random_seed: None,
stop: None,
response_format: None,
tools: None,
tool_choice: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AgentCompletionRequest {
pub agent_id: String,
pub messages: Vec<chat::ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<chat::ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<tool::Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<tool::ToolChoice>,
}
impl AgentCompletionRequest {
pub fn new(
agent_id: String,
messages: Vec<chat::ChatMessage>,
stream: bool,
options: Option<AgentCompletionParams>,
) -> Self {
let opts = options.unwrap_or_default();
Self {
agent_id,
messages,
stream,
max_tokens: opts.max_tokens,
min_tokens: opts.min_tokens,
temperature: opts.temperature,
top_p: opts.top_p,
random_seed: opts.random_seed,
stop: opts.stop,
response_format: opts.response_format,
tools: opts.tools,
tool_choice: opts.tool_choice,
}
}
}
// -----------------------------------------------------------------------------
// Response (same shape as chat completions)
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AgentCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: constants::Model,
pub choices: Vec<chat::ChatResponseChoice>,
pub usage: common::ResponseUsage,
}

78
src/v1/audio.rs Normal file
View File

@@ -0,0 +1,78 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request (multipart form, but we define the params struct)
#[derive(Debug)]
pub struct AudioTranscriptionParams {
pub model: constants::Model,
pub language: Option<String>,
pub temperature: Option<f32>,
pub diarize: Option<bool>,
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}
impl Default for AudioTranscriptionParams {
fn default() -> Self {
Self {
model: constants::Model::voxtral_mini_transcribe(),
language: None,
temperature: None,
diarize: None,
timestamp_granularities: None,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TimestampGranularity {
Segment,
Word,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<TranscriptionSegment>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<TranscriptionWord>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<AudioUsage>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TranscriptionSegment {
pub id: u32,
pub start: f32,
pub end: f32,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TranscriptionWord {
pub word: String,
pub start: f32,
pub end: f32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AudioUsage {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_audio_seconds: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_tokens: Option<u32>,
}

53
src/v1/batch.rs Normal file
View File

@@ -0,0 +1,53 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchJobRequest {
pub input_files: Vec<String>,
pub model: constants::Model,
pub endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct BatchJobResponse {
pub id: String,
pub object: String,
pub model: constants::Model,
pub endpoint: String,
pub input_files: Vec<String>,
pub status: String,
pub created_at: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_requests: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_requests: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub succeeded_requests: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub failed_requests: Option<u64>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct BatchJobListResponse {
pub data: Vec<BatchJobResponse>,
pub object: String,
#[serde(default)]
pub total: u32,
}

View File

@@ -11,13 +11,31 @@ pub struct ChatMessage {
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>,
/// Tool call ID, required when role is Tool.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// Function name, used when role is Tool.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
}
impl ChatMessage {
pub fn new_system_message(content: &str) -> Self {
Self {
role: ChatMessageRole::System,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
Self {
role: ChatMessageRole::Assistant,
content: content.to_string(),
tool_calls,
tool_call_id: None,
name: None,
}
}
@@ -26,6 +44,18 @@ impl ChatMessage {
role: ChatMessageRole::User,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
Self {
role: ChatMessageRole::Tool,
content: content.to_string(),
tool_calls: None,
tool_call_id: Some(tool_call_id.to_string()),
name: name.map(|n| n.to_string()),
}
}
}
@@ -44,17 +74,32 @@ pub enum ChatMessageRole {
}
/// The format that the model must output.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub type_: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<serde_json::Value>,
}
impl ResponseFormat {
pub fn text() -> Self {
Self {
type_: "text".to_string(),
json_schema: None,
}
}
pub fn json_object() -> Self {
Self {
type_: "json_object".to_string(),
json_schema: None,
}
}
pub fn json_schema(schema: serde_json::Value) -> Self {
Self {
type_: "json_schema".to_string(),
json_schema: Some(schema),
}
}
}
@@ -63,91 +108,83 @@ impl ResponseFormat {
// Request
/// The parameters for the chat request.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Clone, Debug)]
pub struct ChatParams {
/// The maximum number of tokens to generate in the completion.
///
/// Defaults to `None`.
pub max_tokens: Option<u32>,
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
///
/// Defaults to `None`.
pub min_tokens: Option<u32>,
pub random_seed: Option<u32>,
/// The format that the model must output.
///
/// Defaults to `None`.
pub response_format: Option<ResponseFormat>,
/// Whether to inject a safety prompt before all conversations.
///
/// Defaults to `false`.
pub safe_prompt: bool,
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
///
/// Defaults to `0.7`.
pub temperature: f32,
/// Specifies if/how functions are called.
///
/// Defaults to `None`.
pub temperature: Option<f32>,
pub tool_choice: Option<tool::ToolChoice>,
/// A list of available tools for the model.
///
/// Defaults to `None`.
pub tools: Option<Vec<tool::Tool>>,
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
///
/// Defaults to `1.0`.
pub top_p: f32,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
pub n: Option<u32>,
pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
/// For reasoning models (Magistral). "high" or "none".
pub reasoning_effort: Option<String>,
}
impl Default for ChatParams {
fn default() -> Self {
Self {
max_tokens: None,
min_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
temperature: None,
tool_choice: None,
tools: None,
top_p: 1.0,
}
}
}
impl ChatParams {
pub fn json_default() -> Self {
Self {
max_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
tool_choice: None,
tools: None,
top_p: 1.0,
top_p: None,
stop: None,
n: None,
frequency_penalty: None,
presence_penalty: None,
parallel_tool_calls: None,
reasoning_effort: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatRequest {
pub messages: Vec<ChatMessage>,
pub model: constants::Model,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
pub safe_prompt: bool,
pub stream: bool,
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub safe_prompt: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<tool::ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<tool::Tool>>,
pub top_p: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
}
impl ChatRequest {
pub fn new(
@@ -156,30 +193,28 @@ impl ChatRequest {
stream: bool,
options: Option<ChatParams>,
) -> Self {
let ChatParams {
max_tokens,
random_seed,
safe_prompt,
temperature,
tool_choice,
tools,
top_p,
response_format,
} = options.unwrap_or_default();
let opts = options.unwrap_or_default();
let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
Self {
messages,
model,
max_tokens,
random_seed,
safe_prompt,
messages,
stream,
temperature,
tool_choice,
tools,
top_p,
response_format,
max_tokens: opts.max_tokens,
min_tokens: opts.min_tokens,
random_seed: opts.random_seed,
safe_prompt,
temperature: opts.temperature,
tool_choice: opts.tool_choice,
tools: opts.tools,
top_p: opts.top_p,
response_format: opts.response_format,
stop: opts.stop,
n: opts.n,
frequency_penalty: opts.frequency_penalty,
presence_penalty: opts.presence_penalty,
parallel_tool_calls: opts.parallel_tool_calls,
reasoning_effort: opts.reasoning_effort,
}
}
}
@@ -192,7 +227,7 @@ pub struct ChatResponse {
pub id: String,
pub object: String,
/// Unix timestamp (in seconds).
pub created: u32,
pub created: u64,
pub model: constants::Model,
pub choices: Vec<ChatResponseChoice>,
pub usage: common::ResponseUsage,
@@ -203,14 +238,18 @@ pub struct ChatResponseChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: ChatResponseChoiceFinishReason,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ChatResponseChoiceFinishReason {
#[serde(rename = "stop")]
Stop,
#[serde(rename = "length")]
Length,
#[serde(rename = "tool_calls")]
ToolCalls,
#[serde(rename = "model_length")]
ModelLength,
#[serde(rename = "error")]
Error,
}

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use serde_json::from_str;
use crate::v1::{chat, common, constants, error};
use crate::v1::{chat, common, constants, error, tool};
// -----------------------------------------------------------------------------
// Response
@@ -11,12 +11,11 @@ pub struct ChatStreamChunk {
pub id: String,
pub object: String,
/// Unix timestamp (in seconds).
pub created: u32,
pub created: u64,
pub model: constants::Model,
pub choices: Vec<ChatStreamChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<common::ResponseUsage>,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
@@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice {
pub index: u32,
pub delta: ChatStreamChunkChoiceDelta,
pub finish_reason: Option<String>,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ChatStreamChunkChoiceDelta {
pub role: Option<chat::ChatMessageRole>,
pub content: String,
#[serde(default)]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>,
}
/// Extracts serialized chunks from a stream message.
@@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line(
return Ok(Some(vec![]));
}
// Attempt to deserialize the JSON string into ChatStreamChunk
match from_str::<ChatStreamChunk>(chunk_as_json) {
Ok(chunk) => Ok(Some(vec![chunk])),
Err(e) => Err(error::ApiError {

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ResponseUsage {
pub prompt_tokens: u32,
#[serde(default)]
pub completion_tokens: u32,
pub total_tokens: u32,
}

View File

@@ -1,35 +1,131 @@
use std::fmt;
use serde::{Deserialize, Serialize};
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum Model {
#[serde(rename = "open-mistral-7b")]
OpenMistral7b,
#[serde(rename = "open-mixtral-8x7b")]
OpenMixtral8x7b,
#[serde(rename = "open-mixtral-8x22b")]
OpenMixtral8x22b,
#[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")]
OpenMistralNemo,
#[serde(rename = "mistral-tiny")]
MistralTiny,
#[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")]
MistralSmallLatest,
#[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")]
MistralMediumLatest,
#[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")]
MistralLargeLatest,
#[serde(rename = "mistral-large-2402")]
MistralLarge,
#[serde(rename = "codestral-latest", alias = "codestral-2405")]
CodestralLatest,
#[serde(rename = "open-codestral-mamba")]
CodestralMamba,
/// A Mistral AI model identifier.
///
/// Use the associated constants for known models, or construct with `Model::new()` for any model string.
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(transparent)]
pub struct Model(pub String);
impl Model {
pub fn new(id: impl Into<String>) -> Self {
Self(id.into())
}
// Flagship / Premier
pub fn mistral_large_latest() -> Self {
Self::new("mistral-large-latest")
}
pub fn mistral_large_3() -> Self {
Self::new("mistral-large-3-25-12")
}
pub fn mistral_medium_latest() -> Self {
Self::new("mistral-medium-latest")
}
pub fn mistral_medium_3_1() -> Self {
Self::new("mistral-medium-3-1-25-08")
}
pub fn mistral_small_latest() -> Self {
Self::new("mistral-small-latest")
}
pub fn mistral_small_4() -> Self {
Self::new("mistral-small-4-0-26-03")
}
pub fn mistral_small_3_2() -> Self {
Self::new("mistral-small-3-2-25-06")
}
// Ministral
pub fn ministral_3_14b() -> Self {
Self::new("ministral-3-14b-25-12")
}
pub fn ministral_3_8b() -> Self {
Self::new("ministral-3-8b-25-12")
}
pub fn ministral_3_3b() -> Self {
Self::new("ministral-3-3b-25-12")
}
// Reasoning
pub fn magistral_medium_latest() -> Self {
Self::new("magistral-medium-latest")
}
pub fn magistral_small_latest() -> Self {
Self::new("magistral-small-latest")
}
// Code
pub fn codestral_latest() -> Self {
Self::new("codestral-latest")
}
pub fn codestral_2508() -> Self {
Self::new("codestral-2508")
}
pub fn codestral_embed() -> Self {
Self::new("codestral-embed-25-05")
}
pub fn devstral_2() -> Self {
Self::new("devstral-2-25-12")
}
pub fn devstral_small_2() -> Self {
Self::new("devstral-small-2-25-12")
}
// Multimodal / Vision
pub fn pixtral_large() -> Self {
Self::new("pixtral-large-2411")
}
// Audio
pub fn voxtral_mini_transcribe() -> Self {
Self::new("voxtral-mini-transcribe-2-26-02")
}
pub fn voxtral_small() -> Self {
Self::new("voxtral-small-25-07")
}
pub fn voxtral_mini() -> Self {
Self::new("voxtral-mini-25-07")
}
// Legacy (kept for backward compatibility)
pub fn open_mistral_nemo() -> Self {
Self::new("open-mistral-nemo")
}
// Embedding
pub fn mistral_embed() -> Self {
Self::new("mistral-embed")
}
// Moderation
pub fn mistral_moderation_latest() -> Self {
Self::new("mistral-moderation-26-03")
}
// OCR
pub fn mistral_ocr_latest() -> Self {
Self::new("mistral-ocr-latest")
}
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum EmbedModel {
#[serde(rename = "mistral-embed")]
MistralEmbed,
impl fmt::Display for Model {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<&str> for Model {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for Model {
fn from(s: String) -> Self {
Self(s)
}
}

View File

@@ -8,42 +8,63 @@ use crate::v1::{common, constants};
#[derive(Debug)]
pub struct EmbeddingRequestOptions {
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
pub output_dimension: Option<u32>,
pub output_dtype: Option<EmbeddingOutputDtype>,
}
impl Default for EmbeddingRequestOptions {
fn default() -> Self {
Self {
encoding_format: None,
output_dimension: None,
output_dtype: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub model: constants::EmbedModel,
pub model: constants::Model,
pub input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dimension: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dtype: Option<EmbeddingOutputDtype>,
}
impl EmbeddingRequest {
pub fn new(
model: constants::EmbedModel,
model: constants::Model,
input: Vec<String>,
options: Option<EmbeddingRequestOptions>,
) -> Self {
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default();
let opts = options.unwrap_or_default();
Self {
model,
input,
encoding_format,
encoding_format: opts.encoding_format,
output_dimension: opts.output_dimension,
output_dtype: opts.output_dtype,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[allow(non_camel_case_types)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingRequestEncodingFormat {
float,
Float,
Base64,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingOutputDtype {
Float,
Int8,
Uint8,
Binary,
Ubinary,
}
// -----------------------------------------------------------------------------
@@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat {
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EmbeddingResponse {
pub id: String,
pub object: String,
pub model: constants::EmbedModel,
pub model: constants::Model,
pub data: Vec<EmbeddingResponseDataItem>,
pub usage: common::ResponseUsage,
}

55
src/v1/files.rs Normal file
View File

@@ -0,0 +1,55 @@
use serde::{Deserialize, Serialize};
// -----------------------------------------------------------------------------
// Request
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum FilePurpose {
FineTune,
Batch,
Ocr,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FileListResponse {
pub data: Vec<FileObject>,
pub object: String,
#[serde(default)]
pub total: u32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FileObject {
pub id: String,
pub object: String,
pub bytes: u64,
pub created_at: u64,
pub filename: String,
pub purpose: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub sample_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_lines: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mimetype: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FileDeleteResponse {
pub id: String,
pub object: String,
pub deleted: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FileUrlResponse {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<u64>,
}

101
src/v1/fim.rs Normal file
View File

@@ -0,0 +1,101 @@
use serde::{Deserialize, Serialize};
use crate::v1::{common, constants};
// -----------------------------------------------------------------------------
// Request
#[derive(Debug)]
pub struct FimParams {
pub suffix: Option<String>,
pub max_tokens: Option<u32>,
pub min_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
pub random_seed: Option<u32>,
}
impl Default for FimParams {
fn default() -> Self {
Self {
suffix: None,
max_tokens: None,
min_tokens: None,
temperature: None,
top_p: None,
stop: None,
random_seed: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FimRequest {
pub model: constants::Model,
pub prompt: String,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>,
}
impl FimRequest {
pub fn new(
model: constants::Model,
prompt: String,
stream: bool,
options: Option<FimParams>,
) -> Self {
let opts = options.unwrap_or_default();
Self {
model,
prompt,
stream,
suffix: opts.suffix,
max_tokens: opts.max_tokens,
min_tokens: opts.min_tokens,
temperature: opts.temperature,
top_p: opts.top_p,
stop: opts.stop,
random_seed: opts.random_seed,
}
}
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FimResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: constants::Model,
pub choices: Vec<FimResponseChoice>,
pub usage: common::ResponseUsage,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FimResponseChoice {
pub index: u32,
pub message: FimResponseMessage,
pub finish_reason: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FimResponseMessage {
pub role: String,
pub content: String,
}

101
src/v1/fine_tuning.rs Normal file
View File

@@ -0,0 +1,101 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct FineTuningJobRequest {
pub model: constants::Model,
pub training_files: Vec<TrainingFile>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_files: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hyperparameters: Option<Hyperparameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_start: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub job_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub integrations: Option<Vec<Integration>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TrainingFile {
pub file_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub weight: Option<f32>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Hyperparameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub learning_rate: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub training_steps: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub warmup_fraction: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub epochs: Option<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Integration {
pub r#type: String,
pub project: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FineTuningJobResponse {
pub id: String,
pub object: String,
pub model: constants::Model,
pub status: FineTuningJobStatus,
pub created_at: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub modified_at: Option<u64>,
pub training_files: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_files: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hyperparameters: Option<Hyperparameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fine_tuned_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub integrations: Option<Vec<Integration>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub trained_tokens: Option<u64>,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FineTuningJobStatus {
Queued,
Running,
Success,
Failed,
TimeoutExceeded,
CancellationRequested,
Cancelled,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FineTuningJobListResponse {
pub data: Vec<FineTuningJobResponse>,
pub object: String,
#[serde(default)]
pub total: u32,
}

View File

@@ -1,3 +1,6 @@
pub mod agents;
pub mod audio;
pub mod batch;
pub mod chat;
pub mod chat_stream;
pub mod client;
@@ -5,6 +8,11 @@ pub mod common;
pub mod constants;
pub mod embedding;
pub mod error;
pub mod files;
pub mod fim;
pub mod fine_tuning;
pub mod model_list;
pub mod moderation;
pub mod ocr;
pub mod tool;
pub mod utils;

View File

@@ -15,23 +15,44 @@ pub struct ModelListData {
pub id: String,
pub object: String,
/// Unix timestamp (in seconds).
pub created: u32,
pub created: u64,
pub owned_by: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub root: Option<String>,
#[serde(default)]
pub archived: bool,
pub name: String,
pub description: String,
pub capabilities: ModelListDataCapabilies,
pub max_context_length: u32,
#[serde(default)]
pub name: Option<String>,
#[serde(default)]
pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub capabilities: Option<ModelListDataCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_context_length: Option<u32>,
#[serde(default)]
pub aliases: Vec<String>,
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
#[serde(skip_serializing_if = "Option::is_none")]
pub deprecation: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelListDataCapabilies {
pub struct ModelListDataCapabilities {
#[serde(default)]
pub completion_chat: bool,
#[serde(default)]
pub completion_fim: bool,
#[serde(default)]
pub function_calling: bool,
#[serde(default)]
pub fine_tuning: bool,
#[serde(default)]
pub vision: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelDeleteResponse {
pub id: String,
pub object: String,
pub deleted: bool,
}

70
src/v1/moderation.rs Normal file
View File

@@ -0,0 +1,70 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct ModerationRequest {
pub model: constants::Model,
pub input: Vec<String>,
}
impl ModerationRequest {
pub fn new(model: constants::Model, input: Vec<String>) -> Self {
Self { model, input }
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatModerationRequest {
pub model: constants::Model,
pub input: Vec<ChatModerationInput>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatModerationInput {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ClassificationRequest {
pub model: constants::Model,
pub input: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatClassificationRequest {
pub model: constants::Model,
pub input: Vec<ChatModerationInput>,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModerationResponse {
pub id: String,
pub model: constants::Model,
pub results: Vec<ModerationResult>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModerationResult {
pub categories: serde_json::Value,
pub category_scores: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ClassificationResponse {
pub id: String,
pub model: constants::Model,
pub results: Vec<ClassificationResult>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ClassificationResult {
pub categories: serde_json::Value,
pub category_scores: serde_json::Value,
}

96
src/v1/ocr.rs Normal file
View File

@@ -0,0 +1,96 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct OcrRequest {
pub model: constants::Model,
pub document: OcrDocument,
#[serde(skip_serializing_if = "Option::is_none")]
pub pages: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_format: Option<OcrTableFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_image_base64: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_limit: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OcrDocument {
#[serde(rename = "type")]
pub type_: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub document_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub file_id: Option<String>,
}
impl OcrDocument {
pub fn from_url(url: &str) -> Self {
Self {
type_: "document_url".to_string(),
document_url: Some(url.to_string()),
file_id: None,
}
}
pub fn from_file_id(file_id: &str) -> Self {
Self {
type_: "file_id".to_string(),
document_url: None,
file_id: Some(file_id.to_string()),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OcrTableFormat {
Markdown,
Html,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrResponse {
pub pages: Vec<OcrPage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage_info: Option<OcrUsageInfo>,
pub model: constants::Model,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrPage {
pub index: u32,
pub markdown: String,
#[serde(default)]
pub images: Vec<OcrImage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<OcrPageDimensions>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrImage {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_base64: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrPageDimensions {
pub width: f32,
pub height: f32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrUsageInfo {
pub pages_processed: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub doc_size_bytes: Option<u64>,
}

View File

@@ -1,12 +1,16 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::{any::Any, collections::HashMap, fmt::Debug};
use std::{any::Any, fmt::Debug};
// -----------------------------------------------------------------------------
// Definitions
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
pub function: ToolCallFunction,
}
@@ -22,31 +26,12 @@ pub struct Tool {
pub function: ToolFunction,
}
impl Tool {
/// Create a tool with a JSON Schema parameters object.
pub fn new(
function_name: String,
function_description: String,
function_parameters: Vec<ToolFunctionParameter>,
parameters: serde_json::Value,
) -> Self {
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
.into_iter()
.map(|param| {
(
param.name,
ToolFunctionParameterProperty {
r#type: param.r#type,
description: param.description,
},
)
})
.collect();
let property_names = properties.keys().cloned().collect();
let parameters = ToolFunctionParameters {
r#type: ToolFunctionParametersType::Object,
properties,
required: property_names,
};
Self {
r#type: ToolType::Function,
function: ToolFunction {
@@ -63,50 +48,9 @@ impl Tool {
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunction {
name: String,
description: String,
parameters: ToolFunctionParameters,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameter {
name: String,
description: String,
r#type: ToolFunctionParameterType,
}
impl ToolFunctionParameter {
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
Self {
name,
r#type,
description,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameters {
r#type: ToolFunctionParametersType,
properties: HashMap<String, ToolFunctionParameterProperty>,
required: Vec<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameterProperty {
r#type: ToolFunctionParameterType,
description: String,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParametersType {
#[serde(rename = "object")]
Object,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParameterType {
#[serde(rename = "string")]
String,
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
@@ -127,6 +71,9 @@ pub enum ToolChoice {
/// The model won't call a function and will generate a message instead.
#[serde(rename = "none")]
None,
/// The model must call at least one tool.
#[serde(rename = "required")]
Required,
}
// -----------------------------------------------------------------------------

View File

@@ -3,7 +3,7 @@ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
client::Client,
constants::Model,
tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
tool::{Tool, ToolChoice},
};
mod setup;
@@ -14,12 +14,12 @@ async fn test_client_chat_async() {
let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b;
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"Guess the next word: \"Eiffel ...\"?",
)];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
};
@@ -29,7 +29,6 @@ async fn test_client_chat_async() {
.await
.unwrap();
expect!(response.model).to_be(Model::OpenMistral7b);
expect!(response.object).to_be("chat.completion".to_string());
expect!(response.choices.len()).to_be(1);
@@ -56,21 +55,26 @@ async fn test_client_chat_async_with_function_calling() {
let tools = vec![Tool::new(
"get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new(
"city".to_string(),
"The name of the city.".to_string(),
ToolFunctionParameterType::String,
)],
serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)];
let client = Client::new(None, None, None, None).unwrap();
let model = Model::MistralSmallLatest;
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"What's the current temperature in Paris?",
)];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
tool_choice: Some(ToolChoice::Any),
tools: Some(tools),
@@ -82,7 +86,6 @@ async fn test_client_chat_async_with_function_calling() {
.await
.unwrap();
expect!(response.model).to_be(Model::MistralSmallLatest);
expect!(response.object).to_be("chat.completion".to_string());
expect!(response.choices.len()).to_be(1);
@@ -91,13 +94,6 @@ async fn test_client_chat_async_with_function_calling() {
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
expect!(response.choices[0].message.content.clone()).to_be("".to_string());
// expect!(response.choices[0].message.tool_calls.clone()).to_be(Some(vec![ToolCall {
// function: ToolCallFunction {
// name: "get_city_temperature".to_string(),
// arguments: "{\"city\": \"Paris\"}".to_string(),
// },
// }]));
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
expect!(response.usage.completion_tokens).to_be_greater_than(0);

View File

@@ -1,16 +1,18 @@
// Streaming tests require a live API key and are not run in CI.
// Uncomment to test locally.
// use futures::stream::StreamExt;
// use jrest::expect;
// use mistralai_client::v1::{
// chat_completion::{ChatParams, ChatMessage, ChatMessageRole},
// chat::{ChatMessage, ChatParams},
// client::Client,
// constants::Model,
// };
//
// #[tokio::test]
// async fn test_client_chat_stream() {
// let client = Client::new(None, None, None, None).unwrap();
// let model = Model::OpenMistral7b;
//
// let model = Model::mistral_small_latest();
// let messages = vec![ChatMessage::new_user_message(
// "Just guess the next word: \"Eiffel ...\"?",
// )];
@@ -19,22 +21,24 @@
// random_seed: Some(42),
// ..Default::default()
// };
// let stream_result = client.chat_stream(model, messages, Some(options)).await;
// let mut stream = stream_result.expect("Failed to create stream.");
// while let Some(maybe_chunk_result) = stream.next().await {
// match maybe_chunk_result {
// Some(Ok(chunk)) => {
// if chunk.choices[0].delta.role == Some(ChatMessageRole::Assistant)
// || chunk.choices[0].finish_reason == Some("stop".to_string())
// {
// expect!(chunk.choices[0].delta.content.len()).to_be(0);
// } else {
// expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0);
//
// let stream = client
// .chat_stream(model, messages, Some(options))
// .await
// .expect("Failed to create stream.");
//
// stream
// .for_each(|chunk_result| async {
// match chunk_result {
// Ok(chunks) => {
// for chunk in &chunks {
// if let Some(content) = &chunk.choices[0].delta.content {
// print!("{}", content);
// }
// }
// }
// Err(error) => eprintln!("Error: {:?}", error),
// }
// Some(Err(error)) => eprintln!("Error processing chunk: {:?}", error),
// None => (),
// }
// }
// })
// .await;
// }

View File

@@ -3,7 +3,7 @@ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
client::Client,
constants::Model,
tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
tool::{Tool, ToolChoice},
};
mod setup;
@@ -14,19 +14,18 @@ fn test_client_chat() {
let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b;
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"Guess the next word: \"Eiffel ...\"?",
)];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
};
let response = client.chat(model, messages, Some(options)).unwrap();
expect!(response.model).to_be(Model::OpenMistral7b);
expect!(response.object).to_be("chat.completion".to_string());
expect!(response.choices.len()).to_be(1);
expect!(response.choices[0].index).to_be(0);
@@ -50,21 +49,26 @@ fn test_client_chat_with_function_calling() {
let tools = vec![Tool::new(
"get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new(
"city".to_string(),
"The name of the city.".to_string(),
ToolFunctionParameterType::String,
)],
serde_json::json!({
"type": "object",
"properties": {
"city": {
"type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)];
let client = Client::new(None, None, None, None).unwrap();
let model = Model::MistralSmallLatest;
let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message(
"What's the current temperature in Paris?",
)];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto),
tools: Some(tools),
@@ -73,12 +77,10 @@ fn test_client_chat_with_function_calling() {
let response = client.chat(model, messages, Some(options)).unwrap();
expect!(response.model).to_be(Model::MistralSmallLatest);
expect!(response.object).to_be("chat.completion".to_string());
expect!(response.choices.len()).to_be(1);
expect!(response.choices[0].index).to_be(0);
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
expect!(response.choices[0].message.content.clone()).to_be("".to_string());
expect!(response.choices[0].finish_reason.clone())
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
expect!(response.usage.prompt_tokens).to_be_greater_than(0);

View File

@@ -1,11 +1,11 @@
use jrest::expect;
use mistralai_client::v1::{client::Client, constants::EmbedModel};
use mistralai_client::v1::{client::Client, constants::Model};
#[tokio::test]
async fn test_client_embeddings_async() {
let client: Client = Client::new(None, None, None, None).unwrap();
let model = EmbedModel::MistralEmbed;
let model = Model::mistral_embed();
let input = vec!["Embed this sentence.", "As well as this one."]
.iter()
.map(|s| s.to_string())
@@ -17,7 +17,6 @@ async fn test_client_embeddings_async() {
.await
.unwrap();
expect!(response.model).to_be(EmbedModel::MistralEmbed);
expect!(response.object).to_be("list".to_string());
expect!(response.data.len()).to_be(2);
expect!(response.data[0].index).to_be(0);

View File

@@ -1,11 +1,11 @@
use jrest::expect;
use mistralai_client::v1::{client::Client, constants::EmbedModel};
use mistralai_client::v1::{client::Client, constants::Model};
#[test]
fn test_client_embeddings() {
let client: Client = Client::new(None, None, None, None).unwrap();
let model = EmbedModel::MistralEmbed;
let model = Model::mistral_embed();
let input = vec!["Embed this sentence.", "As well as this one."]
.iter()
.map(|s| s.to_string())
@@ -14,7 +14,6 @@ fn test_client_embeddings() {
let response = client.embeddings(model, input, options).unwrap();
expect!(response.model).to_be(EmbedModel::MistralEmbed);
expect!(response.object).to_be("list".to_string());
expect!(response.data.len()).to_be(2);
expect!(response.data[0].index).to_be(0);

View File

@@ -6,26 +6,19 @@ use mistralai_client::v1::{
};
#[test]
fn test_model_constant() {
fn test_model_constants() {
let models = vec![
Model::OpenMistral7b,
Model::OpenMixtral8x7b,
Model::OpenMixtral8x22b,
Model::OpenMistralNemo,
Model::MistralTiny,
Model::MistralSmallLatest,
Model::MistralMediumLatest,
Model::MistralLargeLatest,
Model::MistralLarge,
Model::CodestralLatest,
Model::CodestralMamba,
Model::mistral_small_latest(),
Model::mistral_large_latest(),
Model::open_mistral_nemo(),
Model::codestral_latest(),
];
let client = Client::new(None, None, None, None).unwrap();
let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")];
let options = ChatParams {
temperature: 0.0,
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
};