Update to latest Mistral AI API (v1.0.0)
- Replace closed Model enum with flexible string-based Model type with constructor methods for all current models (Mistral Large 3, Small 4, Magistral, Codestral, Devstral, Pixtral, Voxtral, etc.) - Add new API endpoints: FIM completions, Files, Fine-tuning, Batch jobs, OCR, Audio transcription, Moderations/Classifications, and Agent completions (sync + async for all) - Add new chat fields: frequency_penalty, presence_penalty, stop, n, parallel_tool_calls, reasoning_effort, min_tokens, json_schema response format - Add embedding fields: output_dimension, output_dtype - Tool parameters now accept raw JSON Schema (serde_json::Value) instead of limited enum types - Add tool call IDs and Required tool choice variant - Add DELETE HTTP method support and multipart file upload - Bump thiserror to v2, add reqwest multipart feature - Remove strum dependency (no longer needed) - Update all tests and examples for new API
This commit is contained in:
27
Cargo.toml
27
Cargo.toml
@@ -2,7 +2,7 @@
|
|||||||
name = "mistralai-client"
|
name = "mistralai-client"
|
||||||
description = "Mistral AI API client library for Rust (unofficial)."
|
description = "Mistral AI API client library for Rust (unofficial)."
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
version = "0.14.0"
|
version = "1.0.0"
|
||||||
|
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
rust-version = "1.76.0"
|
rust-version = "1.76.0"
|
||||||
@@ -15,18 +15,17 @@ readme = "README.md"
|
|||||||
repository = "https://github.com/ivangabriele/mistralai-client-rs"
|
repository = "https://github.com/ivangabriele/mistralai-client-rs"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-stream = "0.3.5"
|
async-stream = "0.3"
|
||||||
async-trait = "0.1.77"
|
async-trait = "0.1"
|
||||||
env_logger = "0.11.3"
|
env_logger = "0.11"
|
||||||
futures = "0.3.30"
|
futures = "0.3"
|
||||||
log = "0.4.21"
|
log = "0.4"
|
||||||
reqwest = { version = "0.12.0", features = ["json", "blocking", "stream"] }
|
reqwest = { version = "0.12", features = ["json", "blocking", "stream", "multipart"] }
|
||||||
serde = { version = "1.0.197", features = ["derive"] }
|
serde = { version = "1", features = ["derive"] }
|
||||||
serde_json = "1.0.114"
|
serde_json = "1"
|
||||||
strum = "0.26.1"
|
thiserror = "2"
|
||||||
thiserror = "1.0.57"
|
tokio = { version = "1", features = ["full"] }
|
||||||
tokio = { version = "1.36.0", features = ["full"] }
|
tokio-stream = "0.1"
|
||||||
tokio-stream = "0.1.14"
|
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
jrest = "0.2.3"
|
jrest = "0.2"
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
chat::{ChatMessage, ChatParams},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
};
|
};
|
||||||
@@ -8,14 +8,12 @@ fn main() {
|
|||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
role: ChatMessageRole::User,
|
"Just guess the next word: \"Eiffel ...\"?",
|
||||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
)];
|
||||||
tool_calls: None,
|
|
||||||
}];
|
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
chat::{ChatMessage, ChatParams},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
};
|
};
|
||||||
@@ -9,14 +9,12 @@ async fn main() {
|
|||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
role: ChatMessageRole::User,
|
"Just guess the next word: \"Eiffel ...\"?",
|
||||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
)];
|
||||||
tool_calls: None,
|
|
||||||
}];
|
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -29,5 +27,4 @@ async fn main() {
|
|||||||
"{:?}: {}",
|
"{:?}: {}",
|
||||||
result.choices[0].message.role, result.choices[0].message.content
|
result.choices[0].message.role, result.choices[0].message.content
|
||||||
);
|
);
|
||||||
// => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
chat::{ChatMessage, ChatParams},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
|
tool::{Function, Tool, ToolChoice},
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::any::Any;
|
use std::any::Any;
|
||||||
@@ -16,7 +16,6 @@ struct GetCityTemperatureFunction;
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl Function for GetCityTemperatureFunction {
|
impl Function for GetCityTemperatureFunction {
|
||||||
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
||||||
// Deserialize arguments, perform the logic, and return the result
|
|
||||||
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
|
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
|
||||||
|
|
||||||
let temperature = match city.as_str() {
|
let temperature = match city.as_str() {
|
||||||
@@ -32,11 +31,16 @@ fn main() {
|
|||||||
let tools = vec![Tool::new(
|
let tools = vec![Tool::new(
|
||||||
"get_city_temperature".to_string(),
|
"get_city_temperature".to_string(),
|
||||||
"Get the current temperature in a city.".to_string(),
|
"Get the current temperature in a city.".to_string(),
|
||||||
vec![ToolFunctionParameter::new(
|
serde_json::json!({
|
||||||
"city".to_string(),
|
"type": "object",
|
||||||
"The name of the city.".to_string(),
|
"properties": {
|
||||||
ToolFunctionParameterType::String,
|
"city": {
|
||||||
)],
|
"type": "string",
|
||||||
|
"description": "The name of the city."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
)];
|
)];
|
||||||
|
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
@@ -46,14 +50,12 @@ fn main() {
|
|||||||
Box::new(GetCityTemperatureFunction),
|
Box::new(GetCityTemperatureFunction),
|
||||||
);
|
);
|
||||||
|
|
||||||
let model = Model::MistralSmallLatest;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
role: ChatMessageRole::User,
|
"What's the temperature in Paris?",
|
||||||
content: "What's the temperature in Paris?".to_string(),
|
)];
|
||||||
tool_calls: None,
|
|
||||||
}];
|
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
chat::{ChatMessage, ChatParams},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
|
tool::{Function, Tool, ToolChoice},
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::any::Any;
|
use std::any::Any;
|
||||||
@@ -16,7 +16,6 @@ struct GetCityTemperatureFunction;
|
|||||||
#[async_trait::async_trait]
|
#[async_trait::async_trait]
|
||||||
impl Function for GetCityTemperatureFunction {
|
impl Function for GetCityTemperatureFunction {
|
||||||
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
||||||
// Deserialize arguments, perform the logic, and return the result
|
|
||||||
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
|
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
|
||||||
|
|
||||||
let temperature = match city.as_str() {
|
let temperature = match city.as_str() {
|
||||||
@@ -33,11 +32,16 @@ async fn main() {
|
|||||||
let tools = vec![Tool::new(
|
let tools = vec![Tool::new(
|
||||||
"get_city_temperature".to_string(),
|
"get_city_temperature".to_string(),
|
||||||
"Get the current temperature in a city.".to_string(),
|
"Get the current temperature in a city.".to_string(),
|
||||||
vec![ToolFunctionParameter::new(
|
serde_json::json!({
|
||||||
"city".to_string(),
|
"type": "object",
|
||||||
"The name of the city.".to_string(),
|
"properties": {
|
||||||
ToolFunctionParameterType::String,
|
"city": {
|
||||||
)],
|
"type": "string",
|
||||||
|
"description": "The name of the city."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
)];
|
)];
|
||||||
|
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
@@ -47,14 +51,12 @@ async fn main() {
|
|||||||
Box::new(GetCityTemperatureFunction),
|
Box::new(GetCityTemperatureFunction),
|
||||||
);
|
);
|
||||||
|
|
||||||
let model = Model::MistralSmallLatest;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
role: ChatMessageRole::User,
|
"What's the temperature in Paris?",
|
||||||
content: "What's the temperature in Paris?".to_string(),
|
)];
|
||||||
tool_calls: None,
|
|
||||||
}];
|
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use futures::stream::StreamExt;
|
use futures::stream::StreamExt;
|
||||||
use mistralai_client::v1::{
|
use mistralai_client::v1::{
|
||||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
chat::{ChatMessage, ChatParams},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
};
|
};
|
||||||
@@ -11,14 +11,10 @@ async fn main() {
|
|||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage {
|
let messages = vec![ChatMessage::new_user_message("Tell me a short happy story.")];
|
||||||
role: ChatMessageRole::User,
|
|
||||||
content: "Tell me a short happy story.".to_string(),
|
|
||||||
tool_calls: None,
|
|
||||||
}];
|
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -31,9 +27,10 @@ async fn main() {
|
|||||||
.for_each(|chunk_result| async {
|
.for_each(|chunk_result| async {
|
||||||
match chunk_result {
|
match chunk_result {
|
||||||
Ok(chunks) => chunks.iter().for_each(|chunk| {
|
Ok(chunks) => chunks.iter().for_each(|chunk| {
|
||||||
print!("{}", chunk.choices[0].delta.content);
|
if let Some(content) = &chunk.choices[0].delta.content {
|
||||||
io::stdout().flush().unwrap();
|
print!("{}", content);
|
||||||
// => "Once upon a time, [...]"
|
io::stdout().flush().unwrap();
|
||||||
|
}
|
||||||
}),
|
}),
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
eprintln!("Error processing chunk: {:?}", error)
|
eprintln!("Error processing chunk: {:?}", error)
|
||||||
@@ -41,5 +38,5 @@ async fn main() {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
.await;
|
.await;
|
||||||
print!("\n") // To persist the last chunk output.
|
println!();
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
let model = Model::mistral_embed();
|
||||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
@@ -13,5 +13,4 @@ fn main() {
|
|||||||
|
|
||||||
let response = client.embeddings(model, input, options).unwrap();
|
let response = client.embeddings(model, input, options).unwrap();
|
||||||
println!("First Embedding: {:?}", response.data[0]);
|
println!("First Embedding: {:?}", response.data[0]);
|
||||||
// => "First Embedding: {...}"
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
let model = Model::mistral_embed();
|
||||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
@@ -17,5 +17,4 @@ async fn main() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
println!("First Embedding: {:?}", response.data[0]);
|
println!("First Embedding: {:?}", response.data[0]);
|
||||||
// => "First Embedding: {...}"
|
|
||||||
}
|
}
|
||||||
|
|||||||
21
examples/fim.rs
Normal file
21
examples/fim.rs
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
fim::FimParams,
|
||||||
|
};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let model = Model::codestral_latest();
|
||||||
|
let prompt = "def fibonacci(n):".to_string();
|
||||||
|
let options = FimParams {
|
||||||
|
suffix: Some("\n return result".to_string()),
|
||||||
|
temperature: Some(0.0),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.fim(model, prompt, Some(options)).unwrap();
|
||||||
|
println!("Completion: {}", response.choices[0].message.content);
|
||||||
|
}
|
||||||
25
examples/ocr.rs
Normal file
25
examples/ocr.rs
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
use mistralai_client::v1::{
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
ocr::{OcrDocument, OcrRequest},
|
||||||
|
};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let request = OcrRequest {
|
||||||
|
model: Model::mistral_ocr_latest(),
|
||||||
|
document: OcrDocument::from_url("https://arxiv.org/pdf/2201.04234"),
|
||||||
|
pages: Some(vec![0]),
|
||||||
|
table_format: None,
|
||||||
|
include_image_base64: None,
|
||||||
|
image_limit: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = client.ocr(&request).unwrap();
|
||||||
|
for page in &response.pages {
|
||||||
|
println!("--- Page {} ---", page.index);
|
||||||
|
println!("{}", &page.markdown[..200.min(page.markdown.len())]);
|
||||||
|
}
|
||||||
|
}
|
||||||
98
src/v1/agents.rs
Normal file
98
src/v1/agents.rs
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::{chat, common, constants, tool};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AgentCompletionParams {
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
pub response_format: Option<chat::ResponseFormat>,
|
||||||
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
|
}
|
||||||
|
impl Default for AgentCompletionParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_tokens: None,
|
||||||
|
min_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
random_seed: None,
|
||||||
|
stop: None,
|
||||||
|
response_format: None,
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct AgentCompletionRequest {
|
||||||
|
pub agent_id: String,
|
||||||
|
pub messages: Vec<chat::ChatMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub response_format: Option<chat::ResponseFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
|
}
|
||||||
|
impl AgentCompletionRequest {
|
||||||
|
pub fn new(
|
||||||
|
agent_id: String,
|
||||||
|
messages: Vec<chat::ChatMessage>,
|
||||||
|
stream: bool,
|
||||||
|
options: Option<AgentCompletionParams>,
|
||||||
|
) -> Self {
|
||||||
|
let opts = options.unwrap_or_default();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
agent_id,
|
||||||
|
messages,
|
||||||
|
stream,
|
||||||
|
max_tokens: opts.max_tokens,
|
||||||
|
min_tokens: opts.min_tokens,
|
||||||
|
temperature: opts.temperature,
|
||||||
|
top_p: opts.top_p,
|
||||||
|
random_seed: opts.random_seed,
|
||||||
|
stop: opts.stop,
|
||||||
|
response_format: opts.response_format,
|
||||||
|
tools: opts.tools,
|
||||||
|
tool_choice: opts.tool_choice,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response (same shape as chat completions)
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct AgentCompletionResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub choices: Vec<chat::ChatResponseChoice>,
|
||||||
|
pub usage: common::ResponseUsage,
|
||||||
|
}
|
||||||
78
src/v1/audio.rs
Normal file
78
src/v1/audio.rs
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request (multipart form, but we define the params struct)
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AudioTranscriptionParams {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub language: Option<String>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub diarize: Option<bool>,
|
||||||
|
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
|
||||||
|
}
|
||||||
|
impl Default for AudioTranscriptionParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
model: constants::Model::voxtral_mini_transcribe(),
|
||||||
|
language: None,
|
||||||
|
temperature: None,
|
||||||
|
diarize: None,
|
||||||
|
timestamp_granularities: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum TimestampGranularity {
|
||||||
|
Segment,
|
||||||
|
Word,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct AudioTranscriptionResponse {
|
||||||
|
pub text: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub language: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub segments: Option<Vec<TranscriptionSegment>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub words: Option<Vec<TranscriptionWord>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage: Option<AudioUsage>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct TranscriptionSegment {
|
||||||
|
pub id: u32,
|
||||||
|
pub start: f32,
|
||||||
|
pub end: f32,
|
||||||
|
pub text: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub speaker: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct TranscriptionWord {
|
||||||
|
pub word: String,
|
||||||
|
pub start: f32,
|
||||||
|
pub end: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct AudioUsage {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prompt_audio_seconds: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub prompt_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub total_tokens: Option<u32>,
|
||||||
|
}
|
||||||
53
src/v1/batch.rs
Normal file
53
src/v1/batch.rs
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct BatchJobRequest {
|
||||||
|
pub input_files: Vec<String>,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub endpoint: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct BatchJobResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub endpoint: String,
|
||||||
|
pub input_files: Vec<String>,
|
||||||
|
pub status: String,
|
||||||
|
pub created_at: u64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_at: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_file: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub error_file: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub metadata: Option<serde_json::Value>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub total_requests: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub completed_requests: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub succeeded_requests: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub failed_requests: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct BatchJobListResponse {
|
||||||
|
pub data: Vec<BatchJobResponse>,
|
||||||
|
pub object: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub total: u32,
|
||||||
|
}
|
||||||
187
src/v1/chat.rs
187
src/v1/chat.rs
@@ -11,13 +11,31 @@ pub struct ChatMessage {
|
|||||||
pub content: String,
|
pub content: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||||
|
/// Tool call ID, required when role is Tool.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
|
/// Function name, used when role is Tool.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
}
|
}
|
||||||
impl ChatMessage {
|
impl ChatMessage {
|
||||||
|
pub fn new_system_message(content: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::System,
|
||||||
|
content: content.to_string(),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
|
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
role: ChatMessageRole::Assistant,
|
role: ChatMessageRole::Assistant,
|
||||||
content: content.to_string(),
|
content: content.to_string(),
|
||||||
tool_calls,
|
tool_calls,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,6 +44,18 @@ impl ChatMessage {
|
|||||||
role: ChatMessageRole::User,
|
role: ChatMessageRole::User,
|
||||||
content: content.to_string(),
|
content: content.to_string(),
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::Tool,
|
||||||
|
content: content.to_string(),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: Some(tool_call_id.to_string()),
|
||||||
|
name: name.map(|n| n.to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -44,17 +74,32 @@ pub enum ChatMessageRole {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The format that the model must output.
|
/// The format that the model must output.
|
||||||
///
|
|
||||||
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct ResponseFormat {
|
pub struct ResponseFormat {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
pub type_: String,
|
pub type_: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub json_schema: Option<serde_json::Value>,
|
||||||
}
|
}
|
||||||
impl ResponseFormat {
|
impl ResponseFormat {
|
||||||
|
pub fn text() -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "text".to_string(),
|
||||||
|
json_schema: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn json_object() -> Self {
|
pub fn json_object() -> Self {
|
||||||
Self {
|
Self {
|
||||||
type_: "json_object".to_string(),
|
type_: "json_object".to_string(),
|
||||||
|
json_schema: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn json_schema(schema: serde_json::Value) -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "json_schema".to_string(),
|
||||||
|
json_schema: Some(schema),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,91 +108,83 @@ impl ResponseFormat {
|
|||||||
// Request
|
// Request
|
||||||
|
|
||||||
/// The parameters for the chat request.
|
/// The parameters for the chat request.
|
||||||
///
|
|
||||||
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct ChatParams {
|
pub struct ChatParams {
|
||||||
/// The maximum number of tokens to generate in the completion.
|
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
|
pub min_tokens: Option<u32>,
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub random_seed: Option<u32>,
|
pub random_seed: Option<u32>,
|
||||||
/// The format that the model must output.
|
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub response_format: Option<ResponseFormat>,
|
pub response_format: Option<ResponseFormat>,
|
||||||
/// Whether to inject a safety prompt before all conversations.
|
|
||||||
///
|
|
||||||
/// Defaults to `false`.
|
|
||||||
pub safe_prompt: bool,
|
pub safe_prompt: bool,
|
||||||
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
|
pub temperature: Option<f32>,
|
||||||
///
|
|
||||||
/// Defaults to `0.7`.
|
|
||||||
pub temperature: f32,
|
|
||||||
/// Specifies if/how functions are called.
|
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub tool_choice: Option<tool::ToolChoice>,
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
/// A list of available tools for the model.
|
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub tools: Option<Vec<tool::Tool>>,
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
|
pub top_p: Option<f32>,
|
||||||
///
|
pub stop: Option<Vec<String>>,
|
||||||
/// Defaults to `1.0`.
|
pub n: Option<u32>,
|
||||||
pub top_p: f32,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
/// For reasoning models (Magistral). "high" or "none".
|
||||||
|
pub reasoning_effort: Option<String>,
|
||||||
}
|
}
|
||||||
impl Default for ChatParams {
|
impl Default for ChatParams {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
|
min_tokens: None,
|
||||||
random_seed: None,
|
random_seed: None,
|
||||||
safe_prompt: false,
|
safe_prompt: false,
|
||||||
response_format: None,
|
response_format: None,
|
||||||
temperature: 0.7,
|
temperature: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
tools: None,
|
tools: None,
|
||||||
top_p: 1.0,
|
top_p: None,
|
||||||
}
|
stop: None,
|
||||||
}
|
n: None,
|
||||||
}
|
frequency_penalty: None,
|
||||||
impl ChatParams {
|
presence_penalty: None,
|
||||||
pub fn json_default() -> Self {
|
parallel_tool_calls: None,
|
||||||
Self {
|
reasoning_effort: None,
|
||||||
max_tokens: None,
|
|
||||||
random_seed: None,
|
|
||||||
safe_prompt: false,
|
|
||||||
response_format: None,
|
|
||||||
temperature: 0.7,
|
|
||||||
tool_choice: None,
|
|
||||||
tools: None,
|
|
||||||
top_p: 1.0,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct ChatRequest {
|
pub struct ChatRequest {
|
||||||
pub messages: Vec<ChatMessage>,
|
|
||||||
pub model: constants::Model,
|
pub model: constants::Model,
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub random_seed: Option<u32>,
|
pub random_seed: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub response_format: Option<ResponseFormat>,
|
pub response_format: Option<ResponseFormat>,
|
||||||
pub safe_prompt: bool,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stream: bool,
|
pub safe_prompt: Option<bool>,
|
||||||
pub temperature: f32,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_choice: Option<tool::ToolChoice>,
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tools: Option<Vec<tool::Tool>>,
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
pub top_p: f32,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_effort: Option<String>,
|
||||||
}
|
}
|
||||||
impl ChatRequest {
|
impl ChatRequest {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
@@ -156,30 +193,28 @@ impl ChatRequest {
|
|||||||
stream: bool,
|
stream: bool,
|
||||||
options: Option<ChatParams>,
|
options: Option<ChatParams>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let ChatParams {
|
let opts = options.unwrap_or_default();
|
||||||
max_tokens,
|
let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
|
||||||
random_seed,
|
|
||||||
safe_prompt,
|
|
||||||
temperature,
|
|
||||||
tool_choice,
|
|
||||||
tools,
|
|
||||||
top_p,
|
|
||||||
response_format,
|
|
||||||
} = options.unwrap_or_default();
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
messages,
|
|
||||||
model,
|
model,
|
||||||
|
messages,
|
||||||
max_tokens,
|
|
||||||
random_seed,
|
|
||||||
safe_prompt,
|
|
||||||
stream,
|
stream,
|
||||||
temperature,
|
max_tokens: opts.max_tokens,
|
||||||
tool_choice,
|
min_tokens: opts.min_tokens,
|
||||||
tools,
|
random_seed: opts.random_seed,
|
||||||
top_p,
|
safe_prompt,
|
||||||
response_format,
|
temperature: opts.temperature,
|
||||||
|
tool_choice: opts.tool_choice,
|
||||||
|
tools: opts.tools,
|
||||||
|
top_p: opts.top_p,
|
||||||
|
response_format: opts.response_format,
|
||||||
|
stop: opts.stop,
|
||||||
|
n: opts.n,
|
||||||
|
frequency_penalty: opts.frequency_penalty,
|
||||||
|
presence_penalty: opts.presence_penalty,
|
||||||
|
parallel_tool_calls: opts.parallel_tool_calls,
|
||||||
|
reasoning_effort: opts.reasoning_effort,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,7 +227,7 @@ pub struct ChatResponse {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Unix timestamp (in seconds).
|
/// Unix timestamp (in seconds).
|
||||||
pub created: u32,
|
pub created: u64,
|
||||||
pub model: constants::Model,
|
pub model: constants::Model,
|
||||||
pub choices: Vec<ChatResponseChoice>,
|
pub choices: Vec<ChatResponseChoice>,
|
||||||
pub usage: common::ResponseUsage,
|
pub usage: common::ResponseUsage,
|
||||||
@@ -203,14 +238,18 @@ pub struct ChatResponseChoice {
|
|||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub message: ChatMessage,
|
pub message: ChatMessage,
|
||||||
pub finish_reason: ChatResponseChoiceFinishReason,
|
pub finish_reason: ChatResponseChoiceFinishReason,
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub logprobs: ???
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
pub enum ChatResponseChoiceFinishReason {
|
pub enum ChatResponseChoiceFinishReason {
|
||||||
#[serde(rename = "stop")]
|
#[serde(rename = "stop")]
|
||||||
Stop,
|
Stop,
|
||||||
|
#[serde(rename = "length")]
|
||||||
|
Length,
|
||||||
#[serde(rename = "tool_calls")]
|
#[serde(rename = "tool_calls")]
|
||||||
ToolCalls,
|
ToolCalls,
|
||||||
|
#[serde(rename = "model_length")]
|
||||||
|
ModelLength,
|
||||||
|
#[serde(rename = "error")]
|
||||||
|
Error,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
|
|
||||||
use crate::v1::{chat, common, constants, error};
|
use crate::v1::{chat, common, constants, error, tool};
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Response
|
// Response
|
||||||
@@ -11,12 +11,11 @@ pub struct ChatStreamChunk {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Unix timestamp (in seconds).
|
/// Unix timestamp (in seconds).
|
||||||
pub created: u32,
|
pub created: u64,
|
||||||
pub model: constants::Model,
|
pub model: constants::Model,
|
||||||
pub choices: Vec<ChatStreamChunkChoice>,
|
pub choices: Vec<ChatStreamChunkChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub usage: Option<common::ResponseUsage>,
|
pub usage: Option<common::ResponseUsage>,
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub logprobs: ???,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
@@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice {
|
|||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub delta: ChatStreamChunkChoiceDelta,
|
pub delta: ChatStreamChunkChoiceDelta,
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub logprobs: ???,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ChatStreamChunkChoiceDelta {
|
pub struct ChatStreamChunkChoiceDelta {
|
||||||
pub role: Option<chat::ChatMessageRole>,
|
pub role: Option<chat::ChatMessageRole>,
|
||||||
pub content: String,
|
#[serde(default)]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extracts serialized chunks from a stream message.
|
/// Extracts serialized chunks from a stream message.
|
||||||
@@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line(
|
|||||||
return Ok(Some(vec![]));
|
return Ok(Some(vec![]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to deserialize the JSON string into ChatStreamChunk
|
|
||||||
match from_str::<ChatStreamChunk>(chunk_as_json) {
|
match from_str::<ChatStreamChunk>(chunk_as_json) {
|
||||||
Ok(chunk) => Ok(Some(vec![chunk])),
|
Ok(chunk) => Ok(Some(vec![chunk])),
|
||||||
Err(e) => Err(error::ApiError {
|
Err(e) => Err(error::ApiError {
|
||||||
|
|||||||
1102
src/v1/client.rs
1102
src/v1/client.rs
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ResponseUsage {
|
pub struct ResponseUsage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
|
#[serde(default)]
|
||||||
pub completion_tokens: u32,
|
pub completion_tokens: u32,
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +1,131 @@
|
|||||||
|
use std::fmt;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
/// A Mistral AI model identifier.
|
||||||
pub enum Model {
|
///
|
||||||
#[serde(rename = "open-mistral-7b")]
|
/// Use the associated constants for known models, or construct with `Model::new()` for any model string.
|
||||||
OpenMistral7b,
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
#[serde(rename = "open-mixtral-8x7b")]
|
#[serde(transparent)]
|
||||||
OpenMixtral8x7b,
|
pub struct Model(pub String);
|
||||||
#[serde(rename = "open-mixtral-8x22b")]
|
|
||||||
OpenMixtral8x22b,
|
impl Model {
|
||||||
#[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")]
|
pub fn new(id: impl Into<String>) -> Self {
|
||||||
OpenMistralNemo,
|
Self(id.into())
|
||||||
#[serde(rename = "mistral-tiny")]
|
}
|
||||||
MistralTiny,
|
|
||||||
#[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")]
|
// Flagship / Premier
|
||||||
MistralSmallLatest,
|
pub fn mistral_large_latest() -> Self {
|
||||||
#[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")]
|
Self::new("mistral-large-latest")
|
||||||
MistralMediumLatest,
|
}
|
||||||
#[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")]
|
pub fn mistral_large_3() -> Self {
|
||||||
MistralLargeLatest,
|
Self::new("mistral-large-3-25-12")
|
||||||
#[serde(rename = "mistral-large-2402")]
|
}
|
||||||
MistralLarge,
|
pub fn mistral_medium_latest() -> Self {
|
||||||
#[serde(rename = "codestral-latest", alias = "codestral-2405")]
|
Self::new("mistral-medium-latest")
|
||||||
CodestralLatest,
|
}
|
||||||
#[serde(rename = "open-codestral-mamba")]
|
pub fn mistral_medium_3_1() -> Self {
|
||||||
CodestralMamba,
|
Self::new("mistral-medium-3-1-25-08")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_latest() -> Self {
|
||||||
|
Self::new("mistral-small-latest")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_4() -> Self {
|
||||||
|
Self::new("mistral-small-4-0-26-03")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_3_2() -> Self {
|
||||||
|
Self::new("mistral-small-3-2-25-06")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ministral
|
||||||
|
pub fn ministral_3_14b() -> Self {
|
||||||
|
Self::new("ministral-3-14b-25-12")
|
||||||
|
}
|
||||||
|
pub fn ministral_3_8b() -> Self {
|
||||||
|
Self::new("ministral-3-8b-25-12")
|
||||||
|
}
|
||||||
|
pub fn ministral_3_3b() -> Self {
|
||||||
|
Self::new("ministral-3-3b-25-12")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reasoning
|
||||||
|
pub fn magistral_medium_latest() -> Self {
|
||||||
|
Self::new("magistral-medium-latest")
|
||||||
|
}
|
||||||
|
pub fn magistral_small_latest() -> Self {
|
||||||
|
Self::new("magistral-small-latest")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code
|
||||||
|
pub fn codestral_latest() -> Self {
|
||||||
|
Self::new("codestral-latest")
|
||||||
|
}
|
||||||
|
pub fn codestral_2508() -> Self {
|
||||||
|
Self::new("codestral-2508")
|
||||||
|
}
|
||||||
|
pub fn codestral_embed() -> Self {
|
||||||
|
Self::new("codestral-embed-25-05")
|
||||||
|
}
|
||||||
|
pub fn devstral_2() -> Self {
|
||||||
|
Self::new("devstral-2-25-12")
|
||||||
|
}
|
||||||
|
pub fn devstral_small_2() -> Self {
|
||||||
|
Self::new("devstral-small-2-25-12")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multimodal / Vision
|
||||||
|
pub fn pixtral_large() -> Self {
|
||||||
|
Self::new("pixtral-large-2411")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audio
|
||||||
|
pub fn voxtral_mini_transcribe() -> Self {
|
||||||
|
Self::new("voxtral-mini-transcribe-2-26-02")
|
||||||
|
}
|
||||||
|
pub fn voxtral_small() -> Self {
|
||||||
|
Self::new("voxtral-small-25-07")
|
||||||
|
}
|
||||||
|
pub fn voxtral_mini() -> Self {
|
||||||
|
Self::new("voxtral-mini-25-07")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy (kept for backward compatibility)
|
||||||
|
pub fn open_mistral_nemo() -> Self {
|
||||||
|
Self::new("open-mistral-nemo")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embedding
|
||||||
|
pub fn mistral_embed() -> Self {
|
||||||
|
Self::new("mistral-embed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Moderation
|
||||||
|
pub fn mistral_moderation_latest() -> Self {
|
||||||
|
Self::new("mistral-moderation-26-03")
|
||||||
|
}
|
||||||
|
|
||||||
|
// OCR
|
||||||
|
pub fn mistral_ocr_latest() -> Self {
|
||||||
|
Self::new("mistral-ocr-latest")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
impl fmt::Display for Model {
|
||||||
pub enum EmbedModel {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
#[serde(rename = "mistral-embed")]
|
write!(f, "{}", self.0)
|
||||||
MistralEmbed,
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for Model {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
Self(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for Model {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self(s)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,42 +8,63 @@ use crate::v1::{common, constants};
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct EmbeddingRequestOptions {
|
pub struct EmbeddingRequestOptions {
|
||||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||||
|
pub output_dimension: Option<u32>,
|
||||||
|
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||||
}
|
}
|
||||||
impl Default for EmbeddingRequestOptions {
|
impl Default for EmbeddingRequestOptions {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
encoding_format: None,
|
encoding_format: None,
|
||||||
|
output_dimension: None,
|
||||||
|
output_dtype: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct EmbeddingRequest {
|
pub struct EmbeddingRequest {
|
||||||
pub model: constants::EmbedModel,
|
pub model: constants::Model,
|
||||||
pub input: Vec<String>,
|
pub input: Vec<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_dimension: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||||
}
|
}
|
||||||
impl EmbeddingRequest {
|
impl EmbeddingRequest {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
model: constants::EmbedModel,
|
model: constants::Model,
|
||||||
input: Vec<String>,
|
input: Vec<String>,
|
||||||
options: Option<EmbeddingRequestOptions>,
|
options: Option<EmbeddingRequestOptions>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default();
|
let opts = options.unwrap_or_default();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
input,
|
input,
|
||||||
encoding_format,
|
encoding_format: opts.encoding_format,
|
||||||
|
output_dimension: opts.output_dimension,
|
||||||
|
output_dtype: opts.output_dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
#[allow(non_camel_case_types)]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum EmbeddingRequestEncodingFormat {
|
pub enum EmbeddingRequestEncodingFormat {
|
||||||
float,
|
Float,
|
||||||
|
Base64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum EmbeddingOutputDtype {
|
||||||
|
Float,
|
||||||
|
Int8,
|
||||||
|
Uint8,
|
||||||
|
Binary,
|
||||||
|
Ubinary,
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
@@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct EmbeddingResponse {
|
pub struct EmbeddingResponse {
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
pub object: String,
|
||||||
pub model: constants::EmbedModel,
|
pub model: constants::Model,
|
||||||
pub data: Vec<EmbeddingResponseDataItem>,
|
pub data: Vec<EmbeddingResponseDataItem>,
|
||||||
pub usage: common::ResponseUsage,
|
pub usage: common::ResponseUsage,
|
||||||
}
|
}
|
||||||
|
|||||||
55
src/v1/files.rs
Normal file
55
src/v1/files.rs
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "kebab-case")]
|
||||||
|
pub enum FilePurpose {
|
||||||
|
FineTune,
|
||||||
|
Batch,
|
||||||
|
Ocr,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FileListResponse {
|
||||||
|
pub data: Vec<FileObject>,
|
||||||
|
pub object: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub total: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FileObject {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub bytes: u64,
|
||||||
|
pub created_at: u64,
|
||||||
|
pub filename: String,
|
||||||
|
pub purpose: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub sample_type: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub source: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub num_lines: Option<u64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub mimetype: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FileDeleteResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub deleted: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FileUrlResponse {
|
||||||
|
pub url: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub expires_at: Option<u64>,
|
||||||
|
}
|
||||||
101
src/v1/fim.rs
Normal file
101
src/v1/fim.rs
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::{common, constants};
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct FimParams {
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
}
|
||||||
|
impl Default for FimParams {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
suffix: None,
|
||||||
|
max_tokens: None,
|
||||||
|
min_tokens: None,
|
||||||
|
temperature: None,
|
||||||
|
top_p: None,
|
||||||
|
stop: None,
|
||||||
|
random_seed: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FimRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub prompt: String,
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub random_seed: Option<u32>,
|
||||||
|
}
|
||||||
|
impl FimRequest {
|
||||||
|
pub fn new(
|
||||||
|
model: constants::Model,
|
||||||
|
prompt: String,
|
||||||
|
stream: bool,
|
||||||
|
options: Option<FimParams>,
|
||||||
|
) -> Self {
|
||||||
|
let opts = options.unwrap_or_default();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
stream,
|
||||||
|
suffix: opts.suffix,
|
||||||
|
max_tokens: opts.max_tokens,
|
||||||
|
min_tokens: opts.min_tokens,
|
||||||
|
temperature: opts.temperature,
|
||||||
|
top_p: opts.top_p,
|
||||||
|
stop: opts.stop,
|
||||||
|
random_seed: opts.random_seed,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FimResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub created: u64,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub choices: Vec<FimResponseChoice>,
|
||||||
|
pub usage: common::ResponseUsage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FimResponseChoice {
|
||||||
|
pub index: u32,
|
||||||
|
pub message: FimResponseMessage,
|
||||||
|
pub finish_reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FimResponseMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
101
src/v1/fine_tuning.rs
Normal file
101
src/v1/fine_tuning.rs
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct FineTuningJobRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub training_files: Vec<TrainingFile>,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub validation_files: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hyperparameters: Option<Hyperparameters>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub auto_start: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub job_type: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub integrations: Option<Vec<Integration>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct TrainingFile {
|
||||||
|
pub file_id: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub weight: Option<f32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Hyperparameters {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub learning_rate: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub training_steps: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub warmup_fraction: Option<f64>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub epochs: Option<f64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Integration {
|
||||||
|
pub r#type: String,
|
||||||
|
pub project: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub api_key: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FineTuningJobResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub status: FineTuningJobStatus,
|
||||||
|
pub created_at: u64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub modified_at: Option<u64>,
|
||||||
|
pub training_files: Vec<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub validation_files: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub hyperparameters: Option<Hyperparameters>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub fine_tuned_model: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub integrations: Option<Vec<Integration>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub trained_tokens: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||||
|
pub enum FineTuningJobStatus {
|
||||||
|
Queued,
|
||||||
|
Running,
|
||||||
|
Success,
|
||||||
|
Failed,
|
||||||
|
TimeoutExceeded,
|
||||||
|
CancellationRequested,
|
||||||
|
Cancelled,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct FineTuningJobListResponse {
|
||||||
|
pub data: Vec<FineTuningJobResponse>,
|
||||||
|
pub object: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub total: u32,
|
||||||
|
}
|
||||||
@@ -1,3 +1,6 @@
|
|||||||
|
pub mod agents;
|
||||||
|
pub mod audio;
|
||||||
|
pub mod batch;
|
||||||
pub mod chat;
|
pub mod chat;
|
||||||
pub mod chat_stream;
|
pub mod chat_stream;
|
||||||
pub mod client;
|
pub mod client;
|
||||||
@@ -5,6 +8,11 @@ pub mod common;
|
|||||||
pub mod constants;
|
pub mod constants;
|
||||||
pub mod embedding;
|
pub mod embedding;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
|
pub mod files;
|
||||||
|
pub mod fim;
|
||||||
|
pub mod fine_tuning;
|
||||||
pub mod model_list;
|
pub mod model_list;
|
||||||
|
pub mod moderation;
|
||||||
|
pub mod ocr;
|
||||||
pub mod tool;
|
pub mod tool;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|||||||
@@ -15,23 +15,44 @@ pub struct ModelListData {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Unix timestamp (in seconds).
|
/// Unix timestamp (in seconds).
|
||||||
pub created: u32,
|
pub created: u64,
|
||||||
pub owned_by: String,
|
pub owned_by: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub root: Option<String>,
|
pub root: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
pub archived: bool,
|
pub archived: bool,
|
||||||
pub name: String,
|
#[serde(default)]
|
||||||
pub description: String,
|
pub name: Option<String>,
|
||||||
pub capabilities: ModelListDataCapabilies,
|
#[serde(default)]
|
||||||
pub max_context_length: u32,
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub capabilities: Option<ModelListDataCapabilities>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_context_length: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
pub aliases: Vec<String>,
|
pub aliases: Vec<String>,
|
||||||
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
|
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub deprecation: Option<String>,
|
pub deprecation: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ModelListDataCapabilies {
|
pub struct ModelListDataCapabilities {
|
||||||
|
#[serde(default)]
|
||||||
pub completion_chat: bool,
|
pub completion_chat: bool,
|
||||||
|
#[serde(default)]
|
||||||
pub completion_fim: bool,
|
pub completion_fim: bool,
|
||||||
|
#[serde(default)]
|
||||||
pub function_calling: bool,
|
pub function_calling: bool,
|
||||||
|
#[serde(default)]
|
||||||
pub fine_tuning: bool,
|
pub fine_tuning: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub vision: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ModelDeleteResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub deleted: bool,
|
||||||
}
|
}
|
||||||
|
|||||||
70
src/v1/moderation.rs
Normal file
70
src/v1/moderation.rs
Normal file
@@ -0,0 +1,70 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ModerationRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub input: Vec<String>,
|
||||||
|
}
|
||||||
|
impl ModerationRequest {
|
||||||
|
pub fn new(model: constants::Model, input: Vec<String>) -> Self {
|
||||||
|
Self { model, input }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatModerationRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub input: Vec<ChatModerationInput>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatModerationInput {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ClassificationRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub input: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ChatClassificationRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub input: Vec<ChatModerationInput>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ModerationResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub results: Vec<ModerationResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ModerationResult {
|
||||||
|
pub categories: serde_json::Value,
|
||||||
|
pub category_scores: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ClassificationResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub results: Vec<ClassificationResult>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ClassificationResult {
|
||||||
|
pub categories: serde_json::Value,
|
||||||
|
pub category_scores: serde_json::Value,
|
||||||
|
}
|
||||||
96
src/v1/ocr.rs
Normal file
96
src/v1/ocr.rs
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::v1::constants;
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Request
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct OcrRequest {
|
||||||
|
pub model: constants::Model,
|
||||||
|
pub document: OcrDocument,
|
||||||
|
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub pages: Option<Vec<u32>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub table_format: Option<OcrTableFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub include_image_base64: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub image_limit: Option<u32>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct OcrDocument {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub type_: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub document_url: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub file_id: Option<String>,
|
||||||
|
}
|
||||||
|
impl OcrDocument {
|
||||||
|
pub fn from_url(url: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "document_url".to_string(),
|
||||||
|
document_url: Some(url.to_string()),
|
||||||
|
file_id: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn from_file_id(file_id: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "file_id".to_string(),
|
||||||
|
document_url: None,
|
||||||
|
file_id: Some(file_id.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum OcrTableFormat {
|
||||||
|
Markdown,
|
||||||
|
Html,
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----------------------------------------------------------------------------
|
||||||
|
// Response
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrResponse {
|
||||||
|
pub pages: Vec<OcrPage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub usage_info: Option<OcrUsageInfo>,
|
||||||
|
pub model: constants::Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrPage {
|
||||||
|
pub index: u32,
|
||||||
|
pub markdown: String,
|
||||||
|
#[serde(default)]
|
||||||
|
pub images: Vec<OcrImage>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub dimensions: Option<OcrPageDimensions>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrImage {
|
||||||
|
pub id: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub image_base64: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrPageDimensions {
|
||||||
|
pub width: f32,
|
||||||
|
pub height: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct OcrUsageInfo {
|
||||||
|
pub pages_processed: u32,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub doc_size_bytes: Option<u64>,
|
||||||
|
}
|
||||||
@@ -1,12 +1,16 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{any::Any, collections::HashMap, fmt::Debug};
|
use std::{any::Any, fmt::Debug};
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Definitions
|
// Definitions
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub r#type: Option<String>,
|
||||||
pub function: ToolCallFunction,
|
pub function: ToolCallFunction,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,31 +26,12 @@ pub struct Tool {
|
|||||||
pub function: ToolFunction,
|
pub function: ToolFunction,
|
||||||
}
|
}
|
||||||
impl Tool {
|
impl Tool {
|
||||||
|
/// Create a tool with a JSON Schema parameters object.
|
||||||
pub fn new(
|
pub fn new(
|
||||||
function_name: String,
|
function_name: String,
|
||||||
function_description: String,
|
function_description: String,
|
||||||
function_parameters: Vec<ToolFunctionParameter>,
|
parameters: serde_json::Value,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
|
|
||||||
.into_iter()
|
|
||||||
.map(|param| {
|
|
||||||
(
|
|
||||||
param.name,
|
|
||||||
ToolFunctionParameterProperty {
|
|
||||||
r#type: param.r#type,
|
|
||||||
description: param.description,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let property_names = properties.keys().cloned().collect();
|
|
||||||
|
|
||||||
let parameters = ToolFunctionParameters {
|
|
||||||
r#type: ToolFunctionParametersType::Object,
|
|
||||||
properties,
|
|
||||||
required: property_names,
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
r#type: ToolType::Function,
|
r#type: ToolType::Function,
|
||||||
function: ToolFunction {
|
function: ToolFunction {
|
||||||
@@ -63,50 +48,9 @@ impl Tool {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ToolFunction {
|
pub struct ToolFunction {
|
||||||
name: String,
|
pub name: String,
|
||||||
description: String,
|
pub description: String,
|
||||||
parameters: ToolFunctionParameters,
|
pub parameters: serde_json::Value,
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ToolFunctionParameter {
|
|
||||||
name: String,
|
|
||||||
description: String,
|
|
||||||
r#type: ToolFunctionParameterType,
|
|
||||||
}
|
|
||||||
impl ToolFunctionParameter {
|
|
||||||
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
|
|
||||||
Self {
|
|
||||||
name,
|
|
||||||
r#type,
|
|
||||||
description,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ToolFunctionParameters {
|
|
||||||
r#type: ToolFunctionParametersType,
|
|
||||||
properties: HashMap<String, ToolFunctionParameterProperty>,
|
|
||||||
required: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ToolFunctionParameterProperty {
|
|
||||||
r#type: ToolFunctionParameterType,
|
|
||||||
description: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
|
||||||
pub enum ToolFunctionParametersType {
|
|
||||||
#[serde(rename = "object")]
|
|
||||||
Object,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
|
||||||
pub enum ToolFunctionParameterType {
|
|
||||||
#[serde(rename = "string")]
|
|
||||||
String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
@@ -127,6 +71,9 @@ pub enum ToolChoice {
|
|||||||
/// The model won't call a function and will generate a message instead.
|
/// The model won't call a function and will generate a message instead.
|
||||||
#[serde(rename = "none")]
|
#[serde(rename = "none")]
|
||||||
None,
|
None,
|
||||||
|
/// The model must call at least one tool.
|
||||||
|
#[serde(rename = "required")]
|
||||||
|
Required,
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use mistralai_client::v1::{
|
|||||||
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
|
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
|
tool::{Tool, ToolChoice},
|
||||||
};
|
};
|
||||||
|
|
||||||
mod setup;
|
mod setup;
|
||||||
@@ -14,12 +14,12 @@ async fn test_client_chat_async() {
|
|||||||
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage::new_user_message(
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
"Guess the next word: \"Eiffel ...\"?",
|
"Guess the next word: \"Eiffel ...\"?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -29,7 +29,6 @@ async fn test_client_chat_async() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(Model::OpenMistral7b);
|
|
||||||
expect!(response.object).to_be("chat.completion".to_string());
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
|
|
||||||
expect!(response.choices.len()).to_be(1);
|
expect!(response.choices.len()).to_be(1);
|
||||||
@@ -56,21 +55,26 @@ async fn test_client_chat_async_with_function_calling() {
|
|||||||
let tools = vec![Tool::new(
|
let tools = vec![Tool::new(
|
||||||
"get_city_temperature".to_string(),
|
"get_city_temperature".to_string(),
|
||||||
"Get the current temperature in a city.".to_string(),
|
"Get the current temperature in a city.".to_string(),
|
||||||
vec![ToolFunctionParameter::new(
|
serde_json::json!({
|
||||||
"city".to_string(),
|
"type": "object",
|
||||||
"The name of the city.".to_string(),
|
"properties": {
|
||||||
ToolFunctionParameterType::String,
|
"city": {
|
||||||
)],
|
"type": "string",
|
||||||
|
"description": "The name of the city."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
)];
|
)];
|
||||||
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::MistralSmallLatest;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage::new_user_message(
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
"What's the current temperature in Paris?",
|
"What's the current temperature in Paris?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Any),
|
tool_choice: Some(ToolChoice::Any),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
@@ -82,7 +86,6 @@ async fn test_client_chat_async_with_function_calling() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(Model::MistralSmallLatest);
|
|
||||||
expect!(response.object).to_be("chat.completion".to_string());
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
|
|
||||||
expect!(response.choices.len()).to_be(1);
|
expect!(response.choices.len()).to_be(1);
|
||||||
@@ -91,13 +94,6 @@ async fn test_client_chat_async_with_function_calling() {
|
|||||||
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
|
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
|
||||||
|
|
||||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
expect!(response.choices[0].message.content.clone()).to_be("".to_string());
|
|
||||||
// expect!(response.choices[0].message.tool_calls.clone()).to_be(Some(vec![ToolCall {
|
|
||||||
// function: ToolCallFunction {
|
|
||||||
// name: "get_city_temperature".to_string(),
|
|
||||||
// arguments: "{\"city\": \"Paris\"}".to_string(),
|
|
||||||
// },
|
|
||||||
// }]));
|
|
||||||
|
|
||||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
|
|||||||
@@ -1,16 +1,18 @@
|
|||||||
|
// Streaming tests require a live API key and are not run in CI.
|
||||||
|
// Uncomment to test locally.
|
||||||
|
|
||||||
// use futures::stream::StreamExt;
|
// use futures::stream::StreamExt;
|
||||||
// use jrest::expect;
|
|
||||||
// use mistralai_client::v1::{
|
// use mistralai_client::v1::{
|
||||||
// chat_completion::{ChatParams, ChatMessage, ChatMessageRole},
|
// chat::{ChatMessage, ChatParams},
|
||||||
// client::Client,
|
// client::Client,
|
||||||
// constants::Model,
|
// constants::Model,
|
||||||
// };
|
// };
|
||||||
|
//
|
||||||
// #[tokio::test]
|
// #[tokio::test]
|
||||||
// async fn test_client_chat_stream() {
|
// async fn test_client_chat_stream() {
|
||||||
// let client = Client::new(None, None, None, None).unwrap();
|
// let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
//
|
||||||
// let model = Model::OpenMistral7b;
|
// let model = Model::mistral_small_latest();
|
||||||
// let messages = vec![ChatMessage::new_user_message(
|
// let messages = vec![ChatMessage::new_user_message(
|
||||||
// "Just guess the next word: \"Eiffel ...\"?",
|
// "Just guess the next word: \"Eiffel ...\"?",
|
||||||
// )];
|
// )];
|
||||||
@@ -19,22 +21,24 @@
|
|||||||
// random_seed: Some(42),
|
// random_seed: Some(42),
|
||||||
// ..Default::default()
|
// ..Default::default()
|
||||||
// };
|
// };
|
||||||
|
//
|
||||||
// let stream_result = client.chat_stream(model, messages, Some(options)).await;
|
// let stream = client
|
||||||
// let mut stream = stream_result.expect("Failed to create stream.");
|
// .chat_stream(model, messages, Some(options))
|
||||||
// while let Some(maybe_chunk_result) = stream.next().await {
|
// .await
|
||||||
// match maybe_chunk_result {
|
// .expect("Failed to create stream.");
|
||||||
// Some(Ok(chunk)) => {
|
//
|
||||||
// if chunk.choices[0].delta.role == Some(ChatMessageRole::Assistant)
|
// stream
|
||||||
// || chunk.choices[0].finish_reason == Some("stop".to_string())
|
// .for_each(|chunk_result| async {
|
||||||
// {
|
// match chunk_result {
|
||||||
// expect!(chunk.choices[0].delta.content.len()).to_be(0);
|
// Ok(chunks) => {
|
||||||
// } else {
|
// for chunk in &chunks {
|
||||||
// expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0);
|
// if let Some(content) = &chunk.choices[0].delta.content {
|
||||||
|
// print!("{}", content);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
// }
|
// }
|
||||||
|
// Err(error) => eprintln!("Error: {:?}", error),
|
||||||
// }
|
// }
|
||||||
// Some(Err(error)) => eprintln!("Error processing chunk: {:?}", error),
|
// })
|
||||||
// None => (),
|
// .await;
|
||||||
// }
|
|
||||||
// }
|
|
||||||
// }
|
// }
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use mistralai_client::v1::{
|
|||||||
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
|
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
|
||||||
client::Client,
|
client::Client,
|
||||||
constants::Model,
|
constants::Model,
|
||||||
tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
|
tool::{Tool, ToolChoice},
|
||||||
};
|
};
|
||||||
|
|
||||||
mod setup;
|
mod setup;
|
||||||
@@ -14,19 +14,18 @@ fn test_client_chat() {
|
|||||||
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage::new_user_message(
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
"Guess the next word: \"Eiffel ...\"?",
|
"Guess the next word: \"Eiffel ...\"?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = client.chat(model, messages, Some(options)).unwrap();
|
let response = client.chat(model, messages, Some(options)).unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(Model::OpenMistral7b);
|
|
||||||
expect!(response.object).to_be("chat.completion".to_string());
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
expect!(response.choices.len()).to_be(1);
|
expect!(response.choices.len()).to_be(1);
|
||||||
expect!(response.choices[0].index).to_be(0);
|
expect!(response.choices[0].index).to_be(0);
|
||||||
@@ -50,21 +49,26 @@ fn test_client_chat_with_function_calling() {
|
|||||||
let tools = vec![Tool::new(
|
let tools = vec![Tool::new(
|
||||||
"get_city_temperature".to_string(),
|
"get_city_temperature".to_string(),
|
||||||
"Get the current temperature in a city.".to_string(),
|
"Get the current temperature in a city.".to_string(),
|
||||||
vec![ToolFunctionParameter::new(
|
serde_json::json!({
|
||||||
"city".to_string(),
|
"type": "object",
|
||||||
"The name of the city.".to_string(),
|
"properties": {
|
||||||
ToolFunctionParameterType::String,
|
"city": {
|
||||||
)],
|
"type": "string",
|
||||||
|
"description": "The name of the city."
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["city"]
|
||||||
|
}),
|
||||||
)];
|
)];
|
||||||
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = Model::MistralSmallLatest;
|
let model = Model::mistral_small_latest();
|
||||||
let messages = vec![ChatMessage::new_user_message(
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
"What's the current temperature in Paris?",
|
"What's the current temperature in Paris?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
@@ -73,12 +77,10 @@ fn test_client_chat_with_function_calling() {
|
|||||||
|
|
||||||
let response = client.chat(model, messages, Some(options)).unwrap();
|
let response = client.chat(model, messages, Some(options)).unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(Model::MistralSmallLatest);
|
|
||||||
expect!(response.object).to_be("chat.completion".to_string());
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
expect!(response.choices.len()).to_be(1);
|
expect!(response.choices.len()).to_be(1);
|
||||||
expect!(response.choices[0].index).to_be(0);
|
expect!(response.choices[0].index).to_be(0);
|
||||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
expect!(response.choices[0].message.content.clone()).to_be("".to_string());
|
|
||||||
expect!(response.choices[0].finish_reason.clone())
|
expect!(response.choices[0].finish_reason.clone())
|
||||||
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
|
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
|
||||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use jrest::expect;
|
use jrest::expect;
|
||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn test_client_embeddings_async() {
|
async fn test_client_embeddings_async() {
|
||||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
let model = Model::mistral_embed();
|
||||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
@@ -17,7 +17,6 @@ async fn test_client_embeddings_async() {
|
|||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(EmbedModel::MistralEmbed);
|
|
||||||
expect!(response.object).to_be("list".to_string());
|
expect!(response.object).to_be("list".to_string());
|
||||||
expect!(response.data.len()).to_be(2);
|
expect!(response.data.len()).to_be(2);
|
||||||
expect!(response.data[0].index).to_be(0);
|
expect!(response.data[0].index).to_be(0);
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
use jrest::expect;
|
use jrest::expect;
|
||||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
use mistralai_client::v1::{client::Client, constants::Model};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_client_embeddings() {
|
fn test_client_embeddings() {
|
||||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let model = EmbedModel::MistralEmbed;
|
let model = Model::mistral_embed();
|
||||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
@@ -14,7 +14,6 @@ fn test_client_embeddings() {
|
|||||||
|
|
||||||
let response = client.embeddings(model, input, options).unwrap();
|
let response = client.embeddings(model, input, options).unwrap();
|
||||||
|
|
||||||
expect!(response.model).to_be(EmbedModel::MistralEmbed);
|
|
||||||
expect!(response.object).to_be("list".to_string());
|
expect!(response.object).to_be("list".to_string());
|
||||||
expect!(response.data.len()).to_be(2);
|
expect!(response.data.len()).to_be(2);
|
||||||
expect!(response.data[0].index).to_be(0);
|
expect!(response.data[0].index).to_be(0);
|
||||||
|
|||||||
@@ -6,26 +6,19 @@ use mistralai_client::v1::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_model_constant() {
|
fn test_model_constants() {
|
||||||
let models = vec![
|
let models = vec![
|
||||||
Model::OpenMistral7b,
|
Model::mistral_small_latest(),
|
||||||
Model::OpenMixtral8x7b,
|
Model::mistral_large_latest(),
|
||||||
Model::OpenMixtral8x22b,
|
Model::open_mistral_nemo(),
|
||||||
Model::OpenMistralNemo,
|
Model::codestral_latest(),
|
||||||
Model::MistralTiny,
|
|
||||||
Model::MistralSmallLatest,
|
|
||||||
Model::MistralMediumLatest,
|
|
||||||
Model::MistralLargeLatest,
|
|
||||||
Model::MistralLarge,
|
|
||||||
Model::CodestralLatest,
|
|
||||||
Model::CodestralMamba,
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let client = Client::new(None, None, None, None).unwrap();
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")];
|
let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: 0.0,
|
temperature: Some(0.0),
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
Reference in New Issue
Block a user