Compare commits
24 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67aa5bbaef | ||
|
|
415fd98167 | ||
|
|
8e9f7a5386 | ||
|
|
3afeec1d58 | ||
|
|
0c097aa56d | ||
|
|
e6539c0ccf | ||
|
|
30156c5273 | ||
|
|
ecd0c3028f | ||
|
|
0df67b1b25 | ||
|
|
f7d012b280 | ||
|
|
5b5bd2d68e | ||
|
|
2fc0642a5e | ||
|
|
cf68a77320 | ||
|
|
e61ace9a18 | ||
|
|
64034402ca | ||
|
|
85c3611afb | ||
|
|
da5fe54115 | ||
|
|
7a5e0679c1 | ||
|
|
99d9d099e2 | ||
|
|
91fb775132 | ||
|
|
7474aa6730 | ||
|
|
6a99eca49c | ||
|
|
fccd59c0cc | ||
|
|
a463cb3106 |
15
.github/ISSUE_TEMPLATE/bug_report.md
vendored
15
.github/ISSUE_TEMPLATE/bug_report.md
vendored
@@ -1,17 +1,16 @@
|
|||||||
---
|
---
|
||||||
name: Bug report
|
name: Bug report
|
||||||
about: Create a report to help us improve
|
about: Create a report to help us improve
|
||||||
title: ''
|
title: ""
|
||||||
labels: ''
|
labels: "bug"
|
||||||
assignees: ''
|
assignees: ""
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Describe the bug**
|
**Describe the bug**
|
||||||
|
|
||||||
A clear and concise description of what the bug is.
|
...
|
||||||
|
|
||||||
**To Reproduce**
|
**Reproduction**
|
||||||
|
|
||||||
Steps to reproduce the behavior:
|
Steps to reproduce the behavior:
|
||||||
|
|
||||||
@@ -20,7 +19,7 @@ Steps to reproduce the behavior:
|
|||||||
|
|
||||||
**Expected behavior**
|
**Expected behavior**
|
||||||
|
|
||||||
A clear and concise description of what you expected to happen.
|
...
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
@@ -32,4 +31,4 @@ If applicable, what version did you use?
|
|||||||
|
|
||||||
**Environment**
|
**Environment**
|
||||||
|
|
||||||
Add useful information about your configuration and environment here.
|
If applicable, add relevant information about your config and environment here.
|
||||||
|
|||||||
25
.github/ISSUE_TEMPLATE/feature_request.md
vendored
25
.github/ISSUE_TEMPLATE/feature_request.md
vendored
@@ -1,24 +1,19 @@
|
|||||||
---
|
---
|
||||||
name: Feature request
|
name: Feature request
|
||||||
about: Suggest an idea for this project
|
about: Suggest a new idea for the project.
|
||||||
title: ''
|
title: ""
|
||||||
labels: ''
|
labels: "enhancement"
|
||||||
assignees: ''
|
assignees: ""
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**Is your feature request related to a problem? Please describe.**
|
**Is your feature request related to some problems?**
|
||||||
|
|
||||||
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
|
- _Ex. I'm always frustrated when..._
|
||||||
|
|
||||||
**Describe the solution you'd like**
|
**What are the solutions you'd like?**
|
||||||
|
|
||||||
A clear and concise description of what you want to happen.
|
- _Ex. A new option to..._
|
||||||
|
|
||||||
**Describe alternatives you've considered**
|
**Anything else?**
|
||||||
|
|
||||||
A clear and concise description of any alternative solutions or features you've considered.
|
- ...
|
||||||
|
|
||||||
**Additional context**
|
|
||||||
|
|
||||||
Add any other context or screenshots about the feature request here.
|
|
||||||
|
|||||||
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -1,6 +1,8 @@
|
|||||||
name: Test
|
name: Test
|
||||||
|
|
||||||
on: push
|
on:
|
||||||
|
pull_request:
|
||||||
|
push:
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
|||||||
44
CHANGELOG.md
44
CHANGELOG.md
@@ -1,3 +1,47 @@
|
|||||||
|
## [0.12.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.11.0...v) (2024-07-24)
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* implement the Debug trait for Client ([#11](https://github.com/ivangabriele/mistralai-client-rs/issues/11)) ([3afeec1](https://github.com/ivangabriele/mistralai-client-rs/commit/3afeec1d586022e43c7b10906acec5e65927ba7d))
|
||||||
|
* mark Function trait as Send ([#12](https://github.com/ivangabriele/mistralai-client-rs/issues/12)) ([8e9f7a5](https://github.com/ivangabriele/mistralai-client-rs/commit/8e9f7a53863879b2ad618e9e5707b198e4f3b135))
|
||||||
|
## [0.11.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.10.0...v) (2024-06-22)
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* **constants:** add OpenMixtral8x22b, MistralTiny & CodestralLatest to Model enum ([ecd0c30](https://github.com/ivangabriele/mistralai-client-rs/commit/ecd0c3028fdcfab32b867eb1eed86182f5f4ab81))
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* **chat:** implement Clone trait for ChatParams & ResponseFormat ([0df67b1](https://github.com/ivangabriele/mistralai-client-rs/commit/0df67b1b2571fb04b636ce015a2daabe629ff352))
|
||||||
|
## [0.10.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.9.0...v) (2024-06-07)
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* **chat:** - `Chat::ChatParams.safe_prompt` & `Chat::ChatRequest.safe_prompt` are now `bool` instead of `Option<bool>`. Default is `false`.
|
||||||
|
- `Chat::ChatParams.temperature` & `Chat::ChatRequest.temperature` are now `f32` instead of `Option<f32>`. Default is `0.7`.
|
||||||
|
- `Chat::ChatParams.top_p` & `Chat::ChatRequest.top_p` are now `f32` instead of `Option<f32>`. Default is `1.0`.
|
||||||
|
|
||||||
|
### Features
|
||||||
|
|
||||||
|
* **chat:** add response_format for JSON return values ([85c3611](https://github.com/ivangabriele/mistralai-client-rs/commit/85c3611afbbe8df30dfc7512cc381ed304ce4024))
|
||||||
|
* **chat:** add the 'system' and 'tool' message roles ([#10](https://github.com/ivangabriele/mistralai-client-rs/issues/10)) ([2fc0642](https://github.com/ivangabriele/mistralai-client-rs/commit/2fc0642a5e4c024b15710acaab7735480e8dfe6a))
|
||||||
|
* **chat:** change safe_prompt, temperature & top_p to non-Option types ([cf68a77](https://github.com/ivangabriele/mistralai-client-rs/commit/cf68a773201ebe0e802face52af388711acf0c27))
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* **chat:** skip serializing tool_calls if null, to avoid 422 error ([da5fe54](https://github.com/ivangabriele/mistralai-client-rs/commit/da5fe54115ce622379776661a440e2708b24810c))
|
||||||
|
## [0.9.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.8.0...v) (2024-04-13)
|
||||||
|
|
||||||
|
|
||||||
|
### ⚠ BREAKING CHANGES
|
||||||
|
|
||||||
|
* `Model.OpenMistral8x7b` has been renamed to `Model.OpenMixtral8x7b`.
|
||||||
|
|
||||||
|
### Bug Fixes
|
||||||
|
|
||||||
|
* **deps:** update rust crate reqwest to 0.12.0 ([#6](https://github.com/ivangabriele/mistralai-client-rs/issues/6)) ([fccd59c](https://github.com/ivangabriele/mistralai-client-rs/commit/fccd59c0cc783edddec1b404363faabb009eecd6))
|
||||||
|
* fix typo in OpenMixtral8x7b model name ([#8](https://github.com/ivangabriele/mistralai-client-rs/issues/8)) ([6a99eca](https://github.com/ivangabriele/mistralai-client-rs/commit/6a99eca49c0cc8e3764a56f6dfd7762ec44a4c3b))
|
||||||
|
|
||||||
## [0.8.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.7.0...v) (2024-03-09)
|
## [0.8.0](https://github.com/ivangabriele/mistralai-client-rs/compare/v0.7.0...v) (2024-03-09)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
name = "mistralai-client"
|
name = "mistralai-client"
|
||||||
description = "Mistral AI API client library for Rust (unofficial)."
|
description = "Mistral AI API client library for Rust (unofficial)."
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
version = "0.8.0"
|
version = "0.12.0"
|
||||||
|
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
rust-version = "1.76.0"
|
rust-version = "1.76.0"
|
||||||
@@ -20,7 +20,7 @@ async-trait = "0.1.77"
|
|||||||
env_logger = "0.11.3"
|
env_logger = "0.11.3"
|
||||||
futures = "0.3.30"
|
futures = "0.3.30"
|
||||||
log = "0.4.21"
|
log = "0.4.21"
|
||||||
reqwest = { version = "0.11.24", features = ["json", "blocking", "stream"] }
|
reqwest = { version = "0.12.0", features = ["json", "blocking", "stream"] }
|
||||||
serde = { version = "1.0.197", features = ["derive"] }
|
serde = { version = "1.0.197", features = ["derive"] }
|
||||||
serde_json = "1.0.114"
|
serde_json = "1.0.114"
|
||||||
strum = "0.26.1"
|
strum = "0.26.1"
|
||||||
|
|||||||
14
Makefile
14
Makefile
@@ -14,7 +14,7 @@ define source_env_if_not_ci
|
|||||||
endef
|
endef
|
||||||
|
|
||||||
define RELEASE_TEMPLATE
|
define RELEASE_TEMPLATE
|
||||||
conventional-changelog -p conventionalcommits -i ./CHANGELOG.md -s
|
npx conventional-changelog-cli -p conventionalcommits -i ./CHANGELOG.md -s
|
||||||
git add .
|
git add .
|
||||||
git commit -m "docs(changelog): update"
|
git commit -m "docs(changelog): update"
|
||||||
git push origin HEAD
|
git push origin HEAD
|
||||||
@@ -53,20 +53,20 @@ release-major:
|
|||||||
$(call RELEASE_TEMPLATE,major)
|
$(call RELEASE_TEMPLATE,major)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
@$(source_env_if_not_ci)
|
@$(source_env_if_not_ci) && \
|
||||||
cargo test --no-fail-fast
|
cargo test --no-fail-fast
|
||||||
test-cover:
|
test-cover:
|
||||||
@$(source_env_if_not_ci)
|
@$(source_env_if_not_ci) && \
|
||||||
cargo llvm-cov
|
cargo llvm-cov
|
||||||
test-doc:
|
test-doc:
|
||||||
@$(source_env_if_not_ci)
|
@$(source_env_if_not_ci) && \
|
||||||
cargo test --doc --no-fail-fast
|
cargo test --doc --no-fail-fast
|
||||||
test-examples:
|
test-examples:
|
||||||
@$(source_env_if_not_ci)
|
@$(source_env_if_not_ci) && \
|
||||||
@for example in $$(ls examples/*.rs | sed 's/examples\/\(.*\)\.rs/\1/'); do \
|
for example in $$(ls examples/*.rs | sed 's/examples\/\(.*\)\.rs/\1/'); do \
|
||||||
echo "Running $$example"; \
|
echo "Running $$example"; \
|
||||||
cargo run --example $$example; \
|
cargo run --example $$example; \
|
||||||
done
|
done
|
||||||
test-watch:
|
test-watch:
|
||||||
@source ./.env
|
@source ./.env && \
|
||||||
cargo watch -x "test -- --nocapture"
|
cargo watch -x "test -- --nocapture"
|
||||||
|
|||||||
14
README.md
14
README.md
@@ -7,6 +7,10 @@
|
|||||||
|
|
||||||
Rust client for the Mistral AI API.
|
Rust client for the Mistral AI API.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> While we are in v0, minor versions may introduce breaking changes.
|
||||||
|
> Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
- [Supported APIs](#supported-apis)
|
- [Supported APIs](#supported-apis)
|
||||||
@@ -102,7 +106,7 @@ fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -134,7 +138,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -174,7 +178,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -259,7 +263,7 @@ fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
@@ -336,7 +340,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -7,6 +7,10 @@
|
|||||||
|
|
||||||
Rust client for the Mistral AI API.
|
Rust client for the Mistral AI API.
|
||||||
|
|
||||||
|
> [!IMPORTANT]
|
||||||
|
> While we are in v0, minor versions may introduce breaking changes.
|
||||||
|
> Please, refer to the [CHANGELOG.md](./CHANGELOG.md) for more information.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
- [Supported APIs](#supported-apis)
|
- [Supported APIs](#supported-apis)
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ async fn main() {
|
|||||||
tool_calls: None,
|
tool_calls: None,
|
||||||
}];
|
}];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ use crate::v1::{common, constants, tool};
|
|||||||
pub struct ChatMessage {
|
pub struct ChatMessage {
|
||||||
pub role: ChatMessageRole,
|
pub role: ChatMessageRole,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
pub tool_calls: Option<Vec<tool::ToolCall>>,
|
||||||
}
|
}
|
||||||
impl ChatMessage {
|
impl ChatMessage {
|
||||||
@@ -29,37 +30,101 @@ impl ChatMessage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// See the [Mistral AI API documentation](https://docs.mistral.ai/capabilities/completion/#chat-messages) for more information.
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
pub enum ChatMessageRole {
|
pub enum ChatMessageRole {
|
||||||
|
#[serde(rename = "system")]
|
||||||
|
System,
|
||||||
#[serde(rename = "assistant")]
|
#[serde(rename = "assistant")]
|
||||||
Assistant,
|
Assistant,
|
||||||
#[serde(rename = "user")]
|
#[serde(rename = "user")]
|
||||||
User,
|
User,
|
||||||
|
#[serde(rename = "tool")]
|
||||||
|
Tool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The format that the model must output.
|
||||||
|
///
|
||||||
|
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||||
|
pub struct ResponseFormat {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub type_: String,
|
||||||
|
}
|
||||||
|
impl ResponseFormat {
|
||||||
|
pub fn json_object() -> Self {
|
||||||
|
Self {
|
||||||
|
type_: "json_object".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Request
|
// Request
|
||||||
|
|
||||||
#[derive(Debug)]
|
/// The parameters for the chat request.
|
||||||
|
///
|
||||||
|
/// See the [API documentation](https://docs.mistral.ai/api/#operation/createChatCompletion) for more information.
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
pub struct ChatParams {
|
pub struct ChatParams {
|
||||||
|
/// The maximum number of tokens to generate in the completion.
|
||||||
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
pub max_tokens: Option<u32>,
|
pub max_tokens: Option<u32>,
|
||||||
|
/// The seed to use for random sampling. If set, different calls will generate deterministic results.
|
||||||
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
pub random_seed: Option<u32>,
|
pub random_seed: Option<u32>,
|
||||||
pub safe_prompt: Option<bool>,
|
/// The format that the model must output.
|
||||||
pub temperature: Option<f32>,
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
/// Whether to inject a safety prompt before all conversations.
|
||||||
|
///
|
||||||
|
/// Defaults to `false`.
|
||||||
|
pub safe_prompt: bool,
|
||||||
|
/// What sampling temperature to use, between `Some(0.0)` and `Some(1.0)`.
|
||||||
|
///
|
||||||
|
/// Defaults to `0.7`.
|
||||||
|
pub temperature: f32,
|
||||||
|
/// Specifies if/how functions are called.
|
||||||
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
pub tool_choice: Option<tool::ToolChoice>,
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
|
/// A list of available tools for the model.
|
||||||
|
///
|
||||||
|
/// Defaults to `None`.
|
||||||
pub tools: Option<Vec<tool::Tool>>,
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
pub top_p: Option<f32>,
|
/// Nucleus sampling, where the model considers the results of the tokens with `top_p` probability mass.
|
||||||
|
///
|
||||||
|
/// Defaults to `1.0`.
|
||||||
|
pub top_p: f32,
|
||||||
}
|
}
|
||||||
impl Default for ChatParams {
|
impl Default for ChatParams {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
max_tokens: None,
|
max_tokens: None,
|
||||||
random_seed: None,
|
random_seed: None,
|
||||||
safe_prompt: None,
|
safe_prompt: false,
|
||||||
temperature: None,
|
response_format: None,
|
||||||
|
temperature: 0.7,
|
||||||
tool_choice: None,
|
tool_choice: None,
|
||||||
tools: None,
|
tools: None,
|
||||||
top_p: None,
|
top_p: 1.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl ChatParams {
|
||||||
|
pub fn json_default() -> Self {
|
||||||
|
Self {
|
||||||
|
max_tokens: None,
|
||||||
|
random_seed: None,
|
||||||
|
safe_prompt: false,
|
||||||
|
response_format: None,
|
||||||
|
temperature: 0.7,
|
||||||
|
tool_choice: None,
|
||||||
|
tools: None,
|
||||||
|
top_p: 1.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -74,20 +139,15 @@ pub struct ChatRequest {
|
|||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub random_seed: Option<u32>,
|
pub random_seed: Option<u32>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub safe_prompt: Option<bool>,
|
pub response_format: Option<ResponseFormat>,
|
||||||
|
pub safe_prompt: bool,
|
||||||
pub stream: bool,
|
pub stream: bool,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
pub temperature: f32,
|
||||||
pub temperature: Option<f32>,
|
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tool_choice: Option<tool::ToolChoice>,
|
pub tool_choice: Option<tool::ToolChoice>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
pub tools: Option<Vec<tool::Tool>>,
|
pub tools: Option<Vec<tool::Tool>>,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
pub top_p: f32,
|
||||||
pub top_p: Option<f32>,
|
|
||||||
// TODO Check this prop (seen in official Python client but not in API doc).
|
|
||||||
// pub tool_choice: Option<String>,
|
|
||||||
// TODO Check this prop (seen in official Python client but not in API doc).
|
|
||||||
// pub response_format: Option<String>,
|
|
||||||
}
|
}
|
||||||
impl ChatRequest {
|
impl ChatRequest {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
@@ -104,6 +164,7 @@ impl ChatRequest {
|
|||||||
tool_choice,
|
tool_choice,
|
||||||
tools,
|
tools,
|
||||||
top_p,
|
top_p,
|
||||||
|
response_format,
|
||||||
} = options.unwrap_or_default();
|
} = options.unwrap_or_default();
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
@@ -118,6 +179,7 @@ impl ChatRequest {
|
|||||||
tool_choice,
|
tool_choice,
|
||||||
tools,
|
tools,
|
||||||
top_p,
|
top_p,
|
||||||
|
response_format,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ use std::{
|
|||||||
|
|
||||||
use crate::v1::{chat, chat_stream, constants, embedding, error, model_list, tool, utils};
|
use crate::v1::{chat, chat_stream, constants, embedding, error, model_list, tool, utils};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
pub struct Client {
|
pub struct Client {
|
||||||
pub api_key: String,
|
pub api_key: String,
|
||||||
pub endpoint: String,
|
pub endpoint: String,
|
||||||
|
|||||||
@@ -6,14 +6,20 @@ pub const API_URL_BASE: &str = "https://api.mistral.ai/v1";
|
|||||||
pub enum Model {
|
pub enum Model {
|
||||||
#[serde(rename = "open-mistral-7b")]
|
#[serde(rename = "open-mistral-7b")]
|
||||||
OpenMistral7b,
|
OpenMistral7b,
|
||||||
#[serde(rename = "open-mistral-8x7b")]
|
#[serde(rename = "open-mixtral-8x7b")]
|
||||||
OpenMistral8x7b,
|
OpenMixtral8x7b,
|
||||||
|
#[serde(rename = "open-mixtral-8x22b")]
|
||||||
|
OpenMixtral8x22b,
|
||||||
|
#[serde(rename = "mistral-tiny")]
|
||||||
|
MistralTiny,
|
||||||
#[serde(rename = "mistral-small-latest")]
|
#[serde(rename = "mistral-small-latest")]
|
||||||
MistralSmallLatest,
|
MistralSmallLatest,
|
||||||
#[serde(rename = "mistral-medium-latest")]
|
#[serde(rename = "mistral-medium-latest")]
|
||||||
MistralMediumLatest,
|
MistralMediumLatest,
|
||||||
#[serde(rename = "mistral-large-latest")]
|
#[serde(rename = "mistral-large-latest")]
|
||||||
MistralLargeLatest,
|
MistralLargeLatest,
|
||||||
|
#[serde(rename = "codestral-latest")]
|
||||||
|
CodestralLatest,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{any::Any, collections::HashMap};
|
use std::{any::Any, collections::HashMap, fmt::Debug};
|
||||||
|
|
||||||
// -----------------------------------------------------------------------------
|
// -----------------------------------------------------------------------------
|
||||||
// Definitions
|
// Definitions
|
||||||
@@ -115,12 +115,16 @@ pub enum ToolType {
|
|||||||
Function,
|
Function,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// An enum representing how functions should be called.
|
||||||
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
|
||||||
pub enum ToolChoice {
|
pub enum ToolChoice {
|
||||||
|
/// The model is forced to call a function.
|
||||||
#[serde(rename = "any")]
|
#[serde(rename = "any")]
|
||||||
Any,
|
Any,
|
||||||
|
/// The model can choose to either generate a message or call a function.
|
||||||
#[serde(rename = "auto")]
|
#[serde(rename = "auto")]
|
||||||
Auto,
|
Auto,
|
||||||
|
/// The model won't call a function and will generate a message instead.
|
||||||
#[serde(rename = "none")]
|
#[serde(rename = "none")]
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
@@ -129,6 +133,12 @@ pub enum ToolChoice {
|
|||||||
// Custom
|
// Custom
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Function {
|
pub trait Function: Send {
|
||||||
async fn execute(&self, arguments: String) -> Box<dyn Any + Send>;
|
async fn execute(&self, arguments: String) -> Box<dyn Any + Send>;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl Debug for dyn Function {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "Function()")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ async fn test_client_chat_async() {
|
|||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::OpenMistral7b;
|
||||||
let messages = vec![ChatMessage::new_user_message(
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
"Just guess the next word: \"Eiffel ...\"?",
|
"Guess the next word: \"Eiffel ...\"?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -37,8 +37,12 @@ async fn test_client_chat_async() {
|
|||||||
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
||||||
|
|
||||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
expect!(response.choices[0].message.content.clone())
|
expect!(response.choices[0]
|
||||||
.to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string());
|
.message
|
||||||
|
.content
|
||||||
|
.clone()
|
||||||
|
.contains("Tower"))
|
||||||
|
.to_be(true);
|
||||||
|
|
||||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
@@ -66,7 +70,7 @@ async fn test_client_chat_async_with_function_calling() {
|
|||||||
"What's the current temperature in Paris?",
|
"What's the current temperature in Paris?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Any),
|
tool_choice: Some(ToolChoice::Any),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -16,10 +16,10 @@ fn test_client_chat() {
|
|||||||
|
|
||||||
let model = Model::OpenMistral7b;
|
let model = Model::OpenMistral7b;
|
||||||
let messages = vec![ChatMessage::new_user_message(
|
let messages = vec![ChatMessage::new_user_message(
|
||||||
"Just guess the next word: \"Eiffel ...\"?",
|
"Guess the next word: \"Eiffel ...\"?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
@@ -31,8 +31,12 @@ fn test_client_chat() {
|
|||||||
expect!(response.choices.len()).to_be(1);
|
expect!(response.choices.len()).to_be(1);
|
||||||
expect!(response.choices[0].index).to_be(0);
|
expect!(response.choices[0].index).to_be(0);
|
||||||
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
expect!(response.choices[0].message.role.clone()).to_be(ChatMessageRole::Assistant);
|
||||||
expect!(response.choices[0].message.content.clone())
|
expect!(response.choices[0]
|
||||||
.to_be("Tower. The Eiffel Tower is a famous landmark in Paris, France.".to_string());
|
.message
|
||||||
|
.content
|
||||||
|
.clone()
|
||||||
|
.contains("Tower"))
|
||||||
|
.to_be(true);
|
||||||
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
expect!(response.choices[0].finish_reason.clone()).to_be(ChatResponseChoiceFinishReason::Stop);
|
||||||
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
expect!(response.usage.prompt_tokens).to_be_greater_than(0);
|
||||||
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
expect!(response.usage.completion_tokens).to_be_greater_than(0);
|
||||||
@@ -60,7 +64,7 @@ fn test_client_chat_with_function_calling() {
|
|||||||
"What's the current temperature in Paris?",
|
"What's the current temperature in Paris?",
|
||||||
)];
|
)];
|
||||||
let options = ChatParams {
|
let options = ChatParams {
|
||||||
temperature: Some(0.0),
|
temperature: 0.0,
|
||||||
random_seed: Some(42),
|
random_seed: Some(42),
|
||||||
tool_choice: Some(ToolChoice::Auto),
|
tool_choice: Some(ToolChoice::Auto),
|
||||||
tools: Some(tools),
|
tools: Some(tools),
|
||||||
|
|||||||
@@ -1,6 +1,11 @@
|
|||||||
use jrest::expect;
|
use jrest::expect;
|
||||||
use mistralai_client::v1::{client::Client, error::ClientError};
|
use mistralai_client::v1::{client::Client, error::ClientError};
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct _Foo {
|
||||||
|
_client: Client,
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_client_new_with_none_params() {
|
fn test_client_new_with_none_params() {
|
||||||
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
let maybe_original_mistral_api_key = std::env::var("MISTRAL_API_KEY").ok();
|
||||||
|
|||||||
41
tests/v1_constants_test.rs
Normal file
41
tests/v1_constants_test.rs
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
use jrest::expect;
|
||||||
|
use mistralai_client::v1::{
|
||||||
|
chat::{ChatMessage, ChatParams},
|
||||||
|
client::Client,
|
||||||
|
constants::Model,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_constant() {
|
||||||
|
let models = vec![
|
||||||
|
Model::OpenMistral7b,
|
||||||
|
Model::OpenMixtral8x7b,
|
||||||
|
Model::OpenMixtral8x22b,
|
||||||
|
Model::MistralTiny,
|
||||||
|
Model::MistralSmallLatest,
|
||||||
|
Model::MistralMediumLatest,
|
||||||
|
Model::MistralLargeLatest,
|
||||||
|
Model::CodestralLatest,
|
||||||
|
];
|
||||||
|
|
||||||
|
let client = Client::new(None, None, None, None).unwrap();
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::new_user_message("A number between 0 and 100?")];
|
||||||
|
let options = ChatParams {
|
||||||
|
temperature: 0.0,
|
||||||
|
random_seed: Some(42),
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
for model in models {
|
||||||
|
let response = client
|
||||||
|
.chat(model.clone(), messages.clone(), Some(options.clone()))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
expect!(response.model).to_be(model);
|
||||||
|
expect!(response.object).to_be("chat.completion".to_string());
|
||||||
|
expect!(response.choices.len()).to_be(1);
|
||||||
|
expect!(response.choices[0].index).to_be(0);
|
||||||
|
expect!(response.choices[0].message.content.len()).to_be_greater_than(0);
|
||||||
|
}
|
||||||
|
}
|
||||||
7
tests/v1_tool_test.rs
Normal file
7
tests/v1_tool_test.rs
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
use mistralai_client::v1::client::Client;
|
||||||
|
|
||||||
|
trait _Trait: Send {}
|
||||||
|
struct _Foo {
|
||||||
|
_dummy: Client,
|
||||||
|
}
|
||||||
|
impl _Trait for _Foo {}
|
||||||
Reference in New Issue
Block a user