Files
mistralai-client-rs/src/v1/chat.rs
Ivan Gabriele cf68a77320 feat(chat)!: change safe_prompt, temperature & top_p to non-Option types
BREAKING CHANGE:
- `Chat::ChatParams.safe_prompt` & `Chat::ChatRequest.safe_prompt` are now `bool` instead of `Option<bool>`. Default is `false`.
- `Chat::ChatParams.temperature` & `Chat::ChatRequest.temperature` are now `f32` instead of `Option<f32>`. Default is `0.7`.
- `Chat::ChatParams.top_p` & `Chat::ChatRequest.top_p` are now `f32` instead of `Option<f32>`. Default is `1.0`.
2024-06-07 16:00:10 +02:00

212 lines
5.7 KiB
Rust

use serde::{Deserialize, Serialize};
use crate::v1::{common, constants, tool};
// -----------------------------------------------------------------------------
// Definitions
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ChatMessage {
pub role: ChatMessageRole,
pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>,
}
impl ChatMessage {
pub fn new_assistant_message(content: &str, tool_calls: Option<Vec<tool::ToolCall>>) -> Self {
Self {
role: ChatMessageRole::Assistant,
content: content.to_string(),
tool_calls,
}
}
pub fn new_user_message(content: &str) -> Self {
Self {
role: ChatMessageRole::User,
content: content.to_string(),
tool_calls: None,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ChatMessageRole {
#[serde(rename = "assistant")]
Assistant,
#[serde(rename = "user")]
User,
}
/// The format that the model must output.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub type_: String,
}
impl ResponseFormat {
pub fn json_object() -> Self {
Self {
type_: "json_object".to_string(),
}
}
}
// -----------------------------------------------------------------------------
// Request
/// The parameters for the chat request.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Debug)]
pub struct ChatParams {
/// The maximum number of tokens to generate in the completion.
///
/// Defaults to `None`.
pub max_tokens: Option<u32>,
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
///
/// Defaults to `None`.
pub random_seed: Option<u32>,
/// The format that the model must output.
///
/// Defaults to `None`.
pub response_format: Option<ResponseFormat>,
/// Whether to inject a safety prompt before all conversations.
///
/// Defaults to `false`.
pub safe_prompt: bool,
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
///
/// Defaults to `0.7`.
pub temperature: f32,
/// Specifies if/how functions are called.
///
/// Defaults to `None`.
pub tool_choice: Option<tool::ToolChoice>,
/// A list of available tools for the model.
///
/// Defaults to `None`.
pub tools: Option<Vec<tool::Tool>>,
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
///
/// Defaults to `1.0`.
pub top_p: f32,
}
impl Default for ChatParams {
fn default() -> Self {
Self {
max_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
tool_choice: None,
tools: None,
top_p: 1.0,
}
}
}
impl ChatParams {
pub fn json_default() -> Self {
Self {
max_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
tool_choice: None,
tools: None,
top_p: 1.0,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ChatRequest {
pub messages: Vec<ChatMessage>,
pub model: constants::Model,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_format: Option<ResponseFormat>,
pub safe_prompt: bool,
pub stream: bool,
pub temperature: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<tool::ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<tool::Tool>>,
pub top_p: f32,
}
impl ChatRequest {
pub fn new(
model: constants::Model,
messages: Vec<ChatMessage>,
stream: bool,
options: Option<ChatParams>,
) -> Self {
let ChatParams {
max_tokens,
random_seed,
safe_prompt,
temperature,
tool_choice,
tools,
top_p,
response_format,
} = options.unwrap_or_default();
Self {
messages,
model,
max_tokens,
random_seed,
safe_prompt,
stream,
temperature,
tool_choice,
tools,
top_p,
response_format,
}
}
}
// -----------------------------------------------------------------------------
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ChatResponse {
pub id: String,
pub object: String,
/// Unix timestamp (in seconds).
pub created: u32,
pub model: constants::Model,
pub choices: Vec<ChatResponseChoice>,
pub usage: common::ResponseUsage,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct ChatResponseChoice {
pub index: u32,
pub message: ChatMessage,
pub finish_reason: ChatResponseChoiceFinishReason,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???
}
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ChatResponseChoiceFinishReason {
#[serde(rename = "stop")]
Stop,
#[serde(rename = "tool_calls")]
ToolCalls,
}