feat(chat)!: change safe_prompt, temperature & top_p to non-Option types
BREAKING CHANGE: - `Chat::ChatParams.safe_prompt` & `Chat::ChatRequest.safe_prompt` are now `bool` instead of `Option<bool>`. Default is `false`. - `Chat::ChatParams.temperature` & `Chat::ChatRequest.temperature` are now `f32` instead of `Option<f32>`. Default is `0.7`. - `Chat::ChatParams.top_p` & `Chat::ChatRequest.top_p` are now `f32` instead of `Option<f32>`. Default is `1.0`.
This commit is contained in:
14
README.md
14
README.md
@@ -7,6 +7,10 @@
|
|||||||
|
|
||||||
Rust client for the Mistral AI API.
|
Rust client for the Mistral AI API.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> While we are in v0, minor versions may introduce breaking changes.
|
||||||
|
> Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
- [Supported APIs](#supported-apis)
|
- [Supported APIs](#supported-apis)
|
||||||
@@ -102,7 +106,7 @@ fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -134,7 +138,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -174,7 +178,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -259,7 +263,7 @@ fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
@@ -336,7 +340,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -7,6 +7,10 @@
|
|||||||
|
|
||||||
Rust client for the Mistral AI API.
|
Rust client for the Mistral AI API.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> While we are in v0, minor versions may introduce breaking changes.
|
||||||
|
> Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
- [Supported APIs](#supported-apis)
|
- [Supported APIs](#supported-apis)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -38,12 +38,14 @@ pub enum ChatMessageRole {
|
|||||||
User,
|
User,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// The format that the model must output.
|
||||||
|
///
|
||||||
|
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
||||||
#[derive(Debug, Serialize, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct ResponseFormat {
|
pub struct ResponseFormat {
|
||||||
#[serde(rename = "type")]
|
#[serde(rename = "type")]
|
||||||
pub type_: String,
|
pub type_: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ResponseFormat {
|
impl ResponseFormat {
|
||||||
pub fn json_object() -> Self {
|
pub fn json_object() -> Self {
|
||||||
Self {
|
Self {
|
||||||
@@ -55,28 +57,55 @@ impl ResponseFormat {
|
|||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Request
|
// Request
|
||||||
|
|
||||||
|
/// The parameters for the chat request.
|
||||||
|
///
|
||||||
|
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ChatParams {
|
pub struct ChatParams {
|
||||||
|
/// The maximum number of tokens to generate in the completion.
|
||||||
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||||
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
pub random_seed: Option<u32>,
|
pub random_seed: Option<u32>,
|
||||||
pub safe_prompt: Option<bool>,
|
/// The format that the model must output.
|
||||||
pub temperature: Option<f32>,
|
///
|
||||||
pub tool_choice: Option<tool::ToolChoice>,
|
/// Defaults to `None`.
|
||||||
pub tools: Option<Vec<tool::Tool>>,
|
|
||||||
pub top_p: Option<f32>,
|
|
||||||
pub response_format: Option<ResponseFormat>,
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
/// Whether to inject a safety prompt before all conversations.
|
||||||
|
///
|
||||||
|
/// Defaults to `false`.
|
||||||
|
pub safe_prompt: bool,
|
||||||
|
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
|
||||||
|
///
|
||||||
|
/// Defaults to `0.7`.
|
||||||
|
pub temperature: f32,
|
||||||
|
/// Specifies if/how functions are called.
|
||||||
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
|
/// A list of available tools for the model.
|
||||||
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
|
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
|
||||||
|
///
|
||||||
|
/// Defaults to `1.0`.
|
||||||
|
pub top_p: f32,
|
||||||
}
|
}
|
||||||
impl Default for ChatParams {
|
impl Default for ChatParams {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
random_seed: None,
|
random_seed: None,
|
||||||
safe_prompt: None,
|
safe_prompt: false,
|
||||||
temperature: None,
|
response_format: None,
|
||||||
|
temperature: 0.7,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
tools: None,
|
tools: None,
|
||||||
top_p: None,
|
top_p: 1.0,
|
||||||
response_format: None,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -85,12 +114,12 @@ impl ChatParams {
|
|||||||
Self {
|
Self {
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
random_seed: None,
|
random_seed: None,
|
||||||
safe_prompt: None,
|
safe_prompt: false,
|
||||||
temperature: None,
|
response_format: None,
|
||||||
|
temperature: 0.7,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
tools: None,
|
tools: None,
|
||||||
top_p: None,
|
top_p: 1.0,
|
||||||
response_format: Some(ResponseFormat::json_object()),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -105,20 +134,15 @@ pub struct ChatRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub random_seed: Option<u32>,
|
pub random_seed: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub safe_prompt: Option<bool>,
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
pub safe_prompt: bool,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
pub temperature: f32,
|
||||||
pub temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_choice: Option<tool::ToolChoice>,
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tools: Option<Vec<tool::Tool>>,
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
pub top_p: f32,
|
||||||
pub top_p: Option<f32>,
|
|
||||||
// TODO Check this prop (seen in official Python client but not in API doc).
|
|
||||||
// pub tool_choice: Option<String>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
|
||||||
pub response_format: Option<ResponseFormat>,
|
|
||||||
}
|
}
|
||||||
impl ChatRequest {
|
impl ChatRequest {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
|
|||||||
@@ -115,12 +115,16 @@ pub enum ToolType {
|
|||||||
Function,
|
Function,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An enum representing how functions should be called.
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
pub enum ToolChoice {
|
pub enum ToolChoice {
|
||||||
|
/// The model is forced to call a function.
|
||||||
#[serde(rename = "any")]
|
#[serde(rename = "any")]
|
||||||
Any,
|
Any,
|
||||||
|
/// The model can choose to either generate a message or call a function.
|
||||||
#[serde(rename = "auto")]
|
#[serde(rename = "auto")]
|
||||||
Auto,
|
Auto,
|
||||||
|
/// The model won't call a function and will generate a message instead.
|
||||||
#[serde(rename = "none")]
|
#[serde(rename = "none")]
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ async fn test_client_chat_async() {
|
|||||||
"Guess the next word: \"Eiffel ...\"?",
|
"Guess the next word: \"Eiffel ...\"?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -37,7 +37,12 @@ async fn test_client_chat_async() {
|
|||||||
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
||||||
|
|
||||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
expect!(response.choices[0].message.content.clone()).to_start_with("Tower".to_string());
|
expect!(response.choices[0]
|
||||||
|
.message
|
||||||
|
.content
|
||||||
|
.clone()
|
||||||
|
.contains("Tower"))
|
||||||
|
.to_be(true);
|
||||||
|
|
||||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
@@ -65,7 +70,7 @@ async fn test_client_chat_async_with_function_calling() {
|
|||||||
"What's the current temperature in Paris?",
|
"What's the current temperature in Paris?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Any),
|
tool_choice: Some(ToolChoice::Any),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ fn test_client_chat() {
|
|||||||
"Guess the next word: \"Eiffel ...\"?",
|
"Guess the next word: \"Eiffel ...\"?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -31,7 +31,12 @@ fn test_client_chat() {
|
|||||||
expect!(response.choices.len()).to_be(1);
|
expect!(response.choices.len()).to_be(1);
|
||||||
expect!(response.choices[0].index).to_be(0);
|
expect!(response.choices[0].index).to_be(0);
|
||||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
expect!(response.choices[0].message.content.clone()).to_start_with("Tower".to_string());
|
expect!(response.choices[0]
|
||||||
|
.message
|
||||||
|
.content
|
||||||
|
.clone()
|
||||||
|
.contains("Tower"))
|
||||||
|
.to_be(true);
|
||||||
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
||||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
@@ -59,7 +64,7 @@ fn test_client_chat_with_function_calling() {
|
|||||||
"What's the current temperature in Paris?",
|
"What's the current temperature in Paris?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
Reference in New Issue
Block a user