From cf68a773201ebe0e802face52af388711acf0c27 Mon Sep 17 00:00:00 2001 From: Ivan Gabriele Date: Fri, 7 Jun 2024 15:53:25 +0200 Subject: [PATCH] feat(chat)!: change safe_prompt, temperature & top_p to non-Option types BREAKING CHANGE: - `Chat::ChatParams.safe_prompt` & `Chat::ChatRequest.safe_prompt` are now `bool` instead of `Option`. Default is `false`. - `Chat::ChatParams.temperature` & `Chat::ChatRequest.temperature` are now `f32` instead of `Option`. Default is `0.7`. - `Chat::ChatParams.top_p` & `Chat::ChatRequest.top_p` are now `f32` instead of `Option`. Default is `1.0`. --- README.md | 14 ++-- README.template.md | 4 ++ examples/chat.rs | 2 +- examples/chat_async.rs | 2 +- examples/chat_with_function_calling.rs | 2 +- examples/chat_with_function_calling_async.rs | 2 +- examples/chat_with_streaming.rs | 2 +- src/v1/chat.rs | 70 +++++++++++++------- src/v1/tool.rs | 4 ++ tests/v1_client_chat_async_test.rs | 11 ++- tests/v1_client_chat_test.rs | 11 ++- 11 files changed, 85 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index 38bfbfb..a88429c 100644 --- a/README.md +++ b/README.md @@ -7,6 +7,10 @@ Rust client for the Mistral AI API. +> [!IMPORTANT] +> While we are in v0, minor versions may introduce breaking changes. +> Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information. + --- - [Supported APIs](#supported-apis) @@ -102,7 +106,7 @@ fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), ..Default::default() }; @@ -134,7 +138,7 @@ async fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), ..Default::default() }; @@ -174,7 +178,7 @@ async fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), ..Default::default() }; @@ -259,7 +263,7 @@ fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools), @@ -336,7 +340,7 @@ async fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools), diff --git a/README.template.md b/README.template.md index 5f3d496..285fead 100644 --- a/README.template.md +++ b/README.template.md @@ -7,6 +7,10 @@ Rust client for the Mistral AI API. +> [!IMPORTANT] +> While we are in v0, minor versions may introduce breaking changes. +> Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information. + --- - [Supported APIs](#supported-apis) diff --git a/examples/chat.rs b/examples/chat.rs index 12d5fd4..ad3be09 100644 --- a/examples/chat.rs +++ b/examples/chat.rs @@ -15,7 +15,7 @@ fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), ..Default::default() }; diff --git a/examples/chat_async.rs b/examples/chat_async.rs index 7034553..a3f35a5 100644 --- a/examples/chat_async.rs +++ b/examples/chat_async.rs @@ -16,7 +16,7 @@ async fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), ..Default::default() }; diff --git a/examples/chat_with_function_calling.rs b/examples/chat_with_function_calling.rs index 991d2f2..67660fb 100644 --- a/examples/chat_with_function_calling.rs +++ b/examples/chat_with_function_calling.rs @@ -53,7 +53,7 @@ fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools), diff --git a/examples/chat_with_function_calling_async.rs b/examples/chat_with_function_calling_async.rs index 0fb5213..5b91329 100644 --- a/examples/chat_with_function_calling_async.rs +++ b/examples/chat_with_function_calling_async.rs @@ -54,7 +54,7 @@ async fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools), diff --git a/examples/chat_with_streaming.rs b/examples/chat_with_streaming.rs index 8515a45..f5ad8d4 100644 --- a/examples/chat_with_streaming.rs +++ b/examples/chat_with_streaming.rs @@ -18,7 +18,7 @@ async fn main() { tool_calls: None, }]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), ..Default::default() }; diff --git a/src/v1/chat.rs b/src/v1/chat.rs index ebb309f..455303e 100644 --- a/src/v1/chat.rs +++ b/src/v1/chat.rs @@ -38,12 +38,14 @@ pub enum ChatMessageRole { User, } +/// The format that the model must output. +/// +/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information. #[derive(Debug, Serialize, Deserialize)] pub struct ResponseFormat { #[serde(rename = "type")] pub type_: String, } - impl ResponseFormat { pub fn json_object() -> Self { Self { @@ -55,28 +57,55 @@ impl ResponseFormat { // ----------------------------------------------------------------------------- // Request +/// The parameters for the chat request. +/// +/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information. #[derive(Debug)] pub struct ChatParams { + /// The maximum number of tokens to generate in the completion. + /// + /// Defaults to `None`. pub max_tokens: Option, + /// The seed to use for random sampling. If set, different calls will generate deterministic results. + /// + /// Defaults to `None`. pub random_seed: Option, - pub safe_prompt: Option, - pub temperature: Option, - pub tool_choice: Option, - pub tools: Option>, - pub top_p: Option, + /// The format that the model must output. + /// + /// Defaults to `None`. pub response_format: Option, + /// Whether to inject a safety prompt before all conversations. + /// + /// Defaults to `false`. + pub safe_prompt: bool, + /// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`. + /// + /// Defaults to `0.7`. + pub temperature: f32, + /// Specifies if/how functions are called. + /// + /// Defaults to `None`. + pub tool_choice: Option, + /// A list of available tools for the model. + /// + /// Defaults to `None`. + pub tools: Option>, + /// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass. + /// + /// Defaults to `1.0`. + pub top_p: f32, } impl Default for ChatParams { fn default() -> Self { Self { max_tokens: None, random_seed: None, - safe_prompt: None, - temperature: None, + safe_prompt: false, + response_format: None, + temperature: 0.7, tool_choice: None, tools: None, - top_p: None, - response_format: None, + top_p: 1.0, } } } @@ -85,12 +114,12 @@ impl ChatParams { Self { max_tokens: None, random_seed: None, - safe_prompt: None, - temperature: None, + safe_prompt: false, + response_format: None, + temperature: 0.7, tool_choice: None, tools: None, - top_p: None, - response_format: Some(ResponseFormat::json_object()), + top_p: 1.0, } } } @@ -105,20 +134,15 @@ pub struct ChatRequest { #[serde(skip_serializing_if = "Option::is_none")] pub random_seed: Option, #[serde(skip_serializing_if = "Option::is_none")] - pub safe_prompt: Option, + pub response_format: Option, + pub safe_prompt: bool, pub stream: bool, - #[serde(skip_serializing_if = "Option::is_none")] - pub temperature: Option, + pub temperature: f32, #[serde(skip_serializing_if = "Option::is_none")] pub tool_choice: Option, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub top_p: Option, - // TODO Check this prop (seen in official Python client but not in API doc). - // pub tool_choice: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub response_format: Option, + pub top_p: f32, } impl ChatRequest { pub fn new( diff --git a/src/v1/tool.rs b/src/v1/tool.rs index 2249a40..4d85b31 100644 --- a/src/v1/tool.rs +++ b/src/v1/tool.rs @@ -115,12 +115,16 @@ pub enum ToolType { Function, } +/// An enum representing how functions should be called. #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] pub enum ToolChoice { + /// The model is forced to call a function. #[serde(rename = "any")] Any, + /// The model can choose to either generate a message or call a function. #[serde(rename = "auto")] Auto, + /// The model won't call a function and will generate a message instead. #[serde(rename = "none")] None, } diff --git a/tests/v1_client_chat_async_test.rs b/tests/v1_client_chat_async_test.rs index 33862db..c06aa11 100644 --- a/tests/v1_client_chat_async_test.rs +++ b/tests/v1_client_chat_async_test.rs @@ -19,7 +19,7 @@ async fn test_client_chat_async() { "Guess the next word: \"Eiffel ...\"?", )]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), ..Default::default() }; @@ -37,7 +37,12 @@ async fn test_client_chat_async() { expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); - expect!(response.choices[0].message.content.clone()).to_start_with("Tower".to_string()); + expect!(response.choices[0] + .message + .content + .clone() + .contains("Tower")) + .to_be(true); expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0); @@ -65,7 +70,7 @@ async fn test_client_chat_async_with_function_calling() { "What's the current temperature in Paris?", )]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), tool_choice: Some(ToolChoice::Any), tools: Some(tools), diff --git a/tests/v1_client_chat_test.rs b/tests/v1_client_chat_test.rs index 485ecca..e489591 100644 --- a/tests/v1_client_chat_test.rs +++ b/tests/v1_client_chat_test.rs @@ -19,7 +19,7 @@ fn test_client_chat() { "Guess the next word: \"Eiffel ...\"?", )]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), ..Default::default() }; @@ -31,7 +31,12 @@ fn test_client_chat() { expect!(response.choices.len()).to_be(1); expect!(response.choices[0].index).to_be(0); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); - expect!(response.choices[0].message.content.clone()).to_start_with("Tower".to_string()); + expect!(response.choices[0] + .message + .content + .clone() + .contains("Tower")) + .to_be(true); expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0); @@ -59,7 +64,7 @@ fn test_client_chat_with_function_calling() { "What's the current temperature in Paris?", )]; let options = ChatParams { - temperature: Some(0.0), + temperature: 0.0, random_seed: Some(42), tool_choice: Some(ToolChoice::Auto), tools: Some(tools),