refactor!: modernize core types and client for latest Mistral API
BREAKING CHANGE: Model is now a string-based struct with constructor methods instead of a closed enum. EmbedModel is removed — use Model::mistral_embed() instead. Tool parameters now accept serde_json::Value (JSON Schema) instead of limited enum types. - Replace Model enum with flexible Model(String) supporting all current models: Large 3, Small 4, Medium 3.1, Magistral, Codestral, Devstral, Pixtral, Voxtral, Ministral, and arbitrary strings - Remove EmbedModel enum (consolidated into Model) - Chat: add frequency_penalty, presence_penalty, stop, n, min_tokens, parallel_tool_calls, reasoning_effort, json_schema response format - Embeddings: add output_dimension and output_dtype fields - Tools: accept raw JSON Schema, add tool call IDs and Required choice - Stream delta content is now Option<String> for tool call chunks - Add Length, ModelLength, Error finish reason variants - DRY HTTP transport with shared response handlers - Add DELETE method support and model get/delete endpoints - Make model_list fields more lenient with Option/default for API compat
This commit is contained in:
187
src/v1/chat.rs
187
src/v1/chat.rs
@@ -11,13 +11,31 @@ pub struct ChatMessage {
|
|||||||
pub content: String,
|
pub content: String,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||||
|
/// Tool call ID, required when role is Tool.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_call_id: Option<String>,
|
||||||
|
/// Function name, used when role is Tool.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub name: Option<String>,
|
||||||
}
|
}
|
||||||
impl ChatMessage {
|
impl ChatMessage {
|
||||||
|
pub fn new_system_message(content: &str) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::System,
|
||||||
|
content: content.to_string(),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
|
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
role: ChatMessageRole::Assistant,
|
role: ChatMessageRole::Assistant,
|
||||||
content: content.to_string(),
|
content: content.to_string(),
|
||||||
tool_calls,
|
tool_calls,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -26,6 +44,18 @@ impl ChatMessage {
|
|||||||
role: ChatMessageRole::User,
|
role: ChatMessageRole::User,
|
||||||
content: content.to_string(),
|
content: content.to_string(),
|
||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
|
tool_call_id: None,
|
||||||
|
name: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new_tool_message(content: &str, tool_call_id: &str, name: Option<&str>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: ChatMessageRole::Tool,
|
||||||
|
content: content.to_string(),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_call_id: Some(tool_call_id.to_string()),
|
||||||
|
name: name.map(|n| n.to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -44,17 +74,32 @@ pub enum ChatMessageRole {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// The format that the model must output.
|
/// The format that the model must output.
|
||||||
///
|
|
||||||
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
|
||||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
pub struct ResponseFormat {
|
pub struct ResponseFormat {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
pub type_: String,
|
pub type_: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub json_schema: Option<serde_json::Value>,
|
||||||
}
|
}
|
||||||
impl ResponseFormat {
|
impl ResponseFormat {
|
||||||
|
pub fn text() -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "text".to_string(),
|
||||||
|
json_schema: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn json_object() -> Self {
|
pub fn json_object() -> Self {
|
||||||
Self {
|
Self {
|
||||||
type_: "json_object".to_string(),
|
type_: "json_object".to_string(),
|
||||||
|
json_schema: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn json_schema(schema: serde_json::Value) -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "json_schema".to_string(),
|
||||||
|
json_schema: Some(schema),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -63,91 +108,83 @@ impl ResponseFormat {
|
|||||||
// Request
|
// Request
|
||||||
|
|
||||||
/// The parameters for the chat request.
|
/// The parameters for the chat request.
|
||||||
///
|
|
||||||
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Clone, Debug)]
|
||||||
pub struct ChatParams {
|
pub struct ChatParams {
|
||||||
/// The maximum number of tokens to generate in the completion.
|
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
|
pub min_tokens: Option<u32>,
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub random_seed: Option<u32>,
|
pub random_seed: Option<u32>,
|
||||||
/// The format that the model must output.
|
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub response_format: Option<ResponseFormat>,
|
pub response_format: Option<ResponseFormat>,
|
||||||
/// Whether to inject a safety prompt before all conversations.
|
|
||||||
///
|
|
||||||
/// Defaults to `false`.
|
|
||||||
pub safe_prompt: bool,
|
pub safe_prompt: bool,
|
||||||
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
|
pub temperature: Option<f32>,
|
||||||
///
|
|
||||||
/// Defaults to `0.7`.
|
|
||||||
pub temperature: f32,
|
|
||||||
/// Specifies if/how functions are called.
|
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub tool_choice: Option<tool::ToolChoice>,
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
/// A list of available tools for the model.
|
|
||||||
///
|
|
||||||
/// Defaults to `None`.
|
|
||||||
pub tools: Option<Vec<tool::Tool>>,
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
|
pub top_p: Option<f32>,
|
||||||
///
|
pub stop: Option<Vec<String>>,
|
||||||
/// Defaults to `1.0`.
|
pub n: Option<u32>,
|
||||||
pub top_p: f32,
|
pub frequency_penalty: Option<f32>,
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
/// For reasoning models (Magistral). "high" or "none".
|
||||||
|
pub reasoning_effort: Option<String>,
|
||||||
}
|
}
|
||||||
impl Default for ChatParams {
|
impl Default for ChatParams {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
|
min_tokens: None,
|
||||||
random_seed: None,
|
random_seed: None,
|
||||||
safe_prompt: false,
|
safe_prompt: false,
|
||||||
response_format: None,
|
response_format: None,
|
||||||
temperature: 0.7,
|
temperature: None,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
tools: None,
|
tools: None,
|
||||||
top_p: 1.0,
|
top_p: None,
|
||||||
}
|
stop: None,
|
||||||
}
|
n: None,
|
||||||
}
|
frequency_penalty: None,
|
||||||
impl ChatParams {
|
presence_penalty: None,
|
||||||
pub fn json_default() -> Self {
|
parallel_tool_calls: None,
|
||||||
Self {
|
reasoning_effort: None,
|
||||||
max_tokens: None,
|
|
||||||
random_seed: None,
|
|
||||||
safe_prompt: false,
|
|
||||||
response_format: None,
|
|
||||||
temperature: 0.7,
|
|
||||||
tool_choice: None,
|
|
||||||
tools: None,
|
|
||||||
top_p: 1.0,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct ChatRequest {
|
pub struct ChatRequest {
|
||||||
pub messages: Vec<ChatMessage>,
|
|
||||||
pub model: constants::Model,
|
pub model: constants::Model,
|
||||||
|
pub messages: Vec<ChatMessage>,
|
||||||
|
pub stream: bool,
|
||||||
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub min_tokens: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub random_seed: Option<u32>,
|
pub random_seed: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub response_format: Option<ResponseFormat>,
|
pub response_format: Option<ResponseFormat>,
|
||||||
pub safe_prompt: bool,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub stream: bool,
|
pub safe_prompt: Option<bool>,
|
||||||
pub temperature: f32,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub temperature: Option<f32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_choice: Option<tool::ToolChoice>,
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tools: Option<Vec<tool::Tool>>,
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
pub top_p: f32,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub top_p: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub stop: Option<Vec<String>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub n: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub frequency_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub presence_penalty: Option<f32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parallel_tool_calls: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub reasoning_effort: Option<String>,
|
||||||
}
|
}
|
||||||
impl ChatRequest {
|
impl ChatRequest {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
@@ -156,30 +193,28 @@ impl ChatRequest {
|
|||||||
stream: bool,
|
stream: bool,
|
||||||
options: Option<ChatParams>,
|
options: Option<ChatParams>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let ChatParams {
|
let opts = options.unwrap_or_default();
|
||||||
max_tokens,
|
let safe_prompt = if opts.safe_prompt { Some(true) } else { None };
|
||||||
random_seed,
|
|
||||||
safe_prompt,
|
|
||||||
temperature,
|
|
||||||
tool_choice,
|
|
||||||
tools,
|
|
||||||
top_p,
|
|
||||||
response_format,
|
|
||||||
} = options.unwrap_or_default();
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
messages,
|
|
||||||
model,
|
model,
|
||||||
|
messages,
|
||||||
max_tokens,
|
|
||||||
random_seed,
|
|
||||||
safe_prompt,
|
|
||||||
stream,
|
stream,
|
||||||
temperature,
|
max_tokens: opts.max_tokens,
|
||||||
tool_choice,
|
min_tokens: opts.min_tokens,
|
||||||
tools,
|
random_seed: opts.random_seed,
|
||||||
top_p,
|
safe_prompt,
|
||||||
response_format,
|
temperature: opts.temperature,
|
||||||
|
tool_choice: opts.tool_choice,
|
||||||
|
tools: opts.tools,
|
||||||
|
top_p: opts.top_p,
|
||||||
|
response_format: opts.response_format,
|
||||||
|
stop: opts.stop,
|
||||||
|
n: opts.n,
|
||||||
|
frequency_penalty: opts.frequency_penalty,
|
||||||
|
presence_penalty: opts.presence_penalty,
|
||||||
|
parallel_tool_calls: opts.parallel_tool_calls,
|
||||||
|
reasoning_effort: opts.reasoning_effort,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -192,7 +227,7 @@ pub struct ChatResponse {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Unix timestamp (in seconds).
|
/// Unix timestamp (in seconds).
|
||||||
pub created: u32,
|
pub created: u64,
|
||||||
pub model: constants::Model,
|
pub model: constants::Model,
|
||||||
pub choices: Vec<ChatResponseChoice>,
|
pub choices: Vec<ChatResponseChoice>,
|
||||||
pub usage: common::ResponseUsage,
|
pub usage: common::ResponseUsage,
|
||||||
@@ -203,14 +238,18 @@ pub struct ChatResponseChoice {
|
|||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub message: ChatMessage,
|
pub message: ChatMessage,
|
||||||
pub finish_reason: ChatResponseChoiceFinishReason,
|
pub finish_reason: ChatResponseChoiceFinishReason,
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub logprobs: ???
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
pub enum ChatResponseChoiceFinishReason {
|
pub enum ChatResponseChoiceFinishReason {
|
||||||
#[serde(rename = "stop")]
|
#[serde(rename = "stop")]
|
||||||
Stop,
|
Stop,
|
||||||
|
#[serde(rename = "length")]
|
||||||
|
Length,
|
||||||
#[serde(rename = "tool_calls")]
|
#[serde(rename = "tool_calls")]
|
||||||
ToolCalls,
|
ToolCalls,
|
||||||
|
#[serde(rename = "model_length")]
|
||||||
|
ModelLength,
|
||||||
|
#[serde(rename = "error")]
|
||||||
|
Error,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
|
|
||||||
use crate::v1::{chat, common, constants, error};
|
use crate::v1::{chat, common, constants, error, tool};
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Response
|
// Response
|
||||||
@@ -11,12 +11,11 @@ pub struct ChatStreamChunk {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Unix timestamp (in seconds).
|
/// Unix timestamp (in seconds).
|
||||||
pub created: u32,
|
pub created: u64,
|
||||||
pub model: constants::Model,
|
pub model: constants::Model,
|
||||||
pub choices: Vec<ChatStreamChunkChoice>,
|
pub choices: Vec<ChatStreamChunkChoice>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub usage: Option<common::ResponseUsage>,
|
pub usage: Option<common::ResponseUsage>,
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub logprobs: ???,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
@@ -24,14 +23,15 @@ pub struct ChatStreamChunkChoice {
|
|||||||
pub index: u32,
|
pub index: u32,
|
||||||
pub delta: ChatStreamChunkChoiceDelta,
|
pub delta: ChatStreamChunkChoiceDelta,
|
||||||
pub finish_reason: Option<String>,
|
pub finish_reason: Option<String>,
|
||||||
// TODO Check this prop (seen in API responses but undocumented).
|
|
||||||
// pub logprobs: ???,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ChatStreamChunkChoiceDelta {
|
pub struct ChatStreamChunkChoiceDelta {
|
||||||
pub role: Option<chat::ChatMessageRole>,
|
pub role: Option<chat::ChatMessageRole>,
|
||||||
pub content: String,
|
#[serde(default)]
|
||||||
|
pub content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extracts serialized chunks from a stream message.
|
/// Extracts serialized chunks from a stream message.
|
||||||
@@ -47,7 +47,6 @@ pub fn get_chunk_from_stream_message_line(
|
|||||||
return Ok(Some(vec![]));
|
return Ok(Some(vec![]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Attempt to deserialize the JSON string into ChatStreamChunk
|
|
||||||
match from_str::<ChatStreamChunk>(chunk_as_json) {
|
match from_str::<ChatStreamChunk>(chunk_as_json) {
|
||||||
Ok(chunk) => Ok(Some(vec![chunk])),
|
Ok(chunk) => Ok(Some(vec![chunk])),
|
||||||
Err(e) => Err(error::ApiError {
|
Err(e) => Err(error::ApiError {
|
||||||
|
|||||||
487
src/v1/client.rs
487
src/v1/client.rs
@@ -26,25 +26,10 @@ impl Client {
|
|||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
///
|
///
|
||||||
/// * `api_key` - An optional API key.
|
/// * `api_key` - An optional API key. If not provided, uses `MISTRAL_API_KEY` env var.
|
||||||
/// If not provided, the method will try to use the `MISTRAL_API_KEY` environment variable.
|
/// * `endpoint` - An optional custom API endpoint. Defaults to `https://api.mistral.ai/v1`.
|
||||||
/// * `endpoint` - An optional custom API endpoint. Defaults to the official API endpoint if not provided.
|
/// * `max_retries` - Optional maximum number of retries. Defaults to `5`.
|
||||||
/// * `max_retries` - Optional maximum number of retries for failed requests. Defaults to `5`.
|
/// * `timeout` - Optional timeout in seconds. Defaults to `120`.
|
||||||
/// * `timeout` - Optional timeout in seconds for requests. Defaults to `120`.
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// use mistralai_client::v1::client::Client;
|
|
||||||
///
|
|
||||||
/// let client = Client::new(Some("your_api_key_here".to_string()), None, Some(3), Some(60));
|
|
||||||
/// assert!(client.is_ok());
|
|
||||||
/// ```
|
|
||||||
///
|
|
||||||
/// # Errors
|
|
||||||
///
|
|
||||||
/// This method fails whenever neither the `api_key` is provided
|
|
||||||
/// nor the `MISTRAL_API_KEY` environment variable is set.
|
|
||||||
pub fn new(
|
pub fn new(
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
endpoint: Option<String>,
|
endpoint: Option<String>,
|
||||||
@@ -69,43 +54,15 @@ impl Client {
|
|||||||
endpoint,
|
endpoint,
|
||||||
max_retries,
|
max_retries,
|
||||||
timeout,
|
timeout,
|
||||||
|
|
||||||
functions,
|
functions,
|
||||||
last_function_call_result,
|
last_function_call_result,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Synchronously sends a chat completion request and returns the response.
|
// =========================================================================
|
||||||
///
|
// Chat Completions
|
||||||
/// # Arguments
|
// =========================================================================
|
||||||
///
|
|
||||||
/// * `model` - The [Model] to use for the chat completion.
|
|
||||||
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
|
|
||||||
/// * `options` - Optional [ChatParams] to customize the request.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// Returns a [Result] containing the `ChatResponse` if the request is successful,
|
|
||||||
/// or an [ApiError] if there is an error.
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// use mistralai_client::v1::{
|
|
||||||
/// chat::{ChatMessage, ChatMessageRole},
|
|
||||||
/// client::Client,
|
|
||||||
/// constants::Model,
|
|
||||||
/// };
|
|
||||||
///
|
|
||||||
/// let client = Client::new(None, None, None, None).unwrap();
|
|
||||||
/// let messages = vec![ChatMessage {
|
|
||||||
/// role: ChatMessageRole::User,
|
|
||||||
/// content: "Hello, world!".to_string(),
|
|
||||||
/// tool_calls: None,
|
|
||||||
/// }];
|
|
||||||
/// let response = client.chat(Model::OpenMistral7b, messages, None).unwrap();
|
|
||||||
/// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
|
|
||||||
/// ```
|
|
||||||
pub fn chat(
|
pub fn chat(
|
||||||
&self,
|
&self,
|
||||||
model: constants::Model,
|
model: constants::Model,
|
||||||
@@ -119,49 +76,13 @@ impl Client {
|
|||||||
match result {
|
match result {
|
||||||
Ok(data) => {
|
Ok(data) => {
|
||||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||||
|
|
||||||
self.call_function_if_any(data.clone());
|
self.call_function_if_any(data.clone());
|
||||||
|
|
||||||
Ok(data)
|
Ok(data)
|
||||||
}
|
}
|
||||||
Err(error) => Err(self.to_api_error(error)),
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Asynchronously sends a chat completion request and returns the response.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `model` - The [Model] to use for the chat completion.
|
|
||||||
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
|
|
||||||
/// * `options` - Optional [ChatParams] to customize the request.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
|
|
||||||
/// or an [ApiError] if there is an error.
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// use mistralai_client::v1::{
|
|
||||||
/// chat::{ChatMessage, ChatMessageRole},
|
|
||||||
/// client::Client,
|
|
||||||
/// constants::Model,
|
|
||||||
/// };
|
|
||||||
///
|
|
||||||
/// #[tokio::main]
|
|
||||||
/// async fn main() {
|
|
||||||
/// let client = Client::new(None, None, None, None).unwrap();
|
|
||||||
/// let messages = vec![ChatMessage {
|
|
||||||
/// role: ChatMessageRole::User,
|
|
||||||
/// content: "Hello, world!".to_string(),
|
|
||||||
/// tool_calls: None,
|
|
||||||
/// }];
|
|
||||||
/// let response = client.chat_async(Model::OpenMistral7b, messages, None).await.unwrap();
|
|
||||||
/// println!("{:?}: {}", response.choices[0].message.role, response.choices[0].message.content);
|
|
||||||
/// }
|
|
||||||
/// ```
|
|
||||||
pub async fn chat_async(
|
pub async fn chat_async(
|
||||||
&self,
|
&self,
|
||||||
model: constants::Model,
|
model: constants::Model,
|
||||||
@@ -175,68 +96,13 @@ impl Client {
|
|||||||
match result {
|
match result {
|
||||||
Ok(data) => {
|
Ok(data) => {
|
||||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||||
|
|
||||||
self.call_function_if_any_async(data.clone()).await;
|
self.call_function_if_any_async(data.clone()).await;
|
||||||
|
|
||||||
Ok(data)
|
Ok(data)
|
||||||
}
|
}
|
||||||
Err(error) => Err(self.to_api_error(error)),
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Asynchronously sends a chat completion request and returns a stream of message chunks.
|
|
||||||
///
|
|
||||||
/// # Arguments
|
|
||||||
///
|
|
||||||
/// * `model` - The [Model] to use for the chat completion.
|
|
||||||
/// * `messages` - A vector of [ChatMessage] to send as part of the chat.
|
|
||||||
/// * `options` - Optional [ChatParams] to customize the request.
|
|
||||||
///
|
|
||||||
/// # Returns
|
|
||||||
///
|
|
||||||
/// Returns a [Result] containing a `Stream` of `ChatStreamChunk` if the request is successful,
|
|
||||||
/// or an [ApiError] if there is an error.
|
|
||||||
///
|
|
||||||
/// # Examples
|
|
||||||
///
|
|
||||||
/// ```
|
|
||||||
/// use futures::stream::StreamExt;
|
|
||||||
/// use mistralai_client::v1::{
|
|
||||||
/// chat::{ChatMessage, ChatMessageRole},
|
|
||||||
/// client::Client,
|
|
||||||
/// constants::Model,
|
|
||||||
/// };
|
|
||||||
/// use std::io::{self, Write};
|
|
||||||
///
|
|
||||||
/// #[tokio::main]
|
|
||||||
/// async fn main() {
|
|
||||||
/// let client = Client::new(None, None, None, None).unwrap();
|
|
||||||
/// let messages = vec![ChatMessage {
|
|
||||||
/// role: ChatMessageRole::User,
|
|
||||||
/// content: "Hello, world!".to_string(),
|
|
||||||
/// tool_calls: None,
|
|
||||||
/// }];
|
|
||||||
///
|
|
||||||
/// let stream_result = client
|
|
||||||
/// .chat_stream(Model::OpenMistral7b,messages, None)
|
|
||||||
/// .await
|
|
||||||
/// .unwrap();
|
|
||||||
/// stream_result
|
|
||||||
/// .for_each(|chunk_result| async {
|
|
||||||
/// match chunk_result {
|
|
||||||
/// Ok(chunks) => chunks.iter().for_each(|chunk| {
|
|
||||||
/// print!("{}", chunk.choices[0].delta.content);
|
|
||||||
/// io::stdout().flush().unwrap();
|
|
||||||
/// // => "Once upon a time, [...]"
|
|
||||||
/// }),
|
|
||||||
/// Err(error) => {
|
|
||||||
/// eprintln!("Error processing chunk: {:?}", error)
|
|
||||||
/// }
|
|
||||||
/// }
|
|
||||||
/// })
|
|
||||||
/// .await;
|
|
||||||
/// print!("\n") // To persist the last chunk output.
|
|
||||||
/// }
|
|
||||||
pub async fn chat_stream(
|
pub async fn chat_stream(
|
||||||
&self,
|
&self,
|
||||||
model: constants::Model,
|
model: constants::Model,
|
||||||
@@ -292,9 +158,13 @@ impl Client {
|
|||||||
Ok(deserialized_stream)
|
Ok(deserialized_stream)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Embeddings
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
pub fn embeddings(
|
pub fn embeddings(
|
||||||
&self,
|
&self,
|
||||||
model: constants::EmbedModel,
|
model: constants::Model,
|
||||||
input: Vec<String>,
|
input: Vec<String>,
|
||||||
options: Option<embedding::EmbeddingRequestOptions>,
|
options: Option<embedding::EmbeddingRequestOptions>,
|
||||||
) -> Result<embedding::EmbeddingResponse, error::ApiError> {
|
) -> Result<embedding::EmbeddingResponse, error::ApiError> {
|
||||||
@@ -305,7 +175,6 @@ impl Client {
|
|||||||
match result {
|
match result {
|
||||||
Ok(data) => {
|
Ok(data) => {
|
||||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||||
|
|
||||||
Ok(data)
|
Ok(data)
|
||||||
}
|
}
|
||||||
Err(error) => Err(self.to_api_error(error)),
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
@@ -314,7 +183,7 @@ impl Client {
|
|||||||
|
|
||||||
pub async fn embeddings_async(
|
pub async fn embeddings_async(
|
||||||
&self,
|
&self,
|
||||||
model: constants::EmbedModel,
|
model: constants::Model,
|
||||||
input: Vec<String>,
|
input: Vec<String>,
|
||||||
options: Option<embedding::EmbeddingRequestOptions>,
|
options: Option<embedding::EmbeddingRequestOptions>,
|
||||||
) -> Result<embedding::EmbeddingResponse, error::ApiError> {
|
) -> Result<embedding::EmbeddingResponse, error::ApiError> {
|
||||||
@@ -325,18 +194,15 @@ impl Client {
|
|||||||
match result {
|
match result {
|
||||||
Ok(data) => {
|
Ok(data) => {
|
||||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||||
|
|
||||||
Ok(data)
|
Ok(data)
|
||||||
}
|
}
|
||||||
Err(error) => Err(self.to_api_error(error)),
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
|
// =========================================================================
|
||||||
let mut result_lock = self.last_function_call_result.lock().unwrap();
|
// Models
|
||||||
|
// =========================================================================
|
||||||
result_lock.take()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn list_models(&self) -> Result<model_list::ModelListResponse, error::ApiError> {
|
pub fn list_models(&self) -> Result<model_list::ModelListResponse, error::ApiError> {
|
||||||
let response = self.get_sync("/models")?;
|
let response = self.get_sync("/models")?;
|
||||||
@@ -344,7 +210,6 @@ impl Client {
|
|||||||
match result {
|
match result {
|
||||||
Ok(data) => {
|
Ok(data) => {
|
||||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||||
|
|
||||||
Ok(data)
|
Ok(data)
|
||||||
}
|
}
|
||||||
Err(error) => Err(self.to_api_error(error)),
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
@@ -359,68 +224,136 @@ impl Client {
|
|||||||
match result {
|
match result {
|
||||||
Ok(data) => {
|
Ok(data) => {
|
||||||
utils::debug_pretty_json_from_struct("Response Data", &data);
|
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||||
|
|
||||||
Ok(data)
|
Ok(data)
|
||||||
}
|
}
|
||||||
Err(error) => Err(self.to_api_error(error)),
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_model(&self, model_id: &str) -> Result<model_list::ModelListData, error::ApiError> {
|
||||||
|
let response = self.get_sync(&format!("/models/{}", model_id))?;
|
||||||
|
let result = response.json::<model_list::ModelListData>();
|
||||||
|
match result {
|
||||||
|
Ok(data) => {
|
||||||
|
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_model_async(
|
||||||
|
&self,
|
||||||
|
model_id: &str,
|
||||||
|
) -> Result<model_list::ModelListData, error::ApiError> {
|
||||||
|
let response = self.get_async(&format!("/models/{}", model_id)).await?;
|
||||||
|
let result = response.json::<model_list::ModelListData>().await;
|
||||||
|
match result {
|
||||||
|
Ok(data) => {
|
||||||
|
utils::debug_pretty_json_from_struct("Response Data", &data);
|
||||||
|
Ok(data)
|
||||||
|
}
|
||||||
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn delete_model(
|
||||||
|
&self,
|
||||||
|
model_id: &str,
|
||||||
|
) -> Result<model_list::ModelDeleteResponse, error::ApiError> {
|
||||||
|
let response = self.delete_sync(&format!("/models/{}", model_id))?;
|
||||||
|
let result = response.json::<model_list::ModelDeleteResponse>();
|
||||||
|
match result {
|
||||||
|
Ok(data) => Ok(data),
|
||||||
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_model_async(
|
||||||
|
&self,
|
||||||
|
model_id: &str,
|
||||||
|
) -> Result<model_list::ModelDeleteResponse, error::ApiError> {
|
||||||
|
let response = self
|
||||||
|
.delete_async(&format!("/models/{}", model_id))
|
||||||
|
.await?;
|
||||||
|
let result = response.json::<model_list::ModelDeleteResponse>().await;
|
||||||
|
match result {
|
||||||
|
Ok(data) => Ok(data),
|
||||||
|
Err(error) => Err(self.to_api_error(error)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// Function Calling
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
pub fn register_function(&mut self, name: String, function: Box<dyn tool::Function>) {
|
pub fn register_function(&mut self, name: String, function: Box<dyn tool::Function>) {
|
||||||
let mut functions = self.functions.lock().unwrap();
|
let mut functions = self.functions.lock().unwrap();
|
||||||
|
|
||||||
functions.insert(name, function);
|
functions.insert(name, function);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_last_function_call_result(&self) -> Option<Box<dyn Any + Send>> {
|
||||||
|
let mut result_lock = self.last_function_call_result.lock().unwrap();
|
||||||
|
result_lock.take()
|
||||||
|
}
|
||||||
|
|
||||||
|
// =========================================================================
|
||||||
|
// HTTP Transport
|
||||||
|
// =========================================================================
|
||||||
|
|
||||||
|
fn user_agent(&self) -> String {
|
||||||
|
format!(
|
||||||
|
"mistralai-client-rs/{}",
|
||||||
|
env!("CARGO_PKG_VERSION")
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn build_request_sync(
|
fn build_request_sync(
|
||||||
&self,
|
&self,
|
||||||
request: reqwest::blocking::RequestBuilder,
|
request: reqwest::blocking::RequestBuilder,
|
||||||
) -> reqwest::blocking::RequestBuilder {
|
) -> reqwest::blocking::RequestBuilder {
|
||||||
let user_agent = format!(
|
request
|
||||||
"ivangabriele/mistralai-client-rs/{}",
|
|
||||||
env!("CARGO_PKG_VERSION")
|
|
||||||
);
|
|
||||||
|
|
||||||
let request_builder = request
|
|
||||||
.bearer_auth(&self.api_key)
|
.bearer_auth(&self.api_key)
|
||||||
.header("Accept", "application/json")
|
.header("Accept", "application/json")
|
||||||
.header("User-Agent", user_agent);
|
.header("User-Agent", self.user_agent())
|
||||||
|
}
|
||||||
|
|
||||||
request_builder
|
fn build_request_sync_no_accept(
|
||||||
|
&self,
|
||||||
|
request: reqwest::blocking::RequestBuilder,
|
||||||
|
) -> reqwest::blocking::RequestBuilder {
|
||||||
|
request
|
||||||
|
.bearer_auth(&self.api_key)
|
||||||
|
.header("User-Agent", self.user_agent())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
fn build_request_async(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||||
let user_agent = format!(
|
request
|
||||||
"ivangabriele/mistralai-client-rs/{}",
|
|
||||||
env!("CARGO_PKG_VERSION")
|
|
||||||
);
|
|
||||||
|
|
||||||
let request_builder = request
|
|
||||||
.bearer_auth(&self.api_key)
|
.bearer_auth(&self.api_key)
|
||||||
.header("Accept", "application/json")
|
.header("Accept", "application/json")
|
||||||
.header("User-Agent", user_agent);
|
.header("User-Agent", self.user_agent())
|
||||||
|
}
|
||||||
|
|
||||||
request_builder
|
fn build_request_async_no_accept(
|
||||||
|
&self,
|
||||||
|
request: reqwest::RequestBuilder,
|
||||||
|
) -> reqwest::RequestBuilder {
|
||||||
|
request
|
||||||
|
.bearer_auth(&self.api_key)
|
||||||
|
.header("User-Agent", self.user_agent())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
fn build_request_stream(&self, request: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
|
||||||
let user_agent = format!(
|
request
|
||||||
"ivangabriele/mistralai-client-rs/{}",
|
|
||||||
env!("CARGO_PKG_VERSION")
|
|
||||||
);
|
|
||||||
|
|
||||||
let request_builder = request
|
|
||||||
.bearer_auth(&self.api_key)
|
.bearer_auth(&self.api_key)
|
||||||
.header("Accept", "text/event-stream")
|
.header("Accept", "text/event-stream")
|
||||||
.header("User-Agent", user_agent);
|
.header("User-Agent", self.user_agent())
|
||||||
|
|
||||||
request_builder
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn call_function_if_any(&self, response: chat::ChatResponse) -> () {
|
fn call_function_if_any(&self, response: chat::ChatResponse) {
|
||||||
let next_result = match response.choices.get(0) {
|
let next_result = match response.choices.first() {
|
||||||
Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
|
Some(first_choice) => match first_choice.message.tool_calls.as_ref() {
|
||||||
Some(tool_calls) => match tool_calls.get(0) {
|
Some(tool_calls) => match tool_calls.first() {
|
||||||
Some(first_tool_call) => {
|
Some(first_tool_call) => {
|
||||||
let functions = self.functions.lock().unwrap();
|
let functions = self.functions.lock().unwrap();
|
||||||
match functions.get(&first_tool_call.function.name) {
|
match functions.get(&first_tool_call.function.name) {
|
||||||
@@ -431,7 +364,6 @@ impl Client {
|
|||||||
.execute(first_tool_call.function.arguments.to_owned())
|
.execute(first_tool_call.function.arguments.to_owned())
|
||||||
.await
|
.await
|
||||||
});
|
});
|
||||||
|
|
||||||
Some(result)
|
Some(result)
|
||||||
}
|
}
|
||||||
None => None,
|
None => None,
|
||||||
@@ -448,10 +380,10 @@ impl Client {
|
|||||||
*last_result_lock = next_result;
|
*last_result_lock = next_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn call_function_if_any_async(&self, response: chat::ChatResponse) -> () {
|
async fn call_function_if_any_async(&self, response: chat::ChatResponse) {
|
||||||
let next_result = match response.choices.get(0) {
|
let next_result = match response.choices.first() {
|
||||||
Some(first_choice) => match first_choice.message.tool_calls.to_owned() {
|
Some(first_choice) => match first_choice.message.tool_calls.as_ref() {
|
||||||
Some(tool_calls) => match tool_calls.get(0) {
|
Some(tool_calls) => match tool_calls.first() {
|
||||||
Some(first_tool_call) => {
|
Some(first_tool_call) => {
|
||||||
let functions = self.functions.lock().unwrap();
|
let functions = self.functions.lock().unwrap();
|
||||||
match functions.get(&first_tool_call.function.name) {
|
match functions.get(&first_tool_call.function.name) {
|
||||||
@@ -459,7 +391,6 @@ impl Client {
|
|||||||
let result = function
|
let result = function
|
||||||
.execute(first_tool_call.function.arguments.to_owned())
|
.execute(first_tool_call.function.arguments.to_owned())
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
Some(result)
|
Some(result)
|
||||||
}
|
}
|
||||||
None => None,
|
None => None,
|
||||||
@@ -482,27 +413,8 @@ impl Client {
|
|||||||
debug!("Request URL: {}", url);
|
debug!("Request URL: {}", url);
|
||||||
|
|
||||||
let request = self.build_request_sync(reqwest_client.get(url));
|
let request = self.build_request_sync(reqwest_client.get(url));
|
||||||
|
|
||||||
let result = request.send();
|
let result = request.send();
|
||||||
match result {
|
self.handle_sync_response(result)
|
||||||
Ok(response) => {
|
|
||||||
if response.status().is_success() {
|
|
||||||
Ok(response)
|
|
||||||
} else {
|
|
||||||
let response_status = response.status();
|
|
||||||
let response_body = response.text().unwrap_or_default();
|
|
||||||
debug!("Response Status: {}", &response_status);
|
|
||||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
|
||||||
|
|
||||||
Err(error::ApiError {
|
|
||||||
message: format!("{}: {}", response_status, response_body),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(error) => Err(error::ApiError {
|
|
||||||
message: error.to_string(),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
|
async fn get_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
|
||||||
@@ -510,29 +422,9 @@ impl Client {
|
|||||||
let url = format!("{}{}", self.endpoint, path);
|
let url = format!("{}{}", self.endpoint, path);
|
||||||
debug!("Request URL: {}", url);
|
debug!("Request URL: {}", url);
|
||||||
|
|
||||||
let request_builder = reqwest_client.get(url);
|
let request = self.build_request_async(reqwest_client.get(url));
|
||||||
let request = self.build_request_async(request_builder);
|
|
||||||
|
|
||||||
let result = request.send().await;
|
let result = request.send().await;
|
||||||
match result {
|
self.handle_async_response(result).await
|
||||||
Ok(response) => {
|
|
||||||
if response.status().is_success() {
|
|
||||||
Ok(response)
|
|
||||||
} else {
|
|
||||||
let response_status = response.status();
|
|
||||||
let response_body = response.text().await.unwrap_or_default();
|
|
||||||
debug!("Response Status: {}", &response_status);
|
|
||||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
|
||||||
|
|
||||||
Err(error::ApiError {
|
|
||||||
message: format!("{}: {}", response_status, response_body),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Err(error) => Err(error::ApiError {
|
|
||||||
message: error.to_string(),
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn post_sync<T: std::fmt::Debug + serde::ser::Serialize>(
|
fn post_sync<T: std::fmt::Debug + serde::ser::Serialize>(
|
||||||
@@ -545,29 +437,22 @@ impl Client {
|
|||||||
debug!("Request URL: {}", url);
|
debug!("Request URL: {}", url);
|
||||||
utils::debug_pretty_json_from_struct("Request Body", params);
|
utils::debug_pretty_json_from_struct("Request Body", params);
|
||||||
|
|
||||||
let request_builder = reqwest_client.post(url).json(params);
|
let request = self.build_request_sync(reqwest_client.post(url).json(params));
|
||||||
let request = self.build_request_sync(request_builder);
|
|
||||||
|
|
||||||
let result = request.send();
|
let result = request.send();
|
||||||
match result {
|
self.handle_sync_response(result)
|
||||||
Ok(response) => {
|
}
|
||||||
if response.status().is_success() {
|
|
||||||
Ok(response)
|
|
||||||
} else {
|
|
||||||
let response_status = response.status();
|
|
||||||
let response_body = response.text().unwrap_or_default();
|
|
||||||
debug!("Response Status: {}", &response_status);
|
|
||||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
|
||||||
|
|
||||||
Err(error::ApiError {
|
fn post_sync_empty(
|
||||||
message: format!("{}: {}", response_body, response_status),
|
&self,
|
||||||
})
|
path: &str,
|
||||||
}
|
) -> Result<reqwest::blocking::Response, error::ApiError> {
|
||||||
}
|
let reqwest_client = reqwest::blocking::Client::new();
|
||||||
Err(error) => Err(error::ApiError {
|
let url = format!("{}{}", self.endpoint, path);
|
||||||
message: error.to_string(),
|
debug!("Request URL: {}", url);
|
||||||
}),
|
|
||||||
}
|
let request = self.build_request_sync(reqwest_client.post(url));
|
||||||
|
let result = request.send();
|
||||||
|
self.handle_sync_response(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
|
async fn post_async<T: serde::ser::Serialize + std::fmt::Debug>(
|
||||||
@@ -580,29 +465,19 @@ impl Client {
|
|||||||
debug!("Request URL: {}", url);
|
debug!("Request URL: {}", url);
|
||||||
utils::debug_pretty_json_from_struct("Request Body", params);
|
utils::debug_pretty_json_from_struct("Request Body", params);
|
||||||
|
|
||||||
let request_builder = reqwest_client.post(url).json(params);
|
let request = self.build_request_async(reqwest_client.post(url).json(params));
|
||||||
let request = self.build_request_async(request_builder);
|
|
||||||
|
|
||||||
let result = request.send().await;
|
let result = request.send().await;
|
||||||
match result {
|
self.handle_async_response(result).await
|
||||||
Ok(response) => {
|
}
|
||||||
if response.status().is_success() {
|
|
||||||
Ok(response)
|
|
||||||
} else {
|
|
||||||
let response_status = response.status();
|
|
||||||
let response_body = response.text().await.unwrap_or_default();
|
|
||||||
debug!("Response Status: {}", &response_status);
|
|
||||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
|
||||||
|
|
||||||
Err(error::ApiError {
|
async fn post_async_empty(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
|
||||||
message: format!("{}: {}", response_status, response_body),
|
let reqwest_client = reqwest::Client::new();
|
||||||
})
|
let url = format!("{}{}", self.endpoint, path);
|
||||||
}
|
debug!("Request URL: {}", url);
|
||||||
}
|
|
||||||
Err(error) => Err(error::ApiError {
|
let request = self.build_request_async(reqwest_client.post(url));
|
||||||
message: error.to_string(),
|
let result = request.send().await;
|
||||||
}),
|
self.handle_async_response(result).await
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn post_stream<T: serde::ser::Serialize + std::fmt::Debug>(
|
async fn post_stream<T: serde::ser::Serialize + std::fmt::Debug>(
|
||||||
@@ -615,22 +490,70 @@ impl Client {
|
|||||||
debug!("Request URL: {}", url);
|
debug!("Request URL: {}", url);
|
||||||
utils::debug_pretty_json_from_struct("Request Body", params);
|
utils::debug_pretty_json_from_struct("Request Body", params);
|
||||||
|
|
||||||
let request_builder = reqwest_client.post(url).json(params);
|
let request = self.build_request_stream(reqwest_client.post(url).json(params));
|
||||||
let request = self.build_request_stream(request_builder);
|
|
||||||
|
|
||||||
let result = request.send().await;
|
let result = request.send().await;
|
||||||
|
self.handle_async_response(result).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn delete_sync(&self, path: &str) -> Result<reqwest::blocking::Response, error::ApiError> {
|
||||||
|
let reqwest_client = reqwest::blocking::Client::new();
|
||||||
|
let url = format!("{}{}", self.endpoint, path);
|
||||||
|
debug!("Request URL: {}", url);
|
||||||
|
|
||||||
|
let request = self.build_request_sync(reqwest_client.delete(url));
|
||||||
|
let result = request.send();
|
||||||
|
self.handle_sync_response(result)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete_async(&self, path: &str) -> Result<reqwest::Response, error::ApiError> {
|
||||||
|
let reqwest_client = reqwest::Client::new();
|
||||||
|
let url = format!("{}{}", self.endpoint, path);
|
||||||
|
debug!("Request URL: {}", url);
|
||||||
|
|
||||||
|
let request = self.build_request_async(reqwest_client.delete(url));
|
||||||
|
let result = request.send().await;
|
||||||
|
self.handle_async_response(result).await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn handle_sync_response(
|
||||||
|
&self,
|
||||||
|
result: Result<reqwest::blocking::Response, reqwest::Error>,
|
||||||
|
) -> Result<reqwest::blocking::Response, error::ApiError> {
|
||||||
match result {
|
match result {
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
if response.status().is_success() {
|
if response.status().is_success() {
|
||||||
Ok(response)
|
Ok(response)
|
||||||
} else {
|
} else {
|
||||||
let response_status = response.status();
|
let status = response.status();
|
||||||
let response_body = response.text().await.unwrap_or_default();
|
let body = response.text().unwrap_or_default();
|
||||||
debug!("Response Status: {}", &response_status);
|
debug!("Response Status: {}", &status);
|
||||||
utils::debug_pretty_json_from_string("Response Data", &response_body);
|
utils::debug_pretty_json_from_string("Response Data", &body);
|
||||||
|
|
||||||
Err(error::ApiError {
|
Err(error::ApiError {
|
||||||
message: format!("{}: {}", response_status, response_body),
|
message: format!("{}: {}", status, body),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(error) => Err(error::ApiError {
|
||||||
|
message: error.to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_async_response(
|
||||||
|
&self,
|
||||||
|
result: Result<reqwest::Response, reqwest::Error>,
|
||||||
|
) -> Result<reqwest::Response, error::ApiError> {
|
||||||
|
match result {
|
||||||
|
Ok(response) => {
|
||||||
|
if response.status().is_success() {
|
||||||
|
Ok(response)
|
||||||
|
} else {
|
||||||
|
let status = response.status();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
debug!("Response Status: {}", &status);
|
||||||
|
utils::debug_pretty_json_from_string("Response Data", &body);
|
||||||
|
Err(error::ApiError {
|
||||||
|
message: format!("{}: {}", status, body),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ use serde::{Deserialize, Serialize};
|
|||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ResponseUsage {
|
pub struct ResponseUsage {
|
||||||
pub prompt_tokens: u32,
|
pub prompt_tokens: u32,
|
||||||
|
#[serde(default)]
|
||||||
pub completion_tokens: u32,
|
pub completion_tokens: u32,
|
||||||
pub total_tokens: u32,
|
pub total_tokens: u32,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,35 +1,131 @@
|
|||||||
|
use std::fmt;
|
||||||
|
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
/// A Mistral AI model identifier.
|
||||||
pub enum Model {
|
///
|
||||||
#[serde(rename = "open-mistral-7b")]
|
/// Use the associated constants for known models, or construct with `Model::new()` for any model string.
|
||||||
OpenMistral7b,
|
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
|
||||||
#[serde(rename = "open-mixtral-8x7b")]
|
#[serde(transparent)]
|
||||||
OpenMixtral8x7b,
|
pub struct Model(pub String);
|
||||||
#[serde(rename = "open-mixtral-8x22b")]
|
|
||||||
OpenMixtral8x22b,
|
impl Model {
|
||||||
#[serde(rename = "open-mistral-nemo", alias = "open-mistral-nemo-2407")]
|
pub fn new(id: impl Into<String>) -> Self {
|
||||||
OpenMistralNemo,
|
Self(id.into())
|
||||||
#[serde(rename = "mistral-tiny")]
|
}
|
||||||
MistralTiny,
|
|
||||||
#[serde(rename = "mistral-small-latest", alias = "mistral-small-2402")]
|
// Flagship / Premier
|
||||||
MistralSmallLatest,
|
pub fn mistral_large_latest() -> Self {
|
||||||
#[serde(rename = "mistral-medium-latest", alias = "mistral-medium-2312")]
|
Self::new("mistral-large-latest")
|
||||||
MistralMediumLatest,
|
}
|
||||||
#[serde(rename = "mistral-large-latest", alias = "mistral-large-2407")]
|
pub fn mistral_large_3() -> Self {
|
||||||
MistralLargeLatest,
|
Self::new("mistral-large-3-25-12")
|
||||||
#[serde(rename = "mistral-large-2402")]
|
}
|
||||||
MistralLarge,
|
pub fn mistral_medium_latest() -> Self {
|
||||||
#[serde(rename = "codestral-latest", alias = "codestral-2405")]
|
Self::new("mistral-medium-latest")
|
||||||
CodestralLatest,
|
}
|
||||||
#[serde(rename = "open-codestral-mamba")]
|
pub fn mistral_medium_3_1() -> Self {
|
||||||
CodestralMamba,
|
Self::new("mistral-medium-3-1-25-08")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_latest() -> Self {
|
||||||
|
Self::new("mistral-small-latest")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_4() -> Self {
|
||||||
|
Self::new("mistral-small-4-0-26-03")
|
||||||
|
}
|
||||||
|
pub fn mistral_small_3_2() -> Self {
|
||||||
|
Self::new("mistral-small-3-2-25-06")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ministral
|
||||||
|
pub fn ministral_3_14b() -> Self {
|
||||||
|
Self::new("ministral-3-14b-25-12")
|
||||||
|
}
|
||||||
|
pub fn ministral_3_8b() -> Self {
|
||||||
|
Self::new("ministral-3-8b-25-12")
|
||||||
|
}
|
||||||
|
pub fn ministral_3_3b() -> Self {
|
||||||
|
Self::new("ministral-3-3b-25-12")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reasoning
|
||||||
|
pub fn magistral_medium_latest() -> Self {
|
||||||
|
Self::new("magistral-medium-latest")
|
||||||
|
}
|
||||||
|
pub fn magistral_small_latest() -> Self {
|
||||||
|
Self::new("magistral-small-latest")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Code
|
||||||
|
pub fn codestral_latest() -> Self {
|
||||||
|
Self::new("codestral-latest")
|
||||||
|
}
|
||||||
|
pub fn codestral_2508() -> Self {
|
||||||
|
Self::new("codestral-2508")
|
||||||
|
}
|
||||||
|
pub fn codestral_embed() -> Self {
|
||||||
|
Self::new("codestral-embed-25-05")
|
||||||
|
}
|
||||||
|
pub fn devstral_2() -> Self {
|
||||||
|
Self::new("devstral-2-25-12")
|
||||||
|
}
|
||||||
|
pub fn devstral_small_2() -> Self {
|
||||||
|
Self::new("devstral-small-2-25-12")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multimodal / Vision
|
||||||
|
pub fn pixtral_large() -> Self {
|
||||||
|
Self::new("pixtral-large-2411")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Audio
|
||||||
|
pub fn voxtral_mini_transcribe() -> Self {
|
||||||
|
Self::new("voxtral-mini-transcribe-2-26-02")
|
||||||
|
}
|
||||||
|
pub fn voxtral_small() -> Self {
|
||||||
|
Self::new("voxtral-small-25-07")
|
||||||
|
}
|
||||||
|
pub fn voxtral_mini() -> Self {
|
||||||
|
Self::new("voxtral-mini-25-07")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Legacy (kept for backward compatibility)
|
||||||
|
pub fn open_mistral_nemo() -> Self {
|
||||||
|
Self::new("open-mistral-nemo")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Embedding
|
||||||
|
pub fn mistral_embed() -> Self {
|
||||||
|
Self::new("mistral-embed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Moderation
|
||||||
|
pub fn mistral_moderation_latest() -> Self {
|
||||||
|
Self::new("mistral-moderation-26-03")
|
||||||
|
}
|
||||||
|
|
||||||
|
// OCR
|
||||||
|
pub fn mistral_ocr_latest() -> Self {
|
||||||
|
Self::new("mistral-ocr-latest")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
impl fmt::Display for Model {
|
||||||
pub enum EmbedModel {
|
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
|
||||||
#[serde(rename = "mistral-embed")]
|
write!(f, "{}", self.0)
|
||||||
MistralEmbed,
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&str> for Model {
|
||||||
|
fn from(s: &str) -> Self {
|
||||||
|
Self(s.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<String> for Model {
|
||||||
|
fn from(s: String) -> Self {
|
||||||
|
Self(s)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,42 +8,63 @@ use crate::v1::{common, constants};
|
|||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct EmbeddingRequestOptions {
|
pub struct EmbeddingRequestOptions {
|
||||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||||
|
pub output_dimension: Option<u32>,
|
||||||
|
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||||
}
|
}
|
||||||
impl Default for EmbeddingRequestOptions {
|
impl Default for EmbeddingRequestOptions {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
encoding_format: None,
|
encoding_format: None,
|
||||||
|
output_dimension: None,
|
||||||
|
output_dtype: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct EmbeddingRequest {
|
pub struct EmbeddingRequest {
|
||||||
pub model: constants::EmbedModel,
|
pub model: constants::Model,
|
||||||
pub input: Vec<String>,
|
pub input: Vec<String>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
pub encoding_format: Option<EmbeddingRequestEncodingFormat>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_dimension: Option<u32>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub output_dtype: Option<EmbeddingOutputDtype>,
|
||||||
}
|
}
|
||||||
impl EmbeddingRequest {
|
impl EmbeddingRequest {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
model: constants::EmbedModel,
|
model: constants::Model,
|
||||||
input: Vec<String>,
|
input: Vec<String>,
|
||||||
options: Option<EmbeddingRequestOptions>,
|
options: Option<EmbeddingRequestOptions>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let EmbeddingRequestOptions { encoding_format } = options.unwrap_or_default();
|
let opts = options.unwrap_or_default();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
input,
|
input,
|
||||||
encoding_format,
|
encoding_format: opts.encoding_format,
|
||||||
|
output_dimension: opts.output_dimension,
|
||||||
|
output_dtype: opts.output_dtype,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
#[allow(non_camel_case_types)]
|
#[serde(rename_all = "lowercase")]
|
||||||
pub enum EmbeddingRequestEncodingFormat {
|
pub enum EmbeddingRequestEncodingFormat {
|
||||||
float,
|
Float,
|
||||||
|
Base64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum EmbeddingOutputDtype {
|
||||||
|
Float,
|
||||||
|
Int8,
|
||||||
|
Uint8,
|
||||||
|
Binary,
|
||||||
|
Ubinary,
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
@@ -51,9 +72,8 @@ pub enum EmbeddingRequestEncodingFormat {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct EmbeddingResponse {
|
pub struct EmbeddingResponse {
|
||||||
pub id: String,
|
|
||||||
pub object: String,
|
pub object: String,
|
||||||
pub model: constants::EmbedModel,
|
pub model: constants::Model,
|
||||||
pub data: Vec<EmbeddingResponseDataItem>,
|
pub data: Vec<EmbeddingResponseDataItem>,
|
||||||
pub usage: common::ResponseUsage,
|
pub usage: common::ResponseUsage,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,23 +15,44 @@ pub struct ModelListData {
|
|||||||
pub id: String,
|
pub id: String,
|
||||||
pub object: String,
|
pub object: String,
|
||||||
/// Unix timestamp (in seconds).
|
/// Unix timestamp (in seconds).
|
||||||
pub created: u32,
|
pub created: u64,
|
||||||
pub owned_by: String,
|
pub owned_by: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub root: Option<String>,
|
pub root: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
pub archived: bool,
|
pub archived: bool,
|
||||||
pub name: String,
|
#[serde(default)]
|
||||||
pub description: String,
|
pub name: Option<String>,
|
||||||
pub capabilities: ModelListDataCapabilies,
|
#[serde(default)]
|
||||||
pub max_context_length: u32,
|
pub description: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub capabilities: Option<ModelListDataCapabilities>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub max_context_length: Option<u32>,
|
||||||
|
#[serde(default)]
|
||||||
pub aliases: Vec<String>,
|
pub aliases: Vec<String>,
|
||||||
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
|
/// ISO 8601 date (`YYYY-MM-DDTHH:MM:SSZ`).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub deprecation: Option<String>,
|
pub deprecation: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ModelListDataCapabilies {
|
pub struct ModelListDataCapabilities {
|
||||||
|
#[serde(default)]
|
||||||
pub completion_chat: bool,
|
pub completion_chat: bool,
|
||||||
|
#[serde(default)]
|
||||||
pub completion_fim: bool,
|
pub completion_fim: bool,
|
||||||
|
#[serde(default)]
|
||||||
pub function_calling: bool,
|
pub function_calling: bool,
|
||||||
|
#[serde(default)]
|
||||||
pub fine_tuning: bool,
|
pub fine_tuning: bool,
|
||||||
|
#[serde(default)]
|
||||||
|
pub vision: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
|
pub struct ModelDeleteResponse {
|
||||||
|
pub id: String,
|
||||||
|
pub object: String,
|
||||||
|
pub deleted: bool,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{any::Any, collections::HashMap, fmt::Debug};
|
use std::{any::Any, fmt::Debug};
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Definitions
|
// Definitions
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
pub struct ToolCall {
|
pub struct ToolCall {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub r#type: Option<String>,
|
||||||
pub function: ToolCallFunction,
|
pub function: ToolCallFunction,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -22,31 +26,12 @@ pub struct Tool {
|
|||||||
pub function: ToolFunction,
|
pub function: ToolFunction,
|
||||||
}
|
}
|
||||||
impl Tool {
|
impl Tool {
|
||||||
|
/// Create a tool with a JSON Schema parameters object.
|
||||||
pub fn new(
|
pub fn new(
|
||||||
function_name: String,
|
function_name: String,
|
||||||
function_description: String,
|
function_description: String,
|
||||||
function_parameters: Vec<ToolFunctionParameter>,
|
parameters: serde_json::Value,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
let properties: HashMap<String, ToolFunctionParameterProperty> = function_parameters
|
|
||||||
.into_iter()
|
|
||||||
.map(|param| {
|
|
||||||
(
|
|
||||||
param.name,
|
|
||||||
ToolFunctionParameterProperty {
|
|
||||||
r#type: param.r#type,
|
|
||||||
description: param.description,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
let property_names = properties.keys().cloned().collect();
|
|
||||||
|
|
||||||
let parameters = ToolFunctionParameters {
|
|
||||||
r#type: ToolFunctionParametersType::Object,
|
|
||||||
properties,
|
|
||||||
required: property_names,
|
|
||||||
};
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
r#type: ToolType::Function,
|
r#type: ToolType::Function,
|
||||||
function: ToolFunction {
|
function: ToolFunction {
|
||||||
@@ -63,50 +48,9 @@ impl Tool {
|
|||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct ToolFunction {
|
pub struct ToolFunction {
|
||||||
name: String,
|
pub name: String,
|
||||||
description: String,
|
pub description: String,
|
||||||
parameters: ToolFunctionParameters,
|
pub parameters: serde_json::Value,
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ToolFunctionParameter {
|
|
||||||
name: String,
|
|
||||||
description: String,
|
|
||||||
r#type: ToolFunctionParameterType,
|
|
||||||
}
|
|
||||||
impl ToolFunctionParameter {
|
|
||||||
pub fn new(name: String, description: String, r#type: ToolFunctionParameterType) -> Self {
|
|
||||||
Self {
|
|
||||||
name,
|
|
||||||
r#type,
|
|
||||||
description,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ToolFunctionParameters {
|
|
||||||
r#type: ToolFunctionParametersType,
|
|
||||||
properties: HashMap<String, ToolFunctionParameterProperty>,
|
|
||||||
required: Vec<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
|
||||||
pub struct ToolFunctionParameterProperty {
|
|
||||||
r#type: ToolFunctionParameterType,
|
|
||||||
description: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
|
||||||
pub enum ToolFunctionParametersType {
|
|
||||||
#[serde(rename = "object")]
|
|
||||||
Object,
|
|
||||||
}
|
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
|
||||||
pub enum ToolFunctionParameterType {
|
|
||||||
#[serde(rename = "string")]
|
|
||||||
String,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
@@ -127,6 +71,9 @@ pub enum ToolChoice {
|
|||||||
/// The model won't call a function and will generate a message instead.
|
/// The model won't call a function and will generate a message instead.
|
||||||
#[serde(rename = "none")]
|
#[serde(rename = "none")]
|
||||||
None,
|
None,
|
||||||
|
/// The model must call at least one tool.
|
||||||
|
#[serde(rename = "required")]
|
||||||
|
Required,
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user