8 Commits

Author SHA1 Message Date
Ivan Gabriele
f7d012b280 ci(release): v0.10.0 2024-06-07 16:55:50 +02:00
Ivan Gabriele
5b5bd2d68e docs(changelog): update 2024-06-07 16:55:39 +02:00
Xavier Gillard
2fc0642a5e feat(chat): add the 'system' and 'tool' message roles (#10)
* add the 'system' and 'tool' message roles (see: https://docs.mistral.ai/capabilities/completion/ )

* docs(chat): add offical doc link in ChatMessageRole

* ci(github): listen to pull_request event in Test workflow

---------

Co-authored-by: Ivan Gabriele <ivan.gabriele@protonmail.com>
2024-06-07 16:49:55 +02:00
Ivan Gabriele
cf68a77320 feat(chat)!: change safe_prompt, temperature & top_p to non-Option types
BREAKING CHANGE:
- `Chat::ChatParams.safe_prompt` & `Chat::ChatRequest.safe_prompt` are now `bool` instead of `Option<bool>`. Default is `false`.
- `Chat::ChatParams.temperature` & `Chat::ChatRequest.temperature` are now `f32` instead of `Option<f32>`. Default is `0.7`.
- `Chat::ChatParams.top_p` & `Chat::ChatRequest.top_p` are now `f32` instead of `Option<f32>`. Default is `1.0`.
2024-06-07 16:00:10 +02:00
Ivan Gabriele
e61ace9a18 test(chat): simplify chat response message content check to cover variations 2024-06-07 14:36:49 +02:00
Ivan Gabriele
64034402ca build(makefile): fix env file loading 2024-06-07 14:36:48 +02:00
Nick Anderson
85c3611afb feat(chat): add response_format for JSON return values 2024-06-07 14:36:37 +02:00
seurimas
da5fe54115 fix(chat): skip serializing tool_calls if null, to avoid 422 error 2024-06-07 14:12:22 +02:00
15 changed files with 144 additions and 43 deletions

View File

@@ -1,6 +1,8 @@
name: Test name: Test
on: push on:
pull_request:
push:
jobs: jobs:
test: test:

View File

@@ -1,3 +1,20 @@
## [0.10.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.9.0...v) (2024-06-07)
### ⚠ BREAKING CHANGES
* **chat:** - `Chat::ChatParams.safe_prompt` & `Chat::ChatRequest.safe_prompt` are now `bool` instead of `Option<bool>`. Default is `false`.
- `Chat::ChatParams.temperature` & `Chat::ChatRequest.temperature` are now `f32` instead of `Option<f32>`. Default is `0.7`.
- `Chat::ChatParams.top_p` & `Chat::ChatRequest.top_p` are now `f32` instead of `Option<f32>`. Default is `1.0`.
### Features
* **chat:** add response_format for JSON return values ([85c3611](https://github.com/ivangabriele/mistralai-client-rs/commit/85c3611afbbe8df30dfc7512cc381ed304ce4024))
* **chat:** add the 'system' and 'tool' message roles ([#10](https://github.com/ivangabriele/mistralai-client-rs/issues/10)) ([2fc0642](https://github.com/ivangabriele/mistralai-client-rs/commit/2fc0642a5e4c024b15710acaab7735480e8dfe6a))
* **chat:** change safe_prompt, temperature & top_p to non-Option types ([cf68a77](https://github.com/ivangabriele/mistralai-client-rs/commit/cf68a773201ebe0e802face52af388711acf0c27))
### Bug Fixes
* **chat:** skip serializing tool_calls if null, to avoid 422 error ([da5fe54](https://github.com/ivangabriele/mistralai-client-rs/commit/da5fe54115ce622379776661a440e2708b24810c))
## [0.9.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.8.0...v) (2024-04-13) ## [0.9.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.8.0...v) (2024-04-13)

View File

@@ -2,7 +2,7 @@
name = "mistralai-client" name = "mistralai-client"
description = "Mistral AI API client library for Rust (unofficial)." description = "Mistral AI API client library for Rust (unofficial)."
license = "Apache-2.0" license = "Apache-2.0"
version = "0.9.0" version = "0.10.0"
edition = "2021" edition = "2021"
rust-version = "1.76.0" rust-version = "1.76.0"

View File

@@ -53,20 +53,20 @@ release-major:
$(call RELEASE_TEMPLATE,major) $(call RELEASE_TEMPLATE,major)
test: test:
@$(source_env_if_not_ci) @$(source_env_if_not_ci) && \
cargo test --no-fail-fast cargo test --no-fail-fast
test-cover: test-cover:
@$(source_env_if_not_ci) @$(source_env_if_not_ci) && \
cargo llvm-cov cargo llvm-cov
test-doc: test-doc:
@$(source_env_if_not_ci) @$(source_env_if_not_ci) && \
cargo test --doc --no-fail-fast cargo test --doc --no-fail-fast
test-examples: test-examples:
@$(source_env_if_not_ci) @$(source_env_if_not_ci) && \
@for example in $$(ls examples/*.rs | sed 's/examples\/\(.*\)\.rs/\1/'); do \ for example in $$(ls examples/*.rs | sed 's/examples\/\(.*\)\.rs/\1/'); do \
echo "Running $$example"; \ echo "Running $$example"; \
cargo run --example $$example; \ cargo run --example $$example; \
done done
test-watch: test-watch:
@source ./.env @source ./.env && \
cargo watch -x "test -- --nocapture" cargo watch -x "test -- --nocapture"

View File

@@ -7,6 +7,10 @@
Rust client for the Mistral AI API. Rust client for the Mistral AI API.
> [!IMPORTANT]
> While we are in v0, minor versions may introduce breaking changes.
> Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information.
--- ---
- [Supported APIs](#supported-apis) - [Supported APIs](#supported-apis)
@@ -102,7 +106,7 @@ fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
@@ -134,7 +138,7 @@ async fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
@@ -174,7 +178,7 @@ async fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
@@ -259,7 +263,7 @@ fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
tools: Some(tools), tools: Some(tools),
@@ -336,7 +340,7 @@ async fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
tools: Some(tools), tools: Some(tools),

View File

@@ -7,6 +7,10 @@
Rust client for the Mistral AI API. Rust client for the Mistral AI API.
> [!IMPORTANT]
> While we are in v0, minor versions may introduce breaking changes.
> Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information.
--- ---
- [Supported APIs](#supported-apis) - [Supported APIs](#supported-apis)

View File

@@ -15,7 +15,7 @@ fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };

View File

@@ -16,7 +16,7 @@ async fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };

View File

@@ -53,7 +53,7 @@ fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
tools: Some(tools), tools: Some(tools),

View File

@@ -54,7 +54,7 @@ async fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
tools: Some(tools), tools: Some(tools),

View File

@@ -18,7 +18,7 @@ async fn main() {
tool_calls: None, tool_calls: None,
}]; }];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };

View File

@@ -9,6 +9,7 @@ use crate::v1::{common, constants, tool};
pub struct ChatMessage { pub struct ChatMessage {
pub role: ChatMessageRole, pub role: ChatMessageRole,
pub content: String, pub content: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_calls: Option<Vec<tool::ToolCall>>, pub tool_calls: Option<Vec<tool::ToolCall>>,
} }
impl ChatMessage { impl ChatMessage {
@@ -29,37 +30,101 @@ impl ChatMessage {
} }
} }
/// See the [Mistral AI API documentation](https://docs.mistral.ai/capabilities/completion/#chat-messages) for more information.
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ChatMessageRole { pub enum ChatMessageRole {
#[serde(rename = "system")]
System,
#[serde(rename = "assistant")] #[serde(rename = "assistant")]
Assistant, Assistant,
#[serde(rename = "user")] #[serde(rename = "user")]
User, User,
#[serde(rename = "tool")]
Tool,
}
/// The format that the model must output.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Debug, Serialize, Deserialize)]
pub struct ResponseFormat {
#[serde(rename = "type")]
pub type_: String,
}
impl ResponseFormat {
pub fn json_object() -> Self {
Self {
type_: "json_object".to_string(),
}
}
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// Request // Request
/// The parameters for the chat request.
///
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
#[derive(Debug)] #[derive(Debug)]
pub struct ChatParams { pub struct ChatParams {
/// The maximum number of tokens to generate in the completion.
///
/// Defaults to `None`.
pub max_tokens: Option<u32>, pub max_tokens: Option<u32>,
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
///
/// Defaults to `None`.
pub random_seed: Option<u32>, pub random_seed: Option<u32>,
pub safe_prompt: Option<bool>, /// The format that the model must output.
pub temperature: Option<f32>, ///
/// Defaults to `None`.
pub response_format: Option<ResponseFormat>,
/// Whether to inject a safety prompt before all conversations.
///
/// Defaults to `false`.
pub safe_prompt: bool,
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
///
/// Defaults to `0.7`.
pub temperature: f32,
/// Specifies if/how functions are called.
///
/// Defaults to `None`.
pub tool_choice: Option<tool::ToolChoice>, pub tool_choice: Option<tool::ToolChoice>,
/// A list of available tools for the model.
///
/// Defaults to `None`.
pub tools: Option<Vec<tool::Tool>>, pub tools: Option<Vec<tool::Tool>>,
pub top_p: Option<f32>, /// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
///
/// Defaults to `1.0`.
pub top_p: f32,
} }
impl Default for ChatParams { impl Default for ChatParams {
fn default() -> Self { fn default() -> Self {
Self { Self {
max_tokens: None, max_tokens: None,
random_seed: None, random_seed: None,
safe_prompt: None, safe_prompt: false,
temperature: None, response_format: None,
temperature: 0.7,
tool_choice: None, tool_choice: None,
tools: None, tools: None,
top_p: None, top_p: 1.0,
}
}
}
impl ChatParams {
pub fn json_default() -> Self {
Self {
max_tokens: None,
random_seed: None,
safe_prompt: false,
response_format: None,
temperature: 0.7,
tool_choice: None,
tools: None,
top_p: 1.0,
} }
} }
} }
@@ -74,20 +139,15 @@ pub struct ChatRequest {
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option<u32>, pub random_seed: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub safe_prompt: Option<bool>, pub response_format: Option<ResponseFormat>,
pub safe_prompt: bool,
pub stream: bool, pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")] pub temperature: f32,
pub temperature: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tool_choice: Option<tool::ToolChoice>, pub tool_choice: Option<tool::ToolChoice>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<tool::Tool>>, pub tools: Option<Vec<tool::Tool>>,
#[serde(skip_serializing_if = "Option::is_none")] pub top_p: f32,
pub top_p: Option<f32>,
// TODO Check this prop (seen in official Python client but not in API doc).
// pub tool_choice: Option<String>,
// TODO Check this prop (seen in official Python client but not in API doc).
// pub response_format: Option<String>,
} }
impl ChatRequest { impl ChatRequest {
pub fn new( pub fn new(
@@ -104,6 +164,7 @@ impl ChatRequest {
tool_choice, tool_choice,
tools, tools,
top_p, top_p,
response_format,
} = options.unwrap_or_default(); } = options.unwrap_or_default();
Self { Self {
@@ -118,6 +179,7 @@ impl ChatRequest {
tool_choice, tool_choice,
tools, tools,
top_p, top_p,
response_format,
} }
} }
} }

View File

@@ -115,12 +115,16 @@ pub enum ToolType {
Function, Function,
} }
/// An enum representing how functions should be called.
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)] #[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ToolChoice { pub enum ToolChoice {
/// The model is forced to call a function.
#[serde(rename = "any")] #[serde(rename = "any")]
Any, Any,
/// The model can choose to either generate a message or call a function.
#[serde(rename = "auto")] #[serde(rename = "auto")]
Auto, Auto,
/// The model won't call a function and will generate a message instead.
#[serde(rename = "none")] #[serde(rename = "none")]
None, None,
} }

View File

@@ -16,10 +16,10 @@ async fn test_client_chat_async() {
let model = Model::OpenMistral7b; let model = Model::OpenMistral7b;
let messages = vec![ChatMessage::new_user_message( let messages = vec![ChatMessage::new_user_message(
"Just guess the next word: \"Eiffel ...\"?", "Guess the next word: \"Eiffel ...\"?",
)]; )];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
@@ -37,8 +37,12 @@ async fn test_client_chat_async() {
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
expect!(response.choices[0].message.content.clone()) expect!(response.choices[0]
.to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string()); .message
.content
.clone()
.contains("Tower"))
.to_be(true);
expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.prompt_tokens).to_be_greater_than(0);
expect!(response.usage.completion_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0);
@@ -66,7 +70,7 @@ async fn test_client_chat_async_with_function_calling() {
"What's the current temperature in Paris?", "What's the current temperature in Paris?",
)]; )];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Any), tool_choice: Some(ToolChoice::Any),
tools: Some(tools), tools: Some(tools),

View File

@@ -16,10 +16,10 @@ fn test_client_chat() {
let model = Model::OpenMistral7b; let model = Model::OpenMistral7b;
let messages = vec![ChatMessage::new_user_message( let messages = vec![ChatMessage::new_user_message(
"Just guess the next word: \"Eiffel ...\"?", "Guess the next word: \"Eiffel ...\"?",
)]; )];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
..Default::default() ..Default::default()
}; };
@@ -31,8 +31,12 @@ fn test_client_chat() {
expect!(response.choices.len()).to_be(1); expect!(response.choices.len()).to_be(1);
expect!(response.choices[0].index).to_be(0); expect!(response.choices[0].index).to_be(0);
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant); expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
expect!(response.choices[0].message.content.clone()) expect!(response.choices[0]
.to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string()); .message
.content
.clone()
.contains("Tower"))
.to_be(true);
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop); expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
expect!(response.usage.prompt_tokens).to_be_greater_than(0); expect!(response.usage.prompt_tokens).to_be_greater_than(0);
expect!(response.usage.completion_tokens).to_be_greater_than(0); expect!(response.usage.completion_tokens).to_be_greater_than(0);
@@ -60,7 +64,7 @@ fn test_client_chat_with_function_calling() {
"What's the current temperature in Paris?", "What's the current temperature in Paris?",
)]; )];
let options = ChatParams { let options = ChatParams {
temperature: Some(0.0), temperature: 0.0,
random_seed: Some(42), random_seed: Some(42),
tool_choice: Some(ToolChoice::Auto), tool_choice: Some(ToolChoice::Auto),
tools: Some(tools), tools: Some(tools),