Update to latest Mistral AI API (v1.0.0)
- Replace closed Model enum with flexible string-based Model type with constructor methods for all current models (Mistral Large 3, Small 4, Magistral, Codestral, Devstral, Pixtral, Voxtral, etc.) - Add new API endpoints: FIM completions, Files, Fine-tuning, Batch jobs, OCR, Audio transcription, Moderations/Classifications, and Agent completions (sync + async for all) - Add new chat fields: frequency_penalty, presence_penalty, stop, n, parallel_tool_calls, reasoning_effort, min_tokens, json_schema response format - Add embedding fields: output_dimension, output_dtype - Tool parameters now accept raw JSON Schema (serde_json::Value) instead of limited enum types - Add tool call IDs and Required tool choice variant - Add DELETE HTTP method support and multipart file upload - Bump thiserror to v2, add reqwest multipart feature - Remove strum dependency (no longer needed) - Update all tests and examples for new API
This commit is contained in:
27
Cargo.toml
27
Cargo.toml
@@ -2,7 +2,7 @@
|
||||
name = "mistralai-client"
|
||||
description = "Mistral AI API client library for Rust (unofficial)."
|
||||
license = "Apache-2.0"
|
||||
version = "0.14.0"
|
||||
version = "1.0.0"
|
||||
|
||||
edition = "2021"
|
||||
rust-version = "1.76.0"
|
||||
@@ -15,18 +15,17 @@ readme = "README.md"
|
||||
repository = "https://github.com/ivangabriele/mistralai-client-rs"
|
||||
|
||||
[dependencies]
|
||||
async-stream = "0.3.5"
|
||||
async-trait = "0.1.77"
|
||||
env_logger = "0.11.3"
|
||||
futures = "0.3.30"
|
||||
log = "0.4.21"
|
||||
reqwest = { version = "0.12.0", features = ["json", "blocking", "stream"] }
|
||||
serde = { version = "1.0.197", features = ["derive"] }
|
||||
serde_json = "1.0.114"
|
||||
strum = "0.26.1"
|
||||
thiserror = "1.0.57"
|
||||
tokio = { version = "1.36.0", features = ["full"] }
|
||||
tokio-stream = "0.1.14"
|
||||
async-stream = "0.3"
|
||||
async-trait = "0.1"
|
||||
env_logger = "0.11"
|
||||
futures = "0.3"
|
||||
log = "0.4"
|
||||
reqwest = { version = "0.12", features = ["json", "blocking", "stream", "multipart"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
thiserror = "2"
|
||||
tokio = { version = "1", features = ["full"] }
|
||||
tokio-stream = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
jrest = "0.2.3"
|
||||
jrest = "0.2"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
||||
chat::{ChatMessage, ChatParams},
|
||||
client::Client,
|
||||
constants::Model,
|
||||
};
|
||||
@@ -8,14 +8,12 @@ fn main() {
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = Model::OpenMistral7b;
|
||||
let messages = vec![ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
||||
tool_calls: None,
|
||||
}];
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message(
|
||||
"Just guess the next word: \"Eiffel ...\"?",
|
||||
)];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
||||
chat::{ChatMessage, ChatParams},
|
||||
client::Client,
|
||||
constants::Model,
|
||||
};
|
||||
@@ -9,14 +9,12 @@ async fn main() {
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = Model::OpenMistral7b;
|
||||
let messages = vec![ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
|
||||
tool_calls: None,
|
||||
}];
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message(
|
||||
"Just guess the next word: \"Eiffel ...\"?",
|
||||
)];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
..Default::default()
|
||||
};
|
||||
@@ -29,5 +27,4 @@ async fn main() {
|
||||
"{:?}: {}",
|
||||
result.choices[0].message.role, result.choices[0].message.content
|
||||
);
|
||||
// => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
||||
chat::{ChatMessage, ChatParams},
|
||||
client::Client,
|
||||
constants::Model,
|
||||
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
|
||||
tool::{Function, Tool, ToolChoice},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::any::Any;
|
||||
@@ -16,7 +16,6 @@ struct GetCityTemperatureFunction;
|
||||
#[async_trait::async_trait]
|
||||
impl Function for GetCityTemperatureFunction {
|
||||
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
||||
// Deserialize arguments, perform the logic, and return the result
|
||||
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
|
||||
|
||||
let temperature = match city.as_str() {
|
||||
@@ -32,11 +31,16 @@ fn main() {
|
||||
let tools = vec![Tool::new(
|
||||
"get_city_temperature".to_string(),
|
||||
"Get the current temperature in a city.".to_string(),
|
||||
vec![ToolFunctionParameter::new(
|
||||
"city".to_string(),
|
||||
"The name of the city.".to_string(),
|
||||
ToolFunctionParameterType::String,
|
||||
)],
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city."
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}),
|
||||
)];
|
||||
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
@@ -46,14 +50,12 @@ fn main() {
|
||||
Box::new(GetCityTemperatureFunction),
|
||||
);
|
||||
|
||||
let model = Model::MistralSmallLatest;
|
||||
let messages = vec![ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: "What's the temperature in Paris?".to_string(),
|
||||
tool_calls: None,
|
||||
}];
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message(
|
||||
"What's the temperature in Paris?",
|
||||
)];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
tools: Some(tools),
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
||||
chat::{ChatMessage, ChatParams},
|
||||
client::Client,
|
||||
constants::Model,
|
||||
tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
|
||||
tool::{Function, Tool, ToolChoice},
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::any::Any;
|
||||
@@ -16,7 +16,6 @@ struct GetCityTemperatureFunction;
|
||||
#[async_trait::async_trait]
|
||||
impl Function for GetCityTemperatureFunction {
|
||||
async fn execute(&self, arguments: String) -> Box<dyn Any + Send> {
|
||||
// Deserialize arguments, perform the logic, and return the result
|
||||
let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
|
||||
|
||||
let temperature = match city.as_str() {
|
||||
@@ -33,11 +32,16 @@ async fn main() {
|
||||
let tools = vec![Tool::new(
|
||||
"get_city_temperature".to_string(),
|
||||
"Get the current temperature in a city.".to_string(),
|
||||
vec![ToolFunctionParameter::new(
|
||||
"city".to_string(),
|
||||
"The name of the city.".to_string(),
|
||||
ToolFunctionParameterType::String,
|
||||
)],
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city."
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}),
|
||||
)];
|
||||
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
@@ -47,14 +51,12 @@ async fn main() {
|
||||
Box::new(GetCityTemperatureFunction),
|
||||
);
|
||||
|
||||
let model = Model::MistralSmallLatest;
|
||||
let messages = vec![ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: "What's the temperature in Paris?".to_string(),
|
||||
tool_calls: None,
|
||||
}];
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message(
|
||||
"What's the temperature in Paris?",
|
||||
)];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
tools: Some(tools),
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
use futures::stream::StreamExt;
|
||||
use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatMessageRole, ChatParams},
|
||||
chat::{ChatMessage, ChatParams},
|
||||
client::Client,
|
||||
constants::Model,
|
||||
};
|
||||
@@ -11,14 +11,10 @@ async fn main() {
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = Model::OpenMistral7b;
|
||||
let messages = vec![ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: "Tell me a short happy story.".to_string(),
|
||||
tool_calls: None,
|
||||
}];
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message("Tell me a short happy story.")];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
..Default::default()
|
||||
};
|
||||
@@ -31,9 +27,10 @@ async fn main() {
|
||||
.for_each(|chunk_result| async {
|
||||
match chunk_result {
|
||||
Ok(chunks) => chunks.iter().for_each(|chunk| {
|
||||
print!("{}", chunk.choices[0].delta.content);
|
||||
io::stdout().flush().unwrap();
|
||||
// => "Once upon a time, [...]"
|
||||
if let Some(content) = &chunk.choices[0].delta.content {
|
||||
print!("{}", content);
|
||||
io::stdout().flush().unwrap();
|
||||
}
|
||||
}),
|
||||
Err(error) => {
|
||||
eprintln!("Error processing chunk: {:?}", error)
|
||||
@@ -41,5 +38,5 @@ async fn main() {
|
||||
}
|
||||
})
|
||||
.await;
|
||||
print!("\n") // To persist the last chunk output.
|
||||
println!();
|
||||
}
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
||||
use mistralai_client::v1::{client::Client, constants::Model};
|
||||
|
||||
fn main() {
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = EmbedModel::MistralEmbed;
|
||||
let model = Model::mistral_embed();
|
||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
@@ -13,5 +13,4 @@ fn main() {
|
||||
|
||||
let response = client.embeddings(model, input, options).unwrap();
|
||||
println!("First Embedding: {:?}", response.data[0]);
|
||||
// => "First Embedding: {...}"
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
||||
use mistralai_client::v1::{client::Client, constants::Model};
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = EmbedModel::MistralEmbed;
|
||||
let model = Model::mistral_embed();
|
||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
@@ -17,5 +17,4 @@ async fn main() {
|
||||
.await
|
||||
.unwrap();
|
||||
println!("First Embedding: {:?}", response.data[0]);
|
||||
// => "First Embedding: {...}"
|
||||
}
|
||||
|
||||
21
examples/fim.rs
Normal file
21
examples/fim.rs
Normal file
@@ -0,0 +1,21 @@
|
||||
use mistralai_client::v1::{
|
||||
client::Client,
|
||||
constants::Model,
|
||||
fim::FimParams,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = Model::codestral_latest();
|
||||
let prompt = "def fibonacci(n):".to_string();
|
||||
let options = FimParams {
|
||||
suffix: Some("\n return result".to_string()),
|
||||
temperature: Some(0.0),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = client.fim(model, prompt, Some(options)).unwrap();
|
||||
println!("Completion: {}", response.choices[0].message.content);
|
||||
}
|
||||
25
examples/ocr.rs
Normal file
25
examples/ocr.rs
Normal file
@@ -0,0 +1,25 @@
|
||||
use mistralai_client::v1::{
|
||||
client::Client,
|
||||
constants::Model,
|
||||
ocr::{OcrDocument, OcrRequest},
|
||||
};
|
||||
|
||||
fn main() {
|
||||
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let request = OcrRequest {
|
||||
model: Model::mistral_ocr_latest(),
|
||||
document: OcrDocument::from_url("https://arxiv.org/pdf/2201.04234"),
|
||||
pages: Some(vec![0]),
|
||||
table_format: None,
|
||||
include_image_base64: None,
|
||||
image_limit: None,
|
||||
};
|
||||
|
||||
let response = client.ocr(&request).unwrap();
|
||||
for page in &response.pages {
|
||||
println!("--- Page {} ---", page.index);
|
||||
println!("{}", &page.markdown[..200.min(page.markdown.len())]);
|
||||
}
|
||||
}
|
||||
98
src/v1/agents.rs
Normal file
98
src/v1/agents.rs
Normal file
@@ -0,0 +1,98 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::v1::{chat, common, constants, tool};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Request
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AgentCompletionParams {
|
||||
pub max_tokens: Option<u32>,
|
||||
pub min_tokens: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub random_seed: Option<u32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub response_format: Option<chat::ResponseFormat>,
|
||||
pub tools: Option<Vec<tool::Tool>>,
|
||||
pub tool_choice: Option<tool::ToolChoice>,
|
||||
}
|
||||
impl Default for AgentCompletionParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_tokens: None,
|
||||
min_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
random_seed: None,
|
||||
stop: None,
|
||||
response_format: None,
|
||||
tools: None,
|
||||
tool_choice: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct AgentCompletionRequest {
|
||||
pub agent_id: String,
|
||||
pub messages: Vec<chat::ChatMessage>,
|
||||
pub stream: bool,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub min_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub random_seed: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<chat::ResponseFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<tool::Tool>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<tool::ToolChoice>,
|
||||
}
|
||||
impl AgentCompletionRequest {
|
||||
pub fn new(
|
||||
agent_id: String,
|
||||
messages: Vec<chat::ChatMessage>,
|
||||
stream: bool,
|
||||
options: Option<AgentCompletionParams>,
|
||||
) -> Self {
|
||||
let opts = options.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
agent_id,
|
||||
messages,
|
||||
stream,
|
||||
max_tokens: opts.max_tokens,
|
||||
min_tokens: opts.min_tokens,
|
||||
temperature: opts.temperature,
|
||||
top_p: opts.top_p,
|
||||
random_seed: opts.random_seed,
|
||||
stop: opts.stop,
|
||||
response_format: opts.response_format,
|
||||
tools: opts.tools,
|
||||
tool_choice: opts.tool_choice,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response (same shape as chat completions)
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct AgentCompletionResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: constants::Model,
|
||||
pub choices: Vec<chat::ChatResponseChoice>,
|
||||
pub usage: common::ResponseUsage,
|
||||
}
|
||||
78
src/v1/audio.rs
Normal file
78
src/v1/audio.rs
Normal file
@@ -0,0 +1,78 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::v1::constants;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Request (multipart form, but we define the params struct)
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AudioTranscriptionParams {
|
||||
pub model: constants::Model,
|
||||
pub language: Option<String>,
|
||||
pub temperature: Option<f32>,
|
||||
pub diarize: Option<bool>,
|
||||
pub timestamp_granularities: Option<Vec<TimestampGranularity>>,
|
||||
}
|
||||
impl Default for AudioTranscriptionParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: constants::Model::voxtral_mini_transcribe(),
|
||||
language: None,
|
||||
temperature: None,
|
||||
diarize: None,
|
||||
timestamp_granularities: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum TimestampGranularity {
|
||||
Segment,
|
||||
Word,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct AudioTranscriptionResponse {
|
||||
pub text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub language: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub segments: Option<Vec<TranscriptionSegment>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub words: Option<Vec<TranscriptionWord>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<AudioUsage>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct TranscriptionSegment {
|
||||
pub id: u32,
|
||||
pub start: f32,
|
||||
pub end: f32,
|
||||
pub text: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub speaker: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct TranscriptionWord {
|
||||
pub word: String,
|
||||
pub start: f32,
|
||||
pub end: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct AudioUsage {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_audio_seconds: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub prompt_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_tokens: Option<u32>,
|
||||
}
|
||||
53
src/v1/batch.rs
Normal file
53
src/v1/batch.rs
Normal file
@@ -0,0 +1,53 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::v1::constants;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Request
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct BatchJobRequest {
|
||||
pub input_files: Vec<String>,
|
||||
pub model: constants::Model,
|
||||
pub endpoint: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct BatchJobResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub model: constants::Model,
|
||||
pub endpoint: String,
|
||||
pub input_files: Vec<String>,
|
||||
pub status: String,
|
||||
pub created_at: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub completed_at: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_file: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub error_file: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub metadata: Option<serde_json::Value>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub total_requests: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub completed_requests: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub succeeded_requests: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub failed_requests: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct BatchJobListResponse {
|
||||
pub data: Vec<BatchJobResponse>,
|
||||
pub object: String,
|
||||
#[serde(default)]
|
||||
pub total: u32,
|
||||
}
|
||||
187
src/v1/chat.rs
187
src/v1/chat.rs
@@ -11,13 +11,31 @@ pub struct ChatMessage {
|
||||
pub content: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||
/// Tool call ID, required when role is Tool.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_call_id: Option<String>,
|
||||
/// Function name, used when role is Tool.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
}
|
||||
impl ChatMessage {
|
||||
pub fn new_system_message(content: &str) -> Self {
|
||||
Self {
|
||||
role: ChatMessageRole::System,
|
||||
content: content.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
|
||||
Self {
|
||||
role: ChatMessageRole::Assistant,
|
||||
content: content.to_string(),
|
||||
tool_calls,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,6 +44,18 @@ impl ChatMessage {
|
||||
role: ChatMessageRole::User,
|
||||
content: content.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
name: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
|
||||
Self {
|
||||
role: ChatMessageRole::Tool,
|
||||
content: content.to_string(),
|
||||
tool_calls: None,
|
||||
tool_call_id: Some(tool_call_id.to_string()),
|
||||
name: name.map(|n| n.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -44,17 +74,32 @@ pub enum ChatMessageRole {
|
||||
}
|
||||
|
||||
/// The format that the model must output.
|
||||
///
|
||||
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct ResponseFormat {
|
||||
#[serde(rename = "type")]
|
||||
pub type_: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub json_schema: Option<serde_json::Value>,
|
||||
}
|
||||
impl ResponseFormat {
|
||||
pub fn text() -> Self {
|
||||
Self {
|
||||
type_: "text".to_string(),
|
||||
json_schema: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn json_object() -> Self {
|
||||
Self {
|
||||
type_: "json_object".to_string(),
|
||||
json_schema: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn json_schema(schema: serde_json::Value) -> Self {
|
||||
Self {
|
||||
type_: "json_schema".to_string(),
|
||||
json_schema: Some(schema),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -63,91 +108,83 @@ impl ResponseFormat {
|
||||
// Request
|
||||
|
||||
/// The parameters for the chat request.
|
||||
///
|
||||
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ChatParams {
|
||||
/// The maximum number of tokens to generate in the completion.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub max_tokens: Option<u32>,
|
||||
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub min_tokens: Option<u32>,
|
||||
pub random_seed: Option<u32>,
|
||||
/// The format that the model must output.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
/// Whether to inject a safety prompt before all conversations.
|
||||
///
|
||||
/// Defaults to `false`.
|
||||
pub safe_prompt: bool,
|
||||
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
|
||||
///
|
||||
/// Defaults to `0.7`.
|
||||
pub temperature: f32,
|
||||
/// Specifies if/how functions are called.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub temperature: Option<f32>,
|
||||
pub tool_choice: Option<tool::ToolChoice>,
|
||||
/// A list of available tools for the model.
|
||||
///
|
||||
/// Defaults to `None`.
|
||||
pub tools: Option<Vec<tool::Tool>>,
|
||||
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
|
||||
///
|
||||
/// Defaults to `1.0`.
|
||||
pub top_p: f32,
|
||||
pub top_p: Option<f32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub n: Option<u32>,
|
||||
pub frequency_penalty: Option<f32>,
|
||||
pub presence_penalty: Option<f32>,
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
/// For reasoning models (Magistral). "high" or "none".
|
||||
pub reasoning_effort: Option<String>,
|
||||
}
|
||||
impl Default for ChatParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_tokens: None,
|
||||
min_tokens: None,
|
||||
random_seed: None,
|
||||
safe_prompt: false,
|
||||
response_format: None,
|
||||
temperature: 0.7,
|
||||
temperature: None,
|
||||
tool_choice: None,
|
||||
tools: None,
|
||||
top_p: 1.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl ChatParams {
|
||||
pub fn json_default() -> Self {
|
||||
Self {
|
||||
max_tokens: None,
|
||||
random_seed: None,
|
||||
safe_prompt: false,
|
||||
response_format: None,
|
||||
temperature: 0.7,
|
||||
tool_choice: None,
|
||||
tools: None,
|
||||
top_p: 1.0,
|
||||
top_p: None,
|
||||
stop: None,
|
||||
n: None,
|
||||
frequency_penalty: None,
|
||||
presence_penalty: None,
|
||||
parallel_tool_calls: None,
|
||||
reasoning_effort: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatRequest {
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub model: constants::Model,
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub stream: bool,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub min_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub random_seed: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub response_format: Option<ResponseFormat>,
|
||||
pub safe_prompt: bool,
|
||||
pub stream: bool,
|
||||
pub temperature: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub safe_prompt: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_choice: Option<tool::ToolChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<tool::Tool>>,
|
||||
pub top_p: f32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub n: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub frequency_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub presence_penalty: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_effort: Option<String>,
|
||||
}
|
||||
impl ChatRequest {
|
||||
pub fn new(
|
||||
@@ -156,30 +193,28 @@ impl ChatRequest {
|
||||
stream: bool,
|
||||
options: Option<ChatParams>,
|
||||
) -> Self {
|
||||
let ChatParams {
|
||||
max_tokens,
|
||||
random_seed,
|
||||
safe_prompt,
|
||||
temperature,
|
||||
tool_choice,
|
||||
tools,
|
||||
top_p,
|
||||
response_format,
|
||||
} = options.unwrap_or_default();
|
||||
let opts = options.unwrap_or_default();
|
||||
let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
|
||||
|
||||
Self {
|
||||
messages,
|
||||
model,
|
||||
|
||||
max_tokens,
|
||||
random_seed,
|
||||
safe_prompt,
|
||||
messages,
|
||||
stream,
|
||||
temperature,
|
||||
tool_choice,
|
||||
tools,
|
||||
top_p,
|
||||
response_format,
|
||||
max_tokens: opts.max_tokens,
|
||||
min_tokens: opts.min_tokens,
|
||||
random_seed: opts.random_seed,
|
||||
safe_prompt,
|
||||
temperature: opts.temperature,
|
||||
tool_choice: opts.tool_choice,
|
||||
tools: opts.tools,
|
||||
top_p: opts.top_p,
|
||||
response_format: opts.response_format,
|
||||
stop: opts.stop,
|
||||
n: opts.n,
|
||||
frequency_penalty: opts.frequency_penalty,
|
||||
presence_penalty: opts.presence_penalty,
|
||||
parallel_tool_calls: opts.parallel_tool_calls,
|
||||
reasoning_effort: opts.reasoning_effort,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -192,7 +227,7 @@ pub struct ChatResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
/// Unix timestamp (in seconds).
|
||||
pub created: u32,
|
||||
pub created: u64,
|
||||
pub model: constants::Model,
|
||||
pub choices: Vec<ChatResponseChoice>,
|
||||
pub usage: common::ResponseUsage,
|
||||
@@ -203,14 +238,18 @@ pub struct ChatResponseChoice {
|
||||
pub index: u32,
|
||||
pub message: ChatMessage,
|
||||
pub finish_reason: ChatResponseChoiceFinishReason,
|
||||
// TODO Check this prop (seen in API responses but undocumented).
|
||||
// pub logprobs: ???
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ChatResponseChoiceFinishReason {
|
||||
#[serde(rename = "stop")]
|
||||
Stop,
|
||||
#[serde(rename = "length")]
|
||||
Length,
|
||||
#[serde(rename = "tool_calls")]
|
||||
ToolCalls,
|
||||
#[serde(rename = "model_length")]
|
||||
ModelLength,
|
||||
#[serde(rename = "error")]
|
||||
Error,
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::from_str;
|
||||
|
||||
use crate::v1::{chat, common, constants, error};
|
||||
use crate::v1::{chat, common, constants, error, tool};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
@@ -11,12 +11,11 @@ pub struct ChatStreamChunk {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
/// Unix timestamp (in seconds).
|
||||
pub created: u32,
|
||||
pub created: u64,
|
||||
pub model: constants::Model,
|
||||
pub choices: Vec<ChatStreamChunkChoice>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage: Option<common::ResponseUsage>,
|
||||
// TODO Check this prop (seen in API responses but undocumented).
|
||||
// pub logprobs: ???,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
@@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice {
|
||||
pub index: u32,
|
||||
pub delta: ChatStreamChunkChoiceDelta,
|
||||
pub finish_reason: Option<String>,
|
||||
// TODO Check this prop (seen in API responses but undocumented).
|
||||
// pub logprobs: ???,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ChatStreamChunkChoiceDelta {
|
||||
pub role: Option<chat::ChatMessageRole>,
|
||||
pub content: String,
|
||||
#[serde(default)]
|
||||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||
}
|
||||
|
||||
/// Extracts serialized chunks from a stream message.
|
||||
@@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line(
|
||||
return Ok(Some(vec![]));
|
||||
}
|
||||
|
||||
// Attempt to deserialize the JSON string into ChatStreamChunk
|
||||
match from_str::<ChatStreamChunk>(chunk_as_json) {
|
||||
Ok(chunk) => Ok(Some(vec![chunk])),
|
||||
Err(e) => Err(error::ApiError {
|
||||
|
||||
1102
src/v1/client.rs
1102
src/v1/client.rs
File diff suppressed because it is too large
Load Diff
@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ResponseUsage {
|
||||
pub prompt_tokens: u32,
|
||||
#[serde(default)]
|
||||
pub completion_tokens: u32,
|
||||
pub total_tokens: u32,
|
||||
}
|
||||
|
||||
@@ -1,35 +1,131 @@
|
||||
use std::fmt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum Model {
|
||||
#[serde(rename = "open-mistral-7b")]
|
||||
OpenMistral7b,
|
||||
#[serde(rename = "open-mixtral-8x7b")]
|
||||
OpenMixtral8x7b,
|
||||
#[serde(rename = "open-mixtral-8x22b")]
|
||||
OpenMixtral8x22b,
|
||||
#[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")]
|
||||
OpenMistralNemo,
|
||||
#[serde(rename = "mistral-tiny")]
|
||||
MistralTiny,
|
||||
#[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")]
|
||||
MistralSmallLatest,
|
||||
#[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")]
|
||||
MistralMediumLatest,
|
||||
#[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")]
|
||||
MistralLargeLatest,
|
||||
#[serde(rename = "mistral-large-2402")]
|
||||
MistralLarge,
|
||||
#[serde(rename = "codestral-latest", alias = "codestral-2405")]
|
||||
CodestralLatest,
|
||||
#[serde(rename = "open-codestral-mamba")]
|
||||
CodestralMamba,
|
||||
/// A Mistral AI model identifier.
|
||||
///
|
||||
/// Use the associated constants for known models, or construct with `Model::new()` for any model string.
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct Model(pub String);
|
||||
|
||||
impl Model {
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
|
||||
// Flagship / Premier
|
||||
pub fn mistral_large_latest() -> Self {
|
||||
Self::new("mistral-large-latest")
|
||||
}
|
||||
pub fn mistral_large_3() -> Self {
|
||||
Self::new("mistral-large-3-25-12")
|
||||
}
|
||||
pub fn mistral_medium_latest() -> Self {
|
||||
Self::new("mistral-medium-latest")
|
||||
}
|
||||
pub fn mistral_medium_3_1() -> Self {
|
||||
Self::new("mistral-medium-3-1-25-08")
|
||||
}
|
||||
pub fn mistral_small_latest() -> Self {
|
||||
Self::new("mistral-small-latest")
|
||||
}
|
||||
pub fn mistral_small_4() -> Self {
|
||||
Self::new("mistral-small-4-0-26-03")
|
||||
}
|
||||
pub fn mistral_small_3_2() -> Self {
|
||||
Self::new("mistral-small-3-2-25-06")
|
||||
}
|
||||
|
||||
// Ministral
|
||||
pub fn ministral_3_14b() -> Self {
|
||||
Self::new("ministral-3-14b-25-12")
|
||||
}
|
||||
pub fn ministral_3_8b() -> Self {
|
||||
Self::new("ministral-3-8b-25-12")
|
||||
}
|
||||
pub fn ministral_3_3b() -> Self {
|
||||
Self::new("ministral-3-3b-25-12")
|
||||
}
|
||||
|
||||
// Reasoning
|
||||
pub fn magistral_medium_latest() -> Self {
|
||||
Self::new("magistral-medium-latest")
|
||||
}
|
||||
pub fn magistral_small_latest() -> Self {
|
||||
Self::new("magistral-small-latest")
|
||||
}
|
||||
|
||||
// Code
|
||||
pub fn codestral_latest() -> Self {
|
||||
Self::new("codestral-latest")
|
||||
}
|
||||
pub fn codestral_2508() -> Self {
|
||||
Self::new("codestral-2508")
|
||||
}
|
||||
pub fn codestral_embed() -> Self {
|
||||
Self::new("codestral-embed-25-05")
|
||||
}
|
||||
pub fn devstral_2() -> Self {
|
||||
Self::new("devstral-2-25-12")
|
||||
}
|
||||
pub fn devstral_small_2() -> Self {
|
||||
Self::new("devstral-small-2-25-12")
|
||||
}
|
||||
|
||||
// Multimodal / Vision
|
||||
pub fn pixtral_large() -> Self {
|
||||
Self::new("pixtral-large-2411")
|
||||
}
|
||||
|
||||
// Audio
|
||||
pub fn voxtral_mini_transcribe() -> Self {
|
||||
Self::new("voxtral-mini-transcribe-2-26-02")
|
||||
}
|
||||
pub fn voxtral_small() -> Self {
|
||||
Self::new("voxtral-small-25-07")
|
||||
}
|
||||
pub fn voxtral_mini() -> Self {
|
||||
Self::new("voxtral-mini-25-07")
|
||||
}
|
||||
|
||||
// Legacy (kept for backward compatibility)
|
||||
pub fn open_mistral_nemo() -> Self {
|
||||
Self::new("open-mistral-nemo")
|
||||
}
|
||||
|
||||
// Embedding
|
||||
pub fn mistral_embed() -> Self {
|
||||
Self::new("mistral-embed")
|
||||
}
|
||||
|
||||
// Moderation
|
||||
pub fn mistral_moderation_latest() -> Self {
|
||||
Self::new("mistral-moderation-26-03")
|
||||
}
|
||||
|
||||
// OCR
|
||||
pub fn mistral_ocr_latest() -> Self {
|
||||
Self::new("mistral-ocr-latest")
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum EmbedModel {
|
||||
#[serde(rename = "mistral-embed")]
|
||||
MistralEmbed,
|
||||
impl fmt::Display for Model {
|
||||
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||
write!(f, "{}", self.0)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for Model {
|
||||
fn from(s: &str) -> Self {
|
||||
Self(s.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for Model {
|
||||
fn from(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,42 +8,63 @@ use crate::v1::{common, constants};
|
||||
#[derive(Debug)]
|
||||
pub struct EmbeddingRequestOptions {
|
||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||
pub output_dimension: Option<u32>,
|
||||
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||
}
|
||||
impl Default for EmbeddingRequestOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
encoding_format: None,
|
||||
output_dimension: None,
|
||||
output_dtype: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct EmbeddingRequest {
|
||||
pub model: constants::EmbedModel,
|
||||
pub model: constants::Model,
|
||||
pub input: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_dimension: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||
}
|
||||
impl EmbeddingRequest {
|
||||
pub fn new(
|
||||
model: constants::EmbedModel,
|
||||
model: constants::Model,
|
||||
input: Vec<String>,
|
||||
options: Option<EmbeddingRequestOptions>,
|
||||
) -> Self {
|
||||
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default();
|
||||
let opts = options.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
model,
|
||||
input,
|
||||
encoding_format,
|
||||
encoding_format: opts.encoding_format,
|
||||
output_dimension: opts.output_dimension,
|
||||
output_dtype: opts.output_dtype,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
#[allow(non_camel_case_types)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EmbeddingRequestEncodingFormat {
|
||||
float,
|
||||
Float,
|
||||
Base64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum EmbeddingOutputDtype {
|
||||
Float,
|
||||
Int8,
|
||||
Uint8,
|
||||
Binary,
|
||||
Ubinary,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
@@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat {
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct EmbeddingResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub model: constants::EmbedModel,
|
||||
pub model: constants::Model,
|
||||
pub data: Vec<EmbeddingResponseDataItem>,
|
||||
pub usage: common::ResponseUsage,
|
||||
}
|
||||
|
||||
55
src/v1/files.rs
Normal file
55
src/v1/files.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Request
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "kebab-case")]
|
||||
pub enum FilePurpose {
|
||||
FineTune,
|
||||
Batch,
|
||||
Ocr,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FileListResponse {
|
||||
pub data: Vec<FileObject>,
|
||||
pub object: String,
|
||||
#[serde(default)]
|
||||
pub total: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FileObject {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub bytes: u64,
|
||||
pub created_at: u64,
|
||||
pub filename: String,
|
||||
pub purpose: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub sample_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub source: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub num_lines: Option<u64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub mimetype: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FileDeleteResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub deleted: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FileUrlResponse {
|
||||
pub url: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub expires_at: Option<u64>,
|
||||
}
|
||||
101
src/v1/fim.rs
Normal file
101
src/v1/fim.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::v1::{common, constants};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Request
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FimParams {
|
||||
pub suffix: Option<String>,
|
||||
pub max_tokens: Option<u32>,
|
||||
pub min_tokens: Option<u32>,
|
||||
pub temperature: Option<f32>,
|
||||
pub top_p: Option<f32>,
|
||||
pub stop: Option<Vec<String>>,
|
||||
pub random_seed: Option<u32>,
|
||||
}
|
||||
impl Default for FimParams {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
suffix: None,
|
||||
max_tokens: None,
|
||||
min_tokens: None,
|
||||
temperature: None,
|
||||
top_p: None,
|
||||
stop: None,
|
||||
random_seed: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct FimRequest {
|
||||
pub model: constants::Model,
|
||||
pub prompt: String,
|
||||
pub stream: bool,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub suffix: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub min_tokens: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub top_p: Option<f32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stop: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub random_seed: Option<u32>,
|
||||
}
|
||||
impl FimRequest {
|
||||
pub fn new(
|
||||
model: constants::Model,
|
||||
prompt: String,
|
||||
stream: bool,
|
||||
options: Option<FimParams>,
|
||||
) -> Self {
|
||||
let opts = options.unwrap_or_default();
|
||||
|
||||
Self {
|
||||
model,
|
||||
prompt,
|
||||
stream,
|
||||
suffix: opts.suffix,
|
||||
max_tokens: opts.max_tokens,
|
||||
min_tokens: opts.min_tokens,
|
||||
temperature: opts.temperature,
|
||||
top_p: opts.top_p,
|
||||
stop: opts.stop,
|
||||
random_seed: opts.random_seed,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FimResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub created: u64,
|
||||
pub model: constants::Model,
|
||||
pub choices: Vec<FimResponseChoice>,
|
||||
pub usage: common::ResponseUsage,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FimResponseChoice {
|
||||
pub index: u32,
|
||||
pub message: FimResponseMessage,
|
||||
pub finish_reason: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FimResponseMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
101
src/v1/fine_tuning.rs
Normal file
101
src/v1/fine_tuning.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::v1::constants;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Request
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct FineTuningJobRequest {
|
||||
pub model: constants::Model,
|
||||
pub training_files: Vec<TrainingFile>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub validation_files: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub hyperparameters: Option<Hyperparameters>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub suffix: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub auto_start: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub job_type: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub integrations: Option<Vec<Integration>>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct TrainingFile {
|
||||
pub file_id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub weight: Option<f32>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Hyperparameters {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub learning_rate: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub training_steps: Option<u32>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub warmup_fraction: Option<f64>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub epochs: Option<f64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct Integration {
|
||||
pub r#type: String,
|
||||
pub project: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub name: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub api_key: Option<String>,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FineTuningJobResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub model: constants::Model,
|
||||
pub status: FineTuningJobStatus,
|
||||
pub created_at: u64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub modified_at: Option<u64>,
|
||||
pub training_files: Vec<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub validation_files: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub hyperparameters: Option<Hyperparameters>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub fine_tuned_model: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub suffix: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub integrations: Option<Vec<Integration>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub trained_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||
pub enum FineTuningJobStatus {
|
||||
Queued,
|
||||
Running,
|
||||
Success,
|
||||
Failed,
|
||||
TimeoutExceeded,
|
||||
CancellationRequested,
|
||||
Cancelled,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct FineTuningJobListResponse {
|
||||
pub data: Vec<FineTuningJobResponse>,
|
||||
pub object: String,
|
||||
#[serde(default)]
|
||||
pub total: u32,
|
||||
}
|
||||
@@ -1,3 +1,6 @@
|
||||
pub mod agents;
|
||||
pub mod audio;
|
||||
pub mod batch;
|
||||
pub mod chat;
|
||||
pub mod chat_stream;
|
||||
pub mod client;
|
||||
@@ -5,6 +8,11 @@ pub mod common;
|
||||
pub mod constants;
|
||||
pub mod embedding;
|
||||
pub mod error;
|
||||
pub mod files;
|
||||
pub mod fim;
|
||||
pub mod fine_tuning;
|
||||
pub mod model_list;
|
||||
pub mod moderation;
|
||||
pub mod ocr;
|
||||
pub mod tool;
|
||||
pub mod utils;
|
||||
|
||||
@@ -15,23 +15,44 @@ pub struct ModelListData {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
/// Unix timestamp (in seconds).
|
||||
pub created: u32,
|
||||
pub created: u64,
|
||||
pub owned_by: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub root: Option<String>,
|
||||
#[serde(default)]
|
||||
pub archived: bool,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub capabilities: ModelListDataCapabilies,
|
||||
pub max_context_length: u32,
|
||||
#[serde(default)]
|
||||
pub name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub description: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub capabilities: Option<ModelListDataCapabilities>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub max_context_length: Option<u32>,
|
||||
#[serde(default)]
|
||||
pub aliases: Vec<String>,
|
||||
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub deprecation: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ModelListDataCapabilies {
|
||||
pub struct ModelListDataCapabilities {
|
||||
#[serde(default)]
|
||||
pub completion_chat: bool,
|
||||
#[serde(default)]
|
||||
pub completion_fim: bool,
|
||||
#[serde(default)]
|
||||
pub function_calling: bool,
|
||||
#[serde(default)]
|
||||
pub fine_tuning: bool,
|
||||
#[serde(default)]
|
||||
pub vision: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ModelDeleteResponse {
|
||||
pub id: String,
|
||||
pub object: String,
|
||||
pub deleted: bool,
|
||||
}
|
||||
|
||||
70
src/v1/moderation.rs
Normal file
70
src/v1/moderation.rs
Normal file
@@ -0,0 +1,70 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::v1::constants;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Request
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ModerationRequest {
|
||||
pub model: constants::Model,
|
||||
pub input: Vec<String>,
|
||||
}
|
||||
impl ModerationRequest {
|
||||
pub fn new(model: constants::Model, input: Vec<String>) -> Self {
|
||||
Self { model, input }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatModerationRequest {
|
||||
pub model: constants::Model,
|
||||
pub input: Vec<ChatModerationInput>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatModerationInput {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ClassificationRequest {
|
||||
pub model: constants::Model,
|
||||
pub input: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct ChatClassificationRequest {
|
||||
pub model: constants::Model,
|
||||
pub input: Vec<ChatModerationInput>,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ModerationResponse {
|
||||
pub id: String,
|
||||
pub model: constants::Model,
|
||||
pub results: Vec<ModerationResult>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ModerationResult {
|
||||
pub categories: serde_json::Value,
|
||||
pub category_scores: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ClassificationResponse {
|
||||
pub id: String,
|
||||
pub model: constants::Model,
|
||||
pub results: Vec<ClassificationResult>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ClassificationResult {
|
||||
pub categories: serde_json::Value,
|
||||
pub category_scores: serde_json::Value,
|
||||
}
|
||||
96
src/v1/ocr.rs
Normal file
96
src/v1/ocr.rs
Normal file
@@ -0,0 +1,96 @@
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::v1::constants;
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Request
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OcrRequest {
|
||||
pub model: constants::Model,
|
||||
pub document: OcrDocument,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pages: Option<Vec<u32>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub table_format: Option<OcrTableFormat>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include_image_base64: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub image_limit: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
pub struct OcrDocument {
|
||||
#[serde(rename = "type")]
|
||||
pub type_: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub document_url: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub file_id: Option<String>,
|
||||
}
|
||||
impl OcrDocument {
|
||||
pub fn from_url(url: &str) -> Self {
|
||||
Self {
|
||||
type_: "document_url".to_string(),
|
||||
document_url: Some(url.to_string()),
|
||||
file_id: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_file_id(file_id: &str) -> Self {
|
||||
Self {
|
||||
type_: "file_id".to_string(),
|
||||
document_url: None,
|
||||
file_id: Some(file_id.to_string()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OcrTableFormat {
|
||||
Markdown,
|
||||
Html,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Response
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct OcrResponse {
|
||||
pub pages: Vec<OcrPage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub usage_info: Option<OcrUsageInfo>,
|
||||
pub model: constants::Model,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct OcrPage {
|
||||
pub index: u32,
|
||||
pub markdown: String,
|
||||
#[serde(default)]
|
||||
pub images: Vec<OcrImage>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub dimensions: Option<OcrPageDimensions>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct OcrImage {
|
||||
pub id: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub image_base64: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct OcrPageDimensions {
|
||||
pub width: f32,
|
||||
pub height: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct OcrUsageInfo {
|
||||
pub pages_processed: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub doc_size_bytes: Option<u64>,
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
use async_trait::async_trait;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{any::Any, collections::HashMap, fmt::Debug};
|
||||
use std::{any::Any, fmt::Debug};
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Definitions
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub struct ToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub r#type: Option<String>,
|
||||
pub function: ToolCallFunction,
|
||||
}
|
||||
|
||||
@@ -22,31 +26,12 @@ pub struct Tool {
|
||||
pub function: ToolFunction,
|
||||
}
|
||||
impl Tool {
|
||||
/// Create a tool with a JSON Schema parameters object.
|
||||
pub fn new(
|
||||
function_name: String,
|
||||
function_description: String,
|
||||
function_parameters: Vec<ToolFunctionParameter>,
|
||||
parameters: serde_json::Value,
|
||||
) -> Self {
|
||||
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
|
||||
.into_iter()
|
||||
.map(|param| {
|
||||
(
|
||||
param.name,
|
||||
ToolFunctionParameterProperty {
|
||||
r#type: param.r#type,
|
||||
description: param.description,
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
let property_names = properties.keys().cloned().collect();
|
||||
|
||||
let parameters = ToolFunctionParameters {
|
||||
r#type: ToolFunctionParametersType::Object,
|
||||
properties,
|
||||
required: property_names,
|
||||
};
|
||||
|
||||
Self {
|
||||
r#type: ToolType::Function,
|
||||
function: ToolFunction {
|
||||
@@ -63,50 +48,9 @@ impl Tool {
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolFunction {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: ToolFunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolFunctionParameter {
|
||||
name: String,
|
||||
description: String,
|
||||
r#type: ToolFunctionParameterType,
|
||||
}
|
||||
impl ToolFunctionParameter {
|
||||
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
|
||||
Self {
|
||||
name,
|
||||
r#type,
|
||||
description,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolFunctionParameters {
|
||||
r#type: ToolFunctionParametersType,
|
||||
properties: HashMap<String, ToolFunctionParameterProperty>,
|
||||
required: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct ToolFunctionParameterProperty {
|
||||
r#type: ToolFunctionParameterType,
|
||||
description: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ToolFunctionParametersType {
|
||||
#[serde(rename = "object")]
|
||||
Object,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
pub enum ToolFunctionParameterType {
|
||||
#[serde(rename = "string")]
|
||||
String,
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||
@@ -127,6 +71,9 @@ pub enum ToolChoice {
|
||||
/// The model won't call a function and will generate a message instead.
|
||||
#[serde(rename = "none")]
|
||||
None,
|
||||
/// The model must call at least one tool.
|
||||
#[serde(rename = "required")]
|
||||
Required,
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
@@ -3,7 +3,7 @@ use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
|
||||
client::Client,
|
||||
constants::Model,
|
||||
tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
|
||||
tool::{Tool, ToolChoice},
|
||||
};
|
||||
|
||||
mod setup;
|
||||
@@ -14,12 +14,12 @@ async fn test_client_chat_async() {
|
||||
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = Model::OpenMistral7b;
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message(
|
||||
"Guess the next word: \"Eiffel ...\"?",
|
||||
)];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
..Default::default()
|
||||
};
|
||||
@@ -29,7 +29,6 @@ async fn test_client_chat_async() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
expect!(response.model).to_be(Model::OpenMistral7b);
|
||||
expect!(response.object).to_be("chat.completion".to_string());
|
||||
|
||||
expect!(response.choices.len()).to_be(1);
|
||||
@@ -56,21 +55,26 @@ async fn test_client_chat_async_with_function_calling() {
|
||||
let tools = vec![Tool::new(
|
||||
"get_city_temperature".to_string(),
|
||||
"Get the current temperature in a city.".to_string(),
|
||||
vec![ToolFunctionParameter::new(
|
||||
"city".to_string(),
|
||||
"The name of the city.".to_string(),
|
||||
ToolFunctionParameterType::String,
|
||||
)],
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city."
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}),
|
||||
)];
|
||||
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = Model::MistralSmallLatest;
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message(
|
||||
"What's the current temperature in Paris?",
|
||||
)];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
tool_choice: Some(ToolChoice::Any),
|
||||
tools: Some(tools),
|
||||
@@ -82,7 +86,6 @@ async fn test_client_chat_async_with_function_calling() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
expect!(response.model).to_be(Model::MistralSmallLatest);
|
||||
expect!(response.object).to_be("chat.completion".to_string());
|
||||
|
||||
expect!(response.choices.len()).to_be(1);
|
||||
@@ -91,13 +94,6 @@ async fn test_client_chat_async_with_function_calling() {
|
||||
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
|
||||
|
||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||
expect!(response.choices[0].message.content.clone()).to_be("".to_string());
|
||||
// expect!(response.choices[0].message.tool_calls.clone()).to_be(Some(vec![ToolCall {
|
||||
// function: ToolCallFunction {
|
||||
// name: "get_city_temperature".to_string(),
|
||||
// arguments: "{\"city\": \"Paris\"}".to_string(),
|
||||
// },
|
||||
// }]));
|
||||
|
||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
// Streaming tests require a live API key and are not run in CI.
|
||||
// Uncomment to test locally.
|
||||
|
||||
// use futures::stream::StreamExt;
|
||||
// use jrest::expect;
|
||||
// use mistralai_client::v1::{
|
||||
// chat_completion::{ChatParams, ChatMessage, ChatMessageRole},
|
||||
// chat::{ChatMessage, ChatParams},
|
||||
// client::Client,
|
||||
// constants::Model,
|
||||
// };
|
||||
|
||||
//
|
||||
// #[tokio::test]
|
||||
// async fn test_client_chat_stream() {
|
||||
// let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
// let model = Model::OpenMistral7b;
|
||||
//
|
||||
// let model = Model::mistral_small_latest();
|
||||
// let messages = vec![ChatMessage::new_user_message(
|
||||
// "Just guess the next word: \"Eiffel ...\"?",
|
||||
// )];
|
||||
@@ -19,22 +21,24 @@
|
||||
// random_seed: Some(42),
|
||||
// ..Default::default()
|
||||
// };
|
||||
|
||||
// let stream_result = client.chat_stream(model, messages, Some(options)).await;
|
||||
// let mut stream = stream_result.expect("Failed to create stream.");
|
||||
// while let Some(maybe_chunk_result) = stream.next().await {
|
||||
// match maybe_chunk_result {
|
||||
// Some(Ok(chunk)) => {
|
||||
// if chunk.choices[0].delta.role == Some(ChatMessageRole::Assistant)
|
||||
// || chunk.choices[0].finish_reason == Some("stop".to_string())
|
||||
// {
|
||||
// expect!(chunk.choices[0].delta.content.len()).to_be(0);
|
||||
// } else {
|
||||
// expect!(chunk.choices[0].delta.content.len()).to_be_greater_than(0);
|
||||
//
|
||||
// let stream = client
|
||||
// .chat_stream(model, messages, Some(options))
|
||||
// .await
|
||||
// .expect("Failed to create stream.");
|
||||
//
|
||||
// stream
|
||||
// .for_each(|chunk_result| async {
|
||||
// match chunk_result {
|
||||
// Ok(chunks) => {
|
||||
// for chunk in &chunks {
|
||||
// if let Some(content) = &chunk.choices[0].delta.content {
|
||||
// print!("{}", content);
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// Err(error) => eprintln!("Error: {:?}", error),
|
||||
// }
|
||||
// Some(Err(error)) => eprintln!("Error processing chunk: {:?}", error),
|
||||
// None => (),
|
||||
// }
|
||||
// }
|
||||
// })
|
||||
// .await;
|
||||
// }
|
||||
|
||||
@@ -3,7 +3,7 @@ use mistralai_client::v1::{
|
||||
chat::{ChatMessage, ChatMessageRole, ChatParams, ChatResponseChoiceFinishReason},
|
||||
client::Client,
|
||||
constants::Model,
|
||||
tool::{Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
|
||||
tool::{Tool, ToolChoice},
|
||||
};
|
||||
|
||||
mod setup;
|
||||
@@ -14,19 +14,18 @@ fn test_client_chat() {
|
||||
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = Model::OpenMistral7b;
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message(
|
||||
"Guess the next word: \"Eiffel ...\"?",
|
||||
)];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let response = client.chat(model, messages, Some(options)).unwrap();
|
||||
|
||||
expect!(response.model).to_be(Model::OpenMistral7b);
|
||||
expect!(response.object).to_be("chat.completion".to_string());
|
||||
expect!(response.choices.len()).to_be(1);
|
||||
expect!(response.choices[0].index).to_be(0);
|
||||
@@ -50,21 +49,26 @@ fn test_client_chat_with_function_calling() {
|
||||
let tools = vec![Tool::new(
|
||||
"get_city_temperature".to_string(),
|
||||
"Get the current temperature in a city.".to_string(),
|
||||
vec![ToolFunctionParameter::new(
|
||||
"city".to_string(),
|
||||
"The name of the city.".to_string(),
|
||||
ToolFunctionParameterType::String,
|
||||
)],
|
||||
serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {
|
||||
"type": "string",
|
||||
"description": "The name of the city."
|
||||
}
|
||||
},
|
||||
"required": ["city"]
|
||||
}),
|
||||
)];
|
||||
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = Model::MistralSmallLatest;
|
||||
let model = Model::mistral_small_latest();
|
||||
let messages = vec![ChatMessage::new_user_message(
|
||||
"What's the current temperature in Paris?",
|
||||
)];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
tool_choice: Some(ToolChoice::Auto),
|
||||
tools: Some(tools),
|
||||
@@ -73,12 +77,10 @@ fn test_client_chat_with_function_calling() {
|
||||
|
||||
let response = client.chat(model, messages, Some(options)).unwrap();
|
||||
|
||||
expect!(response.model).to_be(Model::MistralSmallLatest);
|
||||
expect!(response.object).to_be("chat.completion".to_string());
|
||||
expect!(response.choices.len()).to_be(1);
|
||||
expect!(response.choices[0].index).to_be(0);
|
||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||
expect!(response.choices[0].message.content.clone()).to_be("".to_string());
|
||||
expect!(response.choices[0].finish_reason.clone())
|
||||
.to_be(ChatResponseChoiceFinishReason::ToolCalls);
|
||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use jrest::expect;
|
||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
||||
use mistralai_client::v1::{client::Client, constants::Model};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_client_embeddings_async() {
|
||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = EmbedModel::MistralEmbed;
|
||||
let model = Model::mistral_embed();
|
||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
@@ -17,7 +17,6 @@ async fn test_client_embeddings_async() {
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
expect!(response.model).to_be(EmbedModel::MistralEmbed);
|
||||
expect!(response.object).to_be("list".to_string());
|
||||
expect!(response.data.len()).to_be(2);
|
||||
expect!(response.data[0].index).to_be(0);
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
use jrest::expect;
|
||||
use mistralai_client::v1::{client::Client, constants::EmbedModel};
|
||||
use mistralai_client::v1::{client::Client, constants::Model};
|
||||
|
||||
#[test]
|
||||
fn test_client_embeddings() {
|
||||
let client: Client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let model = EmbedModel::MistralEmbed;
|
||||
let model = Model::mistral_embed();
|
||||
let input = vec!["Embed this sentence.", "As well as this one."]
|
||||
.iter()
|
||||
.map(|s| s.to_string())
|
||||
@@ -14,7 +14,6 @@ fn test_client_embeddings() {
|
||||
|
||||
let response = client.embeddings(model, input, options).unwrap();
|
||||
|
||||
expect!(response.model).to_be(EmbedModel::MistralEmbed);
|
||||
expect!(response.object).to_be("list".to_string());
|
||||
expect!(response.data.len()).to_be(2);
|
||||
expect!(response.data[0].index).to_be(0);
|
||||
|
||||
@@ -6,26 +6,19 @@ use mistralai_client::v1::{
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_model_constant() {
|
||||
fn test_model_constants() {
|
||||
let models = vec![
|
||||
Model::OpenMistral7b,
|
||||
Model::OpenMixtral8x7b,
|
||||
Model::OpenMixtral8x22b,
|
||||
Model::OpenMistralNemo,
|
||||
Model::MistralTiny,
|
||||
Model::MistralSmallLatest,
|
||||
Model::MistralMediumLatest,
|
||||
Model::MistralLargeLatest,
|
||||
Model::MistralLarge,
|
||||
Model::CodestralLatest,
|
||||
Model::CodestralMamba,
|
||||
Model::mistral_small_latest(),
|
||||
Model::mistral_large_latest(),
|
||||
Model::open_mistral_nemo(),
|
||||
Model::codestral_latest(),
|
||||
];
|
||||
|
||||
let client = Client::new(None, None, None, None).unwrap();
|
||||
|
||||
let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")];
|
||||
let options = ChatParams {
|
||||
temperature: 0.0,
|
||||
temperature: Some(0.0),
|
||||
random_seed: Some(42),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
Reference in New Issue
Block a user