Update to latest Mistral AI API (v1.0.0)
Some checks failed
Test / Test Documentation (push) Has been cancelled
Test / Test Examples (push) Has been cancelled
Test / Test (push) Has been cancelled

- Replace closed Model enum with flexible string-based Model type
  with constructor methods for all current models (Mistral Large 3,
  Small 4, Magistral, Codestral, Devstral, Pixtral, Voxtral, etc.)
- Add new API endpoints: FIM completions, Files, Fine-tuning, Batch
  jobs, OCR, Audio transcription, Moderations/Classifications, and
  Agent completions (sync + async for all)
- Add new chat fields: frequency_penalty, presence_penalty, stop,
  n, parallel_tool_calls, reasoning_effort, min_tokens, json_schema
  response format
- Add embedding fields: output_dimension, output_dtype
- Tool parameters now accept raw JSON Schema (serde_json::Value)
  instead of limited enum types
- Add tool call IDs and Required tool choice variant
- Add DELETE HTTP method support and multipart file upload
- Bump thiserror to v2, add reqwest multipart feature
- Remove strum dependency (no longer needed)
- Update all tests and examples for new API
This commit is contained in:
2026-03-20 17:16:26 +00:00
parent 9ad6a1dc84
commit 79bc40bb15
33 changed files with 1977 additions and 622 deletions

View File

@@ -2,7 +2,7 @@
name = "mistralai-client" name = "mistralai-client"
description = "Mistral AI API client library for Rust (unofficial)." description = "Mistral AI API client library for Rust (unofficial)."
license = "Apache-2.0" license = "Apache-2.0"
version = "0.14.0" version = "1.0.0"
edition = "2021" edition = "2021"
rust-version = "1.76.0" rust-version = "1.76.0"
@@ -15,18 +15,17 @@ readme = "README.md"
repository = "https://github.com/ivangabriele/mistralai-client-rs" repository = "https://github.com/ivangabriele/mistralai-client-rs"
[dependencies] [dependencies]
async-stream = "0.3.5" async-stream = "0.3"
async-trait = "0.1.77" async-trait = "0.1"
env_logger = "0.11.3" env_logger = "0.11"
futures = "0.3.30" futures = "0.3"
log = "0.4.21" log = "0.4"
reqwest = { version = "0.12.0", features = ["json", "blocking", "stream"] } reqwest = { version = "0.12", features = ["json", "blocking", "stream", "multipart"] }
serde = { version = "1.0.197", features = ["derive"] } serde = { version = "1", features = ["derive"] }
serde_json = "1.0.114" serde_json = "1"
strum = "0.26.1" thiserror = "2"
thiserror = "1.0.57" tokio = { version = "1", features = ["full"] }
tokio = { version = "1.36.0", features = ["full"] } tokio-stream = "0.1"
tokio-stream = "0.1.14"
[dev-dependencies] [dev-dependencies]
jrest = "0.2.3" jrest = "0.2"

View File

@@ -1,5 +1,5 @@
use mistralai_client::v1::{ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams}, chat::{ChatMessage, ChatParams},
client::Client, client::Client,
constants::Model, constants::Model,
}; };
@@ -8,14 +8,12 @@ fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable. // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap(); let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage { let messages = vec![ChatMessage::new_user_message(
role: ChatMessageRole::User, "Just guess the next word: \"Eiffel ...\"?",
content: "Just guess the next word: \"Eiffel ...\"?".to_string(), )];
tool_calls: None,
}];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };

View File

@@ -1,5 +1,5 @@
use mistralai_client::v1::{ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams}, chat::{ChatMessage, ChatParams},
client::Client, client::Client,
constants::Model, constants::Model,
}; };
@@ -9,14 +9,12 @@ async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable. // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap(); let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage { let messages = vec![ChatMessage::new_user_message(
role: ChatMessageRole::User, "Just guess the next word: \"Eiffel ...\"?",
content: "Just guess the next word: \"Eiffel ...\"?".to_string(), )];
tool_calls: None,
}];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
@@ -29,5 +27,4 @@ async fn main() {
"{:?}: {}", "{:?}: {}",
result.choices[0].message.role, result.choices[0].message.content result.choices[0].message.role, result.choices[0].message.content
); );
// => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
} }

View File

@@ -1,8 +1,8 @@
use mistralai_client::v1::{ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams}, chat::{ChatMessage, ChatParams},
client::Client, client::Client,
constants::Model, constants::Model,
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, tool::{Function, Tool, ToolChoice},
}; };
use serde::Deserialize; use serde::Deserialize;
use std::any::Any; use std::any::Any;
@@ -16,7 +16,6 @@ struct GetCityTemperatureFunction;
#[async_trait::async_trait] #[async_trait::async_trait]
impl Function for GetCityTemperatureFunction { impl Function for GetCityTemperatureFunction {
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> { async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
// Deserialize arguments, perform the logic, and return the result
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
let temperature = match city.as_str() { let temperature = match city.as_str() {
@@ -32,11 +31,16 @@ fn main() {
let tools = vec![Tool::new( let tools = vec![Tool::new(
"get_city_temperature".to_string(), "get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(), "Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new( serde_json::json!({
"city".to_string(), "type": "object",
"The name of the city.".to_string(), "properties": {
ToolFunctionParameterType::String, "city": {
)], "type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)]; )];
// This example suppose you have set the `MISTRAL_API_KEY` environment variable. // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
@@ -46,14 +50,12 @@ fn main() {
Box::new(GetCityTemperatureFunction), Box::new(GetCityTemperatureFunction),
); );
let model = Model::MistralSmallLatest; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage { let messages = vec![ChatMessage::new_user_message(
role: ChatMessageRole::User, "What's the temperature in Paris?",
content: "What's the temperature in Paris?".to_string(), )];
tool_calls: None,
}];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
tools: Some(tools), tools: Some(tools),

View File

@@ -1,8 +1,8 @@
use mistralai_client::v1::{ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams}, chat::{ChatMessage, ChatParams},
client::Client, client::Client,
constants::Model, constants::Model,
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, tool::{Function, Tool, ToolChoice},
}; };
use serde::Deserialize; use serde::Deserialize;
use std::any::Any; use std::any::Any;
@@ -16,7 +16,6 @@ struct GetCityTemperatureFunction;
#[async_trait::async_trait] #[async_trait::async_trait]
impl Function for GetCityTemperatureFunction { impl Function for GetCityTemperatureFunction {
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> { async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
// Deserialize arguments, perform the logic, and return the result
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap(); let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
let temperature = match city.as_str() { let temperature = match city.as_str() {
@@ -33,11 +32,16 @@ async fn main() {
let tools = vec![Tool::new( let tools = vec![Tool::new(
"get_city_temperature".to_string(), "get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(), "Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new( serde_json::json!({
"city".to_string(), "type": "object",
"The name of the city.".to_string(), "properties": {
ToolFunctionParameterType::String, "city": {
)], "type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)]; )];
// This example suppose you have set the `MISTRAL_API_KEY` environment variable. // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
@@ -47,14 +51,12 @@ async fn main() {
Box::new(GetCityTemperatureFunction), Box::new(GetCityTemperatureFunction),
); );
let model = Model::MistralSmallLatest; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage { let messages = vec![ChatMessage::new_user_message(
role: ChatMessageRole::User, "What's the temperature in Paris?",
content: "What's the temperature in Paris?".to_string(), )];
tool_calls: None,
}];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
tools: Some(tools), tools: Some(tools),

View File

@@ -1,6 +1,6 @@
use futures::stream::StreamExt; use futures::stream::StreamExt;
use mistralai_client::v1::{ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams}, chat::{ChatMessage, ChatParams},
client::Client, client::Client,
constants::Model, constants::Model,
}; };
@@ -11,14 +11,10 @@ async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable. // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap(); let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage { let messages = vec![ChatMessage::new_user_message("Tell me a short happy story.")];
role: ChatMessageRole::User,
content: "Tell me a short happy story.".to_string(),
tool_calls: None,
}];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
@@ -31,9 +27,10 @@ async fn main() {
.for_each(|chunk_result| async { .for_each(|chunk_result| async {
match chunk_result { match chunk_result {
Ok(chunks) => chunks.iter().for_each(|chunk| { Ok(chunks) => chunks.iter().for_each(|chunk| {
print!("{}", chunk.choices[0].delta.content); if let Some(content) = &chunk.choices[0].delta.content {
io::stdout().flush().unwrap(); print!("{}", content);
// => "Once upon a time, [...]" io::stdout().flush().unwrap();
}
}), }),
Err(error) => { Err(error) => {
eprintln!("Error processing chunk: {:?}", error) eprintln!("Error processing chunk: {:?}", error)
@@ -41,5 +38,5 @@ async fn main() {
} }
}) })
.await; .await;
print!("\n") // To persist the last chunk output. println!();
} }

View File

@@ -1,10 +1,10 @@
use mistralai_client::v1::{client::Client, constants::EmbedModel}; use mistralai_client::v1::{client::Client, constants::Model};
fn main() { fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable. // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client: Client = Client::new(None, None, None, None).unwrap(); let client: Client = Client::new(None, None, None, None).unwrap();
let model = EmbedModel::MistralEmbed; let model = Model::mistral_embed();
let input = vec!["Embed this sentence.", "As well as this one."] let input = vec!["Embed this sentence.", "As well as this one."]
.iter() .iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
@@ -13,5 +13,4 @@ fn main() {
let response = client.embeddings(model, input, options).unwrap(); let response = client.embeddings(model, input, options).unwrap();
println!("First Embedding: {:?}", response.data[0]); println!("First Embedding: {:?}", response.data[0]);
// => "First Embedding: {...}"
} }

View File

@@ -1,11 +1,11 @@
use mistralai_client::v1::{client::Client, constants::EmbedModel}; use mistralai_client::v1::{client::Client, constants::Model};
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable. // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client: Client = Client::new(None, None, None, None).unwrap(); let client: Client = Client::new(None, None, None, None).unwrap();
let model = EmbedModel::MistralEmbed; let model = Model::mistral_embed();
let input = vec!["Embed this sentence.", "As well as this one."] let input = vec!["Embed this sentence.", "As well as this one."]
.iter() .iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
@@ -17,5 +17,4 @@ async fn main() {
.await .await
.unwrap(); .unwrap();
println!("First Embedding: {:?}", response.data[0]); println!("First Embedding: {:?}", response.data[0]);
// => "First Embedding: {...}"
} }

21
examples/fim.rs Normal file
View File

@@ -0,0 +1,21 @@
use mistralai_client::v1::{
client::Client,
constants::Model,
fim::FimParams,
};
fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap();
let model = Model::codestral_latest();
let prompt = "def fibonacci(n):".to_string();
let options = FimParams {
suffix: Some("\n return result".to_string()),
temperature: Some(0.0),
..Default::default()
};
let response = client.fim(model, prompt, Some(options)).unwrap();
println!("Completion: {}", response.choices[0].message.content);
}

25
examples/ocr.rs Normal file
View File

@@ -0,0 +1,25 @@
use mistralai_client::v1::{
client::Client,
constants::Model,
ocr::{OcrDocument, OcrRequest},
};
fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
let client = Client::new(None, None, None, None).unwrap();
let request = OcrRequest {
model: Model::mistral_ocr_latest(),
document: OcrDocument::from_url("https://arxiv.org/pdf/2201.04234"),
pages: Some(vec![0]),
table_format: None,
include_image_base64: None,
image_limit: None,
};
let response = client.ocr(&request).unwrap();
for page in &response.pages {
println!("--- Page {} ---", page.index);
println!("{}", &page.markdown[..200.min(page.markdown.len())]);
}
}

98
src/v1/agents.rs Normal file
View File

@@ -0,0 +1,98 @@
use serde::{Deserialize, Serialize};
use crate::v1::{chat, common, constants, tool};
// -----------------------------------------------------------------------------
// Request
#[derive(Debug)]
pub struct AgentCompletionParams {
pub max_tokens: Option<u32>,
pub min_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub random_seed: Option<u32>,
pub stop: Option<Vec<String>>,
pub response_format: Option<chat::ResponseFormat>,
pub tools: Option<Vec<tool::Tool>>,
pub tool_choice: Option<tool::ToolChoice>,
}
impl Default for AgentCompletionParams {
fn default() -> Self {
Self {
max_tokens: None,
min_tokens: None,
temperature: None,
top_p: None,
random_seed: None,
stop: None,
response_format: None,
tools: None,
tool_choice: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct AgentCompletionRequest {
pub agent_id: String,
pub messages: Vec<chat::ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<chat::ResponseFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<tool::Tool>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<tool::ToolChoice>,
}
impl AgentCompletionRequest {
pub fn new(
agent_id: String,
messages: Vec<chat::ChatMessage>,
stream: bool,
options: Option<AgentCompletionParams>,
) -> Self {
let opts = options.unwrap_or_default();
Self {
agent_id,
messages,
stream,
max_tokens: opts.max_tokens,
min_tokens: opts.min_tokens,
temperature: opts.temperature,
top_p: opts.top_p,
random_seed: opts.random_seed,
stop: opts.stop,
response_format: opts.response_format,
tools: opts.tools,
tool_choice: opts.tool_choice,
}
}
}
// -----------------------------------------------------------------------------
// Response (same shape as chat completions)
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AgentCompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: constants::Model,
pub choices: Vec<chat::ChatResponseChoice>,
pub usage: common::ResponseUsage,
}

78
src/v1/audio.rs Normal file
View File

@@ -0,0 +1,78 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request (multipart form, but we define the params struct)
#[derive(Debug)]
pub struct AudioTranscriptionParams {
pub model: constants::Model,
pub language: Option<String>,
pub temperature: Option<f32>,
pub diarize: Option<bool>,
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
}
impl Default for AudioTranscriptionParams {
fn default() -> Self {
Self {
model: constants::Model::voxtral_mini_transcribe(),
language: None,
temperature: None,
diarize: None,
timestamp_granularities: None,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum TimestampGranularity {
Segment,
Word,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AudioTranscriptionResponse {
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub language: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub segments: Option<Vec<TranscriptionSegment>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub words: Option<Vec<TranscriptionWord>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<AudioUsage>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TranscriptionSegment {
pub id: u32,
pub start: f32,
pub end: f32,
pub text: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub speaker: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct TranscriptionWord {
pub word: String,
pub start: f32,
pub end: f32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct AudioUsage {
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_audio_seconds: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub prompt_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_tokens: Option<u32>,
}

53
src/v1/batch.rs Normal file
View File

@@ -0,0 +1,53 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchJobRequest {
pub input_files: Vec<String>,
pub model: constants::Model,
pub endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct BatchJobResponse {
pub id: String,
pub object: String,
pub model: constants::Model,
pub endpoint: String,
pub input_files: Vec<String>,
pub status: String,
pub created_at: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_at: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_file: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub metadata: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub total_requests: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub completed_requests: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub succeeded_requests: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub failed_requests: Option<u64>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct BatchJobListResponse {
pub data: Vec<BatchJobResponse>,
pub object: String,
#[serde(default)]
pub total: u32,
}

View File

@@ -11,13 +11,31 @@ pub struct ChatMessage {
pub content: String, pub content: String,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>, pub tool_calls: Option<Vec<tool::ToolCall>>,
/// Tool call ID, required when role is Tool.
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_call_id: Option<String>,
/// Function name, used when role is Tool.
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
} }
impl ChatMessage { impl ChatMessage {
pub fn new_system_message(content: &str) -> Self {
Self {
role: ChatMessageRole::System,
content: content.to_string(),
tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self { pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
Self { Self {
role: ChatMessageRole::Assistant, role: ChatMessageRole::Assistant,
content: content.to_string(), content: content.to_string(),
tool_calls, tool_calls,
tool_call_id: None,
name: None,
} }
} }
@@ -26,6 +44,18 @@ impl ChatMessage {
role: ChatMessageRole::User, role: ChatMessageRole::User,
content: content.to_string(), content: content.to_string(),
tool_calls: None, tool_calls: None,
tool_call_id: None,
name: None,
}
}
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
Self {
role: ChatMessageRole::Tool,
content: content.to_string(),
tool_calls: None,
tool_call_id: Some(tool_call_id.to_string()),
name: name.map(|n| n.to_string()),
} }
} }
} }
@@ -44,17 +74,32 @@ pub enum ChatMessageRole {
} }
/// The format that the model must output. /// The format that the model must output.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Clone, Debug, Serialize, Deserialize)] #[derive(Clone, Debug, Serialize, Deserialize)]
pub struct ResponseFormat { pub struct ResponseFormat {
#[serde(rename = "type")] #[serde(rename = "type")]
pub type_: String, pub type_: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub json_schema: Option<serde_json::Value>,
} }
impl ResponseFormat { impl ResponseFormat {
pub fn text() -> Self {
Self {
type_: "text".to_string(),
json_schema: None,
}
}
pub fn json_object() -> Self { pub fn json_object() -> Self {
Self { Self {
type_: "json_object".to_string(), type_: "json_object".to_string(),
json_schema: None,
}
}
pub fn json_schema(schema: serde_json::Value) -> Self {
Self {
type_: "json_schema".to_string(),
json_schema: Some(schema),
} }
} }
} }
@@ -63,91 +108,83 @@ impl ResponseFormat {
// Request // Request
/// The parameters for the chat request. /// The parameters for the chat request.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
pub struct ChatParams { pub struct ChatParams {
/// The maximum number of tokens to generate in the completion.
///
/// Defaults to `None`.
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
/// The seed to use for random sampling. If set, different calls will generate deterministic results. pub min_tokens: Option<u32>,
///
/// Defaults to `None`.
pub random_seed: Option<u32>, pub random_seed: Option<u32>,
/// The format that the model must output.
///
/// Defaults to `None`.
pub response_format: Option<ResponseFormat>, pub response_format: Option<ResponseFormat>,
/// Whether to inject a safety prompt before all conversations.
///
/// Defaults to `false`.
pub safe_prompt: bool, pub safe_prompt: bool,
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`. pub temperature: Option<f32>,
///
/// Defaults to `0.7`.
pub temperature: f32,
/// Specifies if/how functions are called.
///
/// Defaults to `None`.
pub tool_choice: Option<tool::ToolChoice>, pub tool_choice: Option<tool::ToolChoice>,
/// A list of available tools for the model.
///
/// Defaults to `None`.
pub tools: Option<Vec<tool::Tool>>, pub tools: Option<Vec<tool::Tool>>,
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass. pub top_p: Option<f32>,
/// pub stop: Option<Vec<String>>,
/// Defaults to `1.0`. pub n: Option<u32>,
pub top_p: f32, pub frequency_penalty: Option<f32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
/// For reasoning models (Magistral). "high" or "none".
pub reasoning_effort: Option<String>,
} }
impl Default for ChatParams { impl Default for ChatParams {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_tokens: None, max_tokens: None,
min_tokens: None,
random_seed: None, random_seed: None,
safe_prompt: false, safe_prompt: false,
response_format: None, response_format: None,
temperature: 0.7, temperature: None,
tool_choice: None, tool_choice: None,
tools: None, tools: None,
top_p: 1.0, top_p: None,
} stop: None,
} n: None,
} frequency_penalty: None,
impl ChatParams { presence_penalty: None,
pub fn json_default() -> Self { parallel_tool_calls: None,
Self { reasoning_effort: None,
max_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
tool_choice: None,
tools: None,
top_p: 1.0,
} }
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct ChatRequest { pub struct ChatRequest {
pub messages: Vec<ChatMessage>,
pub model: constants::Model, pub model: constants::Model,
pub messages: Vec<ChatMessage>,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>, pub random_seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>, pub response_format: Option<ResponseFormat>,
pub safe_prompt: bool, #[serde(skip_serializing_if = "Option::is_none")]
pub stream: bool, pub safe_prompt: Option<bool>,
pub temperature: f32, #[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<tool::ToolChoice>, pub tool_choice: Option<tool::ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<tool::Tool>>, pub tools: Option<Vec<tool::Tool>>,
pub top_p: f32, #[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub n: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub frequency_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub presence_penalty: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_tool_calls: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub reasoning_effort: Option<String>,
} }
impl ChatRequest { impl ChatRequest {
pub fn new( pub fn new(
@@ -156,30 +193,28 @@ impl ChatRequest {
stream: bool, stream: bool,
options: Option<ChatParams>, options: Option<ChatParams>,
) -> Self { ) -> Self {
let ChatParams { let opts = options.unwrap_or_default();
max_tokens, let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
random_seed,
safe_prompt,
temperature,
tool_choice,
tools,
top_p,
response_format,
} = options.unwrap_or_default();
Self { Self {
messages,
model, model,
messages,
max_tokens,
random_seed,
safe_prompt,
stream, stream,
temperature, max_tokens: opts.max_tokens,
tool_choice, min_tokens: opts.min_tokens,
tools, random_seed: opts.random_seed,
top_p, safe_prompt,
response_format, temperature: opts.temperature,
tool_choice: opts.tool_choice,
tools: opts.tools,
top_p: opts.top_p,
response_format: opts.response_format,
stop: opts.stop,
n: opts.n,
frequency_penalty: opts.frequency_penalty,
presence_penalty: opts.presence_penalty,
parallel_tool_calls: opts.parallel_tool_calls,
reasoning_effort: opts.reasoning_effort,
} }
} }
} }
@@ -192,7 +227,7 @@ pub struct ChatResponse {
pub id: String, pub id: String,
pub object: String, pub object: String,
/// Unix timestamp (in seconds). /// Unix timestamp (in seconds).
pub created: u32, pub created: u64,
pub model: constants::Model, pub model: constants::Model,
pub choices: Vec<ChatResponseChoice>, pub choices: Vec<ChatResponseChoice>,
pub usage: common::ResponseUsage, pub usage: common::ResponseUsage,
@@ -203,14 +238,18 @@ pub struct ChatResponseChoice {
pub index: u32, pub index: u32,
pub message: ChatMessage, pub message: ChatMessage,
pub finish_reason: ChatResponseChoiceFinishReason, pub finish_reason: ChatResponseChoiceFinishReason,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???
} }
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ChatResponseChoiceFinishReason { pub enum ChatResponseChoiceFinishReason {
#[serde(rename = "stop")] #[serde(rename = "stop")]
Stop, Stop,
#[serde(rename = "length")]
Length,
#[serde(rename = "tool_calls")] #[serde(rename = "tool_calls")]
ToolCalls, ToolCalls,
#[serde(rename = "model_length")]
ModelLength,
#[serde(rename = "error")]
Error,
} }

View File

@@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::from_str; use serde_json::from_str;
use crate::v1::{chat, common, constants, error}; use crate::v1::{chat, common, constants, error, tool};
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Response // Response
@@ -11,12 +11,11 @@ pub struct ChatStreamChunk {
pub id: String, pub id: String,
pub object: String, pub object: String,
/// Unix timestamp (in seconds). /// Unix timestamp (in seconds).
pub created: u32, pub created: u64,
pub model: constants::Model, pub model: constants::Model,
pub choices: Vec<ChatStreamChunkChoice>, pub choices: Vec<ChatStreamChunkChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage: Option<common::ResponseUsage>, pub usage: Option<common::ResponseUsage>,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
@@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice {
pub index: u32, pub index: u32,
pub delta: ChatStreamChunkChoiceDelta, pub delta: ChatStreamChunkChoiceDelta,
pub finish_reason: Option<String>, pub finish_reason: Option<String>,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ChatStreamChunkChoiceDelta { pub struct ChatStreamChunkChoiceDelta {
pub role: Option<chat::ChatMessageRole>, pub role: Option<chat::ChatMessageRole>,
pub content: String, #[serde(default)]
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>,
} }
/// Extracts serialized chunks from a stream message. /// Extracts serialized chunks from a stream message.
@@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line(
return Ok(Some(vec![])); return Ok(Some(vec![]));
} }
// Attempt to deserialize the JSON string into ChatStreamChunk
match from_str::<ChatStreamChunk>(chunk_as_json) { match from_str::<ChatStreamChunk>(chunk_as_json) {
Ok(chunk) => Ok(Some(vec![chunk])), Ok(chunk) => Ok(Some(vec![chunk])),
Err(e) => Err(error::ApiError { Err(e) => Err(error::ApiError {

File diff suppressed because it is too large Load Diff

View File

@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ResponseUsage { pub struct ResponseUsage {
pub prompt_tokens: u32, pub prompt_tokens: u32,
#[serde(default)]
pub completion_tokens: u32, pub completion_tokens: u32,
pub total_tokens: u32, pub total_tokens: u32,
} }

View File

@@ -1,35 +1,131 @@
use std::fmt;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1"; pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] /// A Mistral AI model identifier.
pub enum Model { ///
#[serde(rename = "open-mistral-7b")] /// Use the associated constants for known models, or construct with `Model::new()` for any model string.
OpenMistral7b, #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename = "open-mixtral-8x7b")] #[serde(transparent)]
OpenMixtral8x7b, pub struct Model(pub String);
#[serde(rename = "open-mixtral-8x22b")]
OpenMixtral8x22b, impl Model {
#[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")] pub fn new(id: impl Into<String>) -> Self {
OpenMistralNemo, Self(id.into())
#[serde(rename = "mistral-tiny")] }
MistralTiny,
#[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")] // Flagship / Premier
MistralSmallLatest, pub fn mistral_large_latest() -> Self {
#[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")] Self::new("mistral-large-latest")
MistralMediumLatest, }
#[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")] pub fn mistral_large_3() -> Self {
MistralLargeLatest, Self::new("mistral-large-3-25-12")
#[serde(rename = "mistral-large-2402")] }
MistralLarge, pub fn mistral_medium_latest() -> Self {
#[serde(rename = "codestral-latest", alias = "codestral-2405")] Self::new("mistral-medium-latest")
CodestralLatest, }
#[serde(rename = "open-codestral-mamba")] pub fn mistral_medium_3_1() -> Self {
CodestralMamba, Self::new("mistral-medium-3-1-25-08")
}
pub fn mistral_small_latest() -> Self {
Self::new("mistral-small-latest")
}
pub fn mistral_small_4() -> Self {
Self::new("mistral-small-4-0-26-03")
}
pub fn mistral_small_3_2() -> Self {
Self::new("mistral-small-3-2-25-06")
}
// Ministral
pub fn ministral_3_14b() -> Self {
Self::new("ministral-3-14b-25-12")
}
pub fn ministral_3_8b() -> Self {
Self::new("ministral-3-8b-25-12")
}
pub fn ministral_3_3b() -> Self {
Self::new("ministral-3-3b-25-12")
}
// Reasoning
pub fn magistral_medium_latest() -> Self {
Self::new("magistral-medium-latest")
}
pub fn magistral_small_latest() -> Self {
Self::new("magistral-small-latest")
}
// Code
pub fn codestral_latest() -> Self {
Self::new("codestral-latest")
}
pub fn codestral_2508() -> Self {
Self::new("codestral-2508")
}
pub fn codestral_embed() -> Self {
Self::new("codestral-embed-25-05")
}
pub fn devstral_2() -> Self {
Self::new("devstral-2-25-12")
}
pub fn devstral_small_2() -> Self {
Self::new("devstral-small-2-25-12")
}
// Multimodal / Vision
pub fn pixtral_large() -> Self {
Self::new("pixtral-large-2411")
}
// Audio
pub fn voxtral_mini_transcribe() -> Self {
Self::new("voxtral-mini-transcribe-2-26-02")
}
pub fn voxtral_small() -> Self {
Self::new("voxtral-small-25-07")
}
pub fn voxtral_mini() -> Self {
Self::new("voxtral-mini-25-07")
}
// Legacy (kept for backward compatibility)
pub fn open_mistral_nemo() -> Self {
Self::new("open-mistral-nemo")
}
// Embedding
pub fn mistral_embed() -> Self {
Self::new("mistral-embed")
}
// Moderation
pub fn mistral_moderation_latest() -> Self {
Self::new("mistral-moderation-26-03")
}
// OCR
pub fn mistral_ocr_latest() -> Self {
Self::new("mistral-ocr-latest")
}
} }
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] impl fmt::Display for Model {
pub enum EmbedModel { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
#[serde(rename = "mistral-embed")] write!(f, "{}", self.0)
MistralEmbed, }
}
impl From<&str> for Model {
fn from(s: &str) -> Self {
Self(s.to_string())
}
}
impl From<String> for Model {
fn from(s: String) -> Self {
Self(s)
}
} }

View File

@@ -8,42 +8,63 @@ use crate::v1::{common, constants};
#[derive(Debug)] #[derive(Debug)]
pub struct EmbeddingRequestOptions { pub struct EmbeddingRequestOptions {
pub encoding_format: Option<EmbeddingRequestEncodingFormat>, pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
pub output_dimension: Option<u32>,
pub output_dtype: Option<EmbeddingOutputDtype>,
} }
impl Default for EmbeddingRequestOptions { impl Default for EmbeddingRequestOptions {
fn default() -> Self { fn default() -> Self {
Self { Self {
encoding_format: None, encoding_format: None,
output_dimension: None,
output_dtype: None,
} }
} }
} }
#[derive(Debug, Serialize, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct EmbeddingRequest { pub struct EmbeddingRequest {
pub model: constants::EmbedModel, pub model: constants::Model,
pub input: Vec<String>, pub input: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub encoding_format: Option<EmbeddingRequestEncodingFormat>, pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dimension: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub output_dtype: Option<EmbeddingOutputDtype>,
} }
impl EmbeddingRequest { impl EmbeddingRequest {
pub fn new( pub fn new(
model: constants::EmbedModel, model: constants::Model,
input: Vec<String>, input: Vec<String>,
options: Option<EmbeddingRequestOptions>, options: Option<EmbeddingRequestOptions>,
) -> Self { ) -> Self {
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default(); let opts = options.unwrap_or_default();
Self { Self {
model, model,
input, input,
encoding_format, encoding_format: opts.encoding_format,
output_dimension: opts.output_dimension,
output_dtype: opts.output_dtype,
} }
} }
} }
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[allow(non_camel_case_types)] #[serde(rename_all = "lowercase")]
pub enum EmbeddingRequestEncodingFormat { pub enum EmbeddingRequestEncodingFormat {
float, Float,
Base64,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum EmbeddingOutputDtype {
Float,
Int8,
Uint8,
Binary,
Ubinary,
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat {
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EmbeddingResponse { pub struct EmbeddingResponse {
pub id: String,
pub object: String, pub object: String,
pub model: constants::EmbedModel, pub model: constants::Model,
pub data: Vec<EmbeddingResponseDataItem>, pub data: Vec<EmbeddingResponseDataItem>,
pub usage: common::ResponseUsage, pub usage: common::ResponseUsage,
} }

55
src/v1/files.rs Normal file
View File

@@ -0,0 +1,55 @@
use serde::{Deserialize, Serialize};
// -----------------------------------------------------------------------------
// Request
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum FilePurpose {
FineTune,
Batch,
Ocr,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FileListResponse {
pub data: Vec<FileObject>,
pub object: String,
#[serde(default)]
pub total: u32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FileObject {
pub id: String,
pub object: String,
pub bytes: u64,
pub created_at: u64,
pub filename: String,
pub purpose: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub sample_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub source: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub num_lines: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub mimetype: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FileDeleteResponse {
pub id: String,
pub object: String,
pub deleted: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FileUrlResponse {
pub url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub expires_at: Option<u64>,
}

101
src/v1/fim.rs Normal file
View File

@@ -0,0 +1,101 @@
use serde::{Deserialize, Serialize};
use crate::v1::{common, constants};
// -----------------------------------------------------------------------------
// Request
#[derive(Debug)]
pub struct FimParams {
pub suffix: Option<String>,
pub max_tokens: Option<u32>,
pub min_tokens: Option<u32>,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub stop: Option<Vec<String>>,
pub random_seed: Option<u32>,
}
impl Default for FimParams {
fn default() -> Self {
Self {
suffix: None,
max_tokens: None,
min_tokens: None,
temperature: None,
top_p: None,
stop: None,
random_seed: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct FimRequest {
pub model: constants::Model,
pub prompt: String,
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub min_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>,
}
impl FimRequest {
pub fn new(
model: constants::Model,
prompt: String,
stream: bool,
options: Option<FimParams>,
) -> Self {
let opts = options.unwrap_or_default();
Self {
model,
prompt,
stream,
suffix: opts.suffix,
max_tokens: opts.max_tokens,
min_tokens: opts.min_tokens,
temperature: opts.temperature,
top_p: opts.top_p,
stop: opts.stop,
random_seed: opts.random_seed,
}
}
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FimResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: constants::Model,
pub choices: Vec<FimResponseChoice>,
pub usage: common::ResponseUsage,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FimResponseChoice {
pub index: u32,
pub message: FimResponseMessage,
pub finish_reason: String,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FimResponseMessage {
pub role: String,
pub content: String,
}

101
src/v1/fine_tuning.rs Normal file
View File

@@ -0,0 +1,101 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct FineTuningJobRequest {
pub model: constants::Model,
pub training_files: Vec<TrainingFile>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_files: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hyperparameters: Option<Hyperparameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub auto_start: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub job_type: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub integrations: Option<Vec<Integration>>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct TrainingFile {
pub file_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub weight: Option<f32>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Hyperparameters {
#[serde(skip_serializing_if = "Option::is_none")]
pub learning_rate: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub training_steps: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub warmup_fraction: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub epochs: Option<f64>,
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Integration {
pub r#type: String,
pub project: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub api_key: Option<String>,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FineTuningJobResponse {
pub id: String,
pub object: String,
pub model: constants::Model,
pub status: FineTuningJobStatus,
pub created_at: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub modified_at: Option<u64>,
pub training_files: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub validation_files: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub hyperparameters: Option<Hyperparameters>,
#[serde(skip_serializing_if = "Option::is_none")]
pub fine_tuned_model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub suffix: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub integrations: Option<Vec<Integration>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub trained_tokens: Option<u64>,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum FineTuningJobStatus {
Queued,
Running,
Success,
Failed,
TimeoutExceeded,
CancellationRequested,
Cancelled,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct FineTuningJobListResponse {
pub data: Vec<FineTuningJobResponse>,
pub object: String,
#[serde(default)]
pub total: u32,
}

View File

@@ -1,3 +1,6 @@
pub mod agents;
pub mod audio;
pub mod batch;
pub mod chat; pub mod chat;
pub mod chat_stream; pub mod chat_stream;
pub mod client; pub mod client;
@@ -5,6 +8,11 @@ pub mod common;
pub mod constants; pub mod constants;
pub mod embedding; pub mod embedding;
pub mod error; pub mod error;
pub mod files;
pub mod fim;
pub mod fine_tuning;
pub mod model_list; pub mod model_list;
pub mod moderation;
pub mod ocr;
pub mod tool; pub mod tool;
pub mod utils; pub mod utils;

View File

@@ -15,23 +15,44 @@ pub struct ModelListData {
pub id: String, pub id: String,
pub object: String, pub object: String,
/// Unix timestamp (in seconds). /// Unix timestamp (in seconds).
pub created: u32, pub created: u64,
pub owned_by: String, pub owned_by: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub root: Option<String>, pub root: Option<String>,
#[serde(default)]
pub archived: bool, pub archived: bool,
pub name: String, #[serde(default)]
pub description: String, pub name: Option<String>,
pub capabilities: ModelListDataCapabilies, #[serde(default)]
pub max_context_length: u32, pub description: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub capabilities: Option<ModelListDataCapabilities>,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_context_length: Option<u32>,
#[serde(default)]
pub aliases: Vec<String>, pub aliases: Vec<String>,
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`). /// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
#[serde(skip_serializing_if = "Option::is_none")]
pub deprecation: Option<String>, pub deprecation: Option<String>,
} }
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelListDataCapabilies { pub struct ModelListDataCapabilities {
#[serde(default)]
pub completion_chat: bool, pub completion_chat: bool,
#[serde(default)]
pub completion_fim: bool, pub completion_fim: bool,
#[serde(default)]
pub function_calling: bool, pub function_calling: bool,
#[serde(default)]
pub fine_tuning: bool, pub fine_tuning: bool,
#[serde(default)]
pub vision: bool,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModelDeleteResponse {
pub id: String,
pub object: String,
pub deleted: bool,
} }

70
src/v1/moderation.rs Normal file
View File

@@ -0,0 +1,70 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct ModerationRequest {
pub model: constants::Model,
pub input: Vec<String>,
}
impl ModerationRequest {
pub fn new(model: constants::Model, input: Vec<String>) -> Self {
Self { model, input }
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatModerationRequest {
pub model: constants::Model,
pub input: Vec<ChatModerationInput>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatModerationInput {
pub role: String,
pub content: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ClassificationRequest {
pub model: constants::Model,
pub input: Vec<String>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatClassificationRequest {
pub model: constants::Model,
pub input: Vec<ChatModerationInput>,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModerationResponse {
pub id: String,
pub model: constants::Model,
pub results: Vec<ModerationResult>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ModerationResult {
pub categories: serde_json::Value,
pub category_scores: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ClassificationResponse {
pub id: String,
pub model: constants::Model,
pub results: Vec<ClassificationResult>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ClassificationResult {
pub categories: serde_json::Value,
pub category_scores: serde_json::Value,
}

96
src/v1/ocr.rs Normal file
View File

@@ -0,0 +1,96 @@
use serde::{Deserialize, Serialize};
use crate::v1::constants;
// -----------------------------------------------------------------------------
// Request
#[derive(Debug, Serialize, Deserialize)]
pub struct OcrRequest {
pub model: constants::Model,
pub document: OcrDocument,
#[serde(skip_serializing_if = "Option::is_none")]
pub pages: Option<Vec<u32>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub table_format: Option<OcrTableFormat>,
#[serde(skip_serializing_if = "Option::is_none")]
pub include_image_base64: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_limit: Option<u32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct OcrDocument {
#[serde(rename = "type")]
pub type_: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub document_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub file_id: Option<String>,
}
impl OcrDocument {
pub fn from_url(url: &str) -> Self {
Self {
type_: "document_url".to_string(),
document_url: Some(url.to_string()),
file_id: None,
}
}
pub fn from_file_id(file_id: &str) -> Self {
Self {
type_: "file_id".to_string(),
document_url: None,
file_id: Some(file_id.to_string()),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum OcrTableFormat {
Markdown,
Html,
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrResponse {
pub pages: Vec<OcrPage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub usage_info: Option<OcrUsageInfo>,
pub model: constants::Model,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrPage {
pub index: u32,
pub markdown: String,
#[serde(default)]
pub images: Vec<OcrImage>,
#[serde(skip_serializing_if = "Option::is_none")]
pub dimensions: Option<OcrPageDimensions>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrImage {
pub id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub image_base64: Option<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrPageDimensions {
pub width: f32,
pub height: f32,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct OcrUsageInfo {
pub pages_processed: u32,
#[serde(skip_serializing_if = "Option::is_none")]
pub doc_size_bytes: Option<u64>,
}

View File

@@ -1,12 +1,16 @@
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{any::Any, collections::HashMap, fmt::Debug}; use std::{any::Any, fmt::Debug};
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Definitions // Definitions
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub struct ToolCall { pub struct ToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub r#type: Option<String>,
pub function: ToolCallFunction, pub function: ToolCallFunction,
} }
@@ -22,31 +26,12 @@ pub struct Tool {
pub function: ToolFunction, pub function: ToolFunction,
} }
impl Tool { impl Tool {
/// Create a tool with a JSON Schema parameters object.
pub fn new( pub fn new(
function_name: String, function_name: String,
function_description: String, function_description: String,
function_parameters: Vec<ToolFunctionParameter>, parameters: serde_json::Value,
) -> Self { ) -> Self {
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
.into_iter()
.map(|param| {
(
param.name,
ToolFunctionParameterProperty {
r#type: param.r#type,
description: param.description,
},
)
})
.collect();
let property_names = properties.keys().cloned().collect();
let parameters = ToolFunctionParameters {
r#type: ToolFunctionParametersType::Object,
properties,
required: property_names,
};
Self { Self {
r#type: ToolType::Function, r#type: ToolType::Function,
function: ToolFunction { function: ToolFunction {
@@ -63,50 +48,9 @@ impl Tool {
#[derive(Clone, Debug, Deserialize, Serialize)] #[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunction { pub struct ToolFunction {
name: String, pub name: String,
description: String, pub description: String,
parameters: ToolFunctionParameters, pub parameters: serde_json::Value,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameter {
name: String,
description: String,
r#type: ToolFunctionParameterType,
}
impl ToolFunctionParameter {
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
Self {
name,
r#type,
description,
}
}
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameters {
r#type: ToolFunctionParametersType,
properties: HashMap<String, ToolFunctionParameterProperty>,
required: Vec<String>,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ToolFunctionParameterProperty {
r#type: ToolFunctionParameterType,
description: String,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParametersType {
#[serde(rename = "object")]
Object,
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolFunctionParameterType {
#[serde(rename = "string")]
String,
} }
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
@@ -127,6 +71,9 @@ pub enum ToolChoice {
/// The model won't call a function and will generate a message instead. /// The model won't call a function and will generate a message instead.
#[serde(rename = "none")] #[serde(rename = "none")]
None, None,
/// The model must call at least one tool.
#[serde(rename = "required")]
Required,
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------

View File

@@ -3,7 +3,7 @@ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
client::Client, client::Client,
constants::Model, constants::Model,
tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, tool::{Tool, ToolChoice},
}; };
mod setup; mod setup;
@@ -14,12 +14,12 @@ async fn test_client_chat_async() {
let client = Client::new(None, None, None, None).unwrap(); let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message( let messages = vec![ChatMessage::new_user_message(
"Guess the next word: \"Eiffel ...\"?", "Guess the next word: \"Eiffel ...\"?",
)]; )];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
@@ -29,7 +29,6 @@ async fn test_client_chat_async() {
.await .await
.unwrap(); .unwrap();
expect!(response.model).to_be(Model::OpenMistral7b);
expect!(response.object).to_be("chat.completion".to_string()); expect!(response.object).to_be("chat.completion".to_string());
expect!(response.choices.len()).to_be(1); expect!(response.choices.len()).to_be(1);
@@ -56,21 +55,26 @@ async fn test_client_chat_async_with_function_calling() {
let tools = vec![Tool::new( let tools = vec![Tool::new(
"get_city_temperature".to_string(), "get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(), "Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new( serde_json::json!({
"city".to_string(), "type": "object",
"The name of the city.".to_string(), "properties": {
ToolFunctionParameterType::String, "city": {
)], "type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)]; )];
let client = Client::new(None, None, None, None).unwrap(); let client = Client::new(None, None, None, None).unwrap();
let model = Model::MistralSmallLatest; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message( let messages = vec![ChatMessage::new_user_message(
"What's the current temperature in Paris?", "What's the current temperature in Paris?",
)]; )];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Any), tool_choice: Some(ToolChoice::Any),
tools: Some(tools), tools: Some(tools),
@@ -82,7 +86,6 @@ async fn test_client_chat_async_with_function_calling() {
.await .await
.unwrap(); .unwrap();
expect!(response.model).to_be(Model::MistralSmallLatest);
expect!(response.object).to_be("chat.completion".to_string()); expect!(response.object).to_be("chat.completion".to_string());
expect!(response.choices.len()).to_be(1); expect!(response.choices.len()).to_be(1);
@@ -91,13 +94,6 @@ async fn test_client_chat_async_with_function_calling() {
.to_be(ChatResponseChoiceFinishReason::ToolCalls); .to_be(ChatResponseChoiceFinishReason::ToolCalls);
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
expect!(response.choices[0].message.content.clone()).to_be("".to_string());
// expect!(response.choices[0].message.tool_calls.clone()).to_be(Some(vec![ToolCall {
// function: ToolCallFunction {
// name: "get_city_temperature".to_string(),
// arguments: "{\"city\": \"Paris\"}".to_string(),
// },
// }]));
expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.prompt_tokens).to_be_greater_than(0);
expect!(response.usage.completion_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0);

View File

@@ -1,16 +1,18 @@
// Streaming tests require a live API key and are not run in CI.
// Uncomment to test locally.
// use futures::stream::StreamExt; // use futures::stream::StreamExt;
// use jrest::expect;
// use mistralai_client::v1::{ // use mistralai_client::v1::{
// chat_completion::{ChatParams, ChatMessage, ChatMessageRole}, // chat::{ChatMessage, ChatParams},
// client::Client, // client::Client,
// constants::Model, // constants::Model,
// }; // };
//
// #[tokio::test] // #[tokio::test]
// async fn test_client_chat_stream() { // async fn test_client_chat_stream() {
// let client = Client::new(None, None, None, None).unwrap(); // let client = Client::new(None, None, None, None).unwrap();
//
// let model = Model::OpenMistral7b; // let model = Model::mistral_small_latest();
// let messages = vec![ChatMessage::new_user_message( // let messages = vec![ChatMessage::new_user_message(
// "Just guess the next word: \"Eiffel ...\"?", // "Just guess the next word: \"Eiffel ...\"?",
// )]; // )];
@@ -19,22 +21,24 @@
// random_seed: Some(42), // random_seed: Some(42),
// ..Default::default() // ..Default::default()
// }; // };
//
// let stream_result = client.chat_stream(model, messages, Some(options)).await; // let stream = client
// let mut stream = stream_result.expect("Failed to create stream."); // .chat_stream(model, messages, Some(options))
// while let Some(maybe_chunk_result) = stream.next().await { // .await
// match maybe_chunk_result { // .expect("Failed to create stream.");
// Some(Ok(chunk)) => { //
// if chunk.choices[0].delta.role == Some(ChatMessageRole::Assistant) // stream
// || chunk.choices[0].finish_reason == Some("stop".to_string()) // .for_each(|chunk_result| async {
// { // match chunk_result {
// expect!(chunk.choices[0].delta.content.len()).to_be(0); // Ok(chunks) => {
// } else { // for chunk in &chunks {
// expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0); // if let Some(content) = &chunk.choices[0].delta.content {
// print!("{}", content);
// }
// }
// } // }
// Err(error) => eprintln!("Error: {:?}", error),
// } // }
// Some(Err(error)) => eprintln!("Error processing chunk: {:?}", error), // })
// None => (), // .await;
// }
// }
// } // }

View File

@@ -3,7 +3,7 @@ use mistralai_client::v1::{
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason}, chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
client::Client, client::Client,
constants::Model, constants::Model,
tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType}, tool::{Tool, ToolChoice},
}; };
mod setup; mod setup;
@@ -14,19 +14,18 @@ fn test_client_chat() {
let client = Client::new(None, None, None, None).unwrap(); let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message( let messages = vec![ChatMessage::new_user_message(
"Guess the next word: \"Eiffel ...\"?", "Guess the next word: \"Eiffel ...\"?",
)]; )];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
let response = client.chat(model, messages, Some(options)).unwrap(); let response = client.chat(model, messages, Some(options)).unwrap();
expect!(response.model).to_be(Model::OpenMistral7b);
expect!(response.object).to_be("chat.completion".to_string()); expect!(response.object).to_be("chat.completion".to_string());
expect!(response.choices.len()).to_be(1); expect!(response.choices.len()).to_be(1);
expect!(response.choices[0].index).to_be(0); expect!(response.choices[0].index).to_be(0);
@@ -50,21 +49,26 @@ fn test_client_chat_with_function_calling() {
let tools = vec![Tool::new( let tools = vec![Tool::new(
"get_city_temperature".to_string(), "get_city_temperature".to_string(),
"Get the current temperature in a city.".to_string(), "Get the current temperature in a city.".to_string(),
vec![ToolFunctionParameter::new( serde_json::json!({
"city".to_string(), "type": "object",
"The name of the city.".to_string(), "properties": {
ToolFunctionParameterType::String, "city": {
)], "type": "string",
"description": "The name of the city."
}
},
"required": ["city"]
}),
)]; )];
let client = Client::new(None, None, None, None).unwrap(); let client = Client::new(None, None, None, None).unwrap();
let model = Model::MistralSmallLatest; let model = Model::mistral_small_latest();
let messages = vec![ChatMessage::new_user_message( let messages = vec![ChatMessage::new_user_message(
"What's the current temperature in Paris?", "What's the current temperature in Paris?",
)]; )];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
tools: Some(tools), tools: Some(tools),
@@ -73,12 +77,10 @@ fn test_client_chat_with_function_calling() {
let response = client.chat(model, messages, Some(options)).unwrap(); let response = client.chat(model, messages, Some(options)).unwrap();
expect!(response.model).to_be(Model::MistralSmallLatest);
expect!(response.object).to_be("chat.completion".to_string()); expect!(response.object).to_be("chat.completion".to_string());
expect!(response.choices.len()).to_be(1); expect!(response.choices.len()).to_be(1);
expect!(response.choices[0].index).to_be(0); expect!(response.choices[0].index).to_be(0);
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
expect!(response.choices[0].message.content.clone()).to_be("".to_string());
expect!(response.choices[0].finish_reason.clone()) expect!(response.choices[0].finish_reason.clone())
.to_be(ChatResponseChoiceFinishReason::ToolCalls); .to_be(ChatResponseChoiceFinishReason::ToolCalls);
expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.prompt_tokens).to_be_greater_than(0);

View File

@@ -1,11 +1,11 @@
use jrest::expect; use jrest::expect;
use mistralai_client::v1::{client::Client, constants::EmbedModel}; use mistralai_client::v1::{client::Client, constants::Model};
#[tokio::test] #[tokio::test]
async fn test_client_embeddings_async() { async fn test_client_embeddings_async() {
let client: Client = Client::new(None, None, None, None).unwrap(); let client: Client = Client::new(None, None, None, None).unwrap();
let model = EmbedModel::MistralEmbed; let model = Model::mistral_embed();
let input = vec!["Embed this sentence.", "As well as this one."] let input = vec!["Embed this sentence.", "As well as this one."]
.iter() .iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
@@ -17,7 +17,6 @@ async fn test_client_embeddings_async() {
.await .await
.unwrap(); .unwrap();
expect!(response.model).to_be(EmbedModel::MistralEmbed);
expect!(response.object).to_be("list".to_string()); expect!(response.object).to_be("list".to_string());
expect!(response.data.len()).to_be(2); expect!(response.data.len()).to_be(2);
expect!(response.data[0].index).to_be(0); expect!(response.data[0].index).to_be(0);

View File

@@ -1,11 +1,11 @@
use jrest::expect; use jrest::expect;
use mistralai_client::v1::{client::Client, constants::EmbedModel}; use mistralai_client::v1::{client::Client, constants::Model};
#[test] #[test]
fn test_client_embeddings() { fn test_client_embeddings() {
let client: Client = Client::new(None, None, None, None).unwrap(); let client: Client = Client::new(None, None, None, None).unwrap();
let model = EmbedModel::MistralEmbed; let model = Model::mistral_embed();
let input = vec!["Embed this sentence.", "As well as this one."] let input = vec!["Embed this sentence.", "As well as this one."]
.iter() .iter()
.map(|s| s.to_string()) .map(|s| s.to_string())
@@ -14,7 +14,6 @@ fn test_client_embeddings() {
let response = client.embeddings(model, input, options).unwrap(); let response = client.embeddings(model, input, options).unwrap();
expect!(response.model).to_be(EmbedModel::MistralEmbed);
expect!(response.object).to_be("list".to_string()); expect!(response.object).to_be("list".to_string());
expect!(response.data.len()).to_be(2); expect!(response.data.len()).to_be(2);
expect!(response.data[0].index).to_be(0); expect!(response.data[0].index).to_be(0);

View File

@@ -6,26 +6,19 @@ use mistralai_client::v1::{
}; };
#[test] #[test]
fn test_model_constant() { fn test_model_constants() {
let models = vec![ let models = vec![
Model::OpenMistral7b, Model::mistral_small_latest(),
Model::OpenMixtral8x7b, Model::mistral_large_latest(),
Model::OpenMixtral8x22b, Model::open_mistral_nemo(),
Model::OpenMistralNemo, Model::codestral_latest(),
Model::MistralTiny,
Model::MistralSmallLatest,
Model::MistralMediumLatest,
Model::MistralLargeLatest,
Model::MistralLarge,
Model::CodestralLatest,
Model::CodestralMamba,
]; ];
let client = Client::new(None, None, None, None).unwrap(); let client = Client::new(None, None, None, None).unwrap();
let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")]; let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")];
let options = ChatParams { let options = ChatParams {
temperature: 0.0, temperature: Some(0.0),
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };