- Replace closed Model enum with flexible string-based Model type with constructor methods for all current models (Mistral Large 3, Small 4, Magistral, Codestral, Devstral, Pixtral, Voxtral, etc.) - Add new API endpoints: FIM completions, Files, Fine-tuning, Batch jobs, OCR, Audio transcription, Moderations/Classifications, and Agent completions (sync + async for all) - Add new chat fields: frequency_penalty, presence_penalty, stop, n, parallel_tool_calls, reasoning_effort, min_tokens, json_schema response format - Add embedding fields: output_dimension, output_dtype - Tool parameters now accept raw JSON Schema (serde_json::Value) instead of limited enum types - Add tool call IDs and Required tool choice variant - Add DELETE HTTP method support and multipart file upload - Bump thiserror to v2, add reqwest multipart feature - Remove strum dependency (no longer needed) - Update all tests and examples for new API
256 lines
7.6 KiB
Rust
256 lines
7.6 KiB
Rust
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::v1::{common, constants, tool};
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// Definitions
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
pub struct ChatMessage {
|
|
pub role: ChatMessageRole,
|
|
pub content: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
|
/// Tool call ID, required when role is Tool.
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_call_id: Option<String>,
|
|
/// Function name, used when role is Tool.
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub name: Option<String>,
|
|
}
|
|
impl ChatMessage {
|
|
pub fn new_system_message(content: &str) -> Self {
|
|
Self {
|
|
role: ChatMessageRole::System,
|
|
content: content.to_string(),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
|
|
Self {
|
|
role: ChatMessageRole::Assistant,
|
|
content: content.to_string(),
|
|
tool_calls,
|
|
tool_call_id: None,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
pub fn new_user_message(content: &str) -> Self {
|
|
Self {
|
|
role: ChatMessageRole::User,
|
|
content: content.to_string(),
|
|
tool_calls: None,
|
|
tool_call_id: None,
|
|
name: None,
|
|
}
|
|
}
|
|
|
|
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
|
|
Self {
|
|
role: ChatMessageRole::Tool,
|
|
content: content.to_string(),
|
|
tool_calls: None,
|
|
tool_call_id: Some(tool_call_id.to_string()),
|
|
name: name.map(|n| n.to_string()),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// See the [Mistral AI API documentation](https://docs.mistral.ai/capabilities/completion/#chat-messages) for more information.
|
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
|
pub enum ChatMessageRole {
|
|
#[serde(rename = "system")]
|
|
System,
|
|
#[serde(rename = "assistant")]
|
|
Assistant,
|
|
#[serde(rename = "user")]
|
|
User,
|
|
#[serde(rename = "tool")]
|
|
Tool,
|
|
}
|
|
|
|
/// The format that the model must output.
|
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
|
pub struct ResponseFormat {
|
|
#[serde(rename = "type")]
|
|
pub type_: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub json_schema: Option<serde_json::Value>,
|
|
}
|
|
impl ResponseFormat {
|
|
pub fn text() -> Self {
|
|
Self {
|
|
type_: "text".to_string(),
|
|
json_schema: None,
|
|
}
|
|
}
|
|
|
|
pub fn json_object() -> Self {
|
|
Self {
|
|
type_: "json_object".to_string(),
|
|
json_schema: None,
|
|
}
|
|
}
|
|
|
|
pub fn json_schema(schema: serde_json::Value) -> Self {
|
|
Self {
|
|
type_: "json_schema".to_string(),
|
|
json_schema: Some(schema),
|
|
}
|
|
}
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// Request
|
|
|
|
/// The parameters for the chat request.
|
|
#[derive(Clone, Debug)]
|
|
pub struct ChatParams {
|
|
pub max_tokens: Option<u32>,
|
|
pub min_tokens: Option<u32>,
|
|
pub random_seed: Option<u32>,
|
|
pub response_format: Option<ResponseFormat>,
|
|
pub safe_prompt: bool,
|
|
pub temperature: Option<f32>,
|
|
pub tool_choice: Option<tool::ToolChoice>,
|
|
pub tools: Option<Vec<tool::Tool>>,
|
|
pub top_p: Option<f32>,
|
|
pub stop: Option<Vec<String>>,
|
|
pub n: Option<u32>,
|
|
pub frequency_penalty: Option<f32>,
|
|
pub presence_penalty: Option<f32>,
|
|
pub parallel_tool_calls: Option<bool>,
|
|
/// For reasoning models (Magistral). "high" or "none".
|
|
pub reasoning_effort: Option<String>,
|
|
}
|
|
impl Default for ChatParams {
|
|
fn default() -> Self {
|
|
Self {
|
|
max_tokens: None,
|
|
min_tokens: None,
|
|
random_seed: None,
|
|
safe_prompt: false,
|
|
response_format: None,
|
|
temperature: None,
|
|
tool_choice: None,
|
|
tools: None,
|
|
top_p: None,
|
|
stop: None,
|
|
n: None,
|
|
frequency_penalty: None,
|
|
presence_penalty: None,
|
|
parallel_tool_calls: None,
|
|
reasoning_effort: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct ChatRequest {
|
|
pub model: constants::Model,
|
|
pub messages: Vec<ChatMessage>,
|
|
pub stream: bool,
|
|
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub max_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub min_tokens: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub random_seed: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub response_format: Option<ResponseFormat>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub safe_prompt: Option<bool>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub temperature: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tool_choice: Option<tool::ToolChoice>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub tools: Option<Vec<tool::Tool>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub top_p: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub stop: Option<Vec<String>>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub n: Option<u32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub frequency_penalty: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub presence_penalty: Option<f32>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub parallel_tool_calls: Option<bool>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub reasoning_effort: Option<String>,
|
|
}
|
|
impl ChatRequest {
|
|
pub fn new(
|
|
model: constants::Model,
|
|
messages: Vec<ChatMessage>,
|
|
stream: bool,
|
|
options: Option<ChatParams>,
|
|
) -> Self {
|
|
let opts = options.unwrap_or_default();
|
|
let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
|
|
|
|
Self {
|
|
model,
|
|
messages,
|
|
stream,
|
|
max_tokens: opts.max_tokens,
|
|
min_tokens: opts.min_tokens,
|
|
random_seed: opts.random_seed,
|
|
safe_prompt,
|
|
temperature: opts.temperature,
|
|
tool_choice: opts.tool_choice,
|
|
tools: opts.tools,
|
|
top_p: opts.top_p,
|
|
response_format: opts.response_format,
|
|
stop: opts.stop,
|
|
n: opts.n,
|
|
frequency_penalty: opts.frequency_penalty,
|
|
presence_penalty: opts.presence_penalty,
|
|
parallel_tool_calls: opts.parallel_tool_calls,
|
|
reasoning_effort: opts.reasoning_effort,
|
|
}
|
|
}
|
|
}
|
|
|
|
// -----------------------------------------------------------------------------
|
|
// Response
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
pub struct ChatResponse {
|
|
pub id: String,
|
|
pub object: String,
|
|
/// Unix timestamp (in seconds).
|
|
pub created: u64,
|
|
pub model: constants::Model,
|
|
pub choices: Vec<ChatResponseChoice>,
|
|
pub usage: common::ResponseUsage,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
pub struct ChatResponseChoice {
|
|
pub index: u32,
|
|
pub message: ChatMessage,
|
|
pub finish_reason: ChatResponseChoiceFinishReason,
|
|
}
|
|
|
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
|
pub enum ChatResponseChoiceFinishReason {
|
|
#[serde(rename = "stop")]
|
|
Stop,
|
|
#[serde(rename = "length")]
|
|
Length,
|
|
#[serde(rename = "tool_calls")]
|
|
ToolCalls,
|
|
#[serde(rename = "model_length")]
|
|
ModelLength,
|
|
#[serde(rename = "error")]
|
|
Error,
|
|
}
|