diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml
index f109c6b..264af0a 100644
--- a/.github/workflows/test.yml
+++ b/.github/workflows/test.yml
@@ -40,3 +40,18 @@ jobs:
run: make test-doc
env:
MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
+
+ test_examples:
+ name: Test Documentation
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout
+ uses: actions/checkout@v4
+ - name: Setup Rust
+ uses: actions-rs/toolchain@v1
+ with:
+ toolchain: 1.76.0
+ - name: Run documentation tests
+ run: make test-doc
+ env:
+ MISTRAL_API_KEY: ${{ secrets.MISTRAL_API_KEY }}
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index 749a9d8..d6f00f4 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -4,7 +4,10 @@
- [Requirements](#requirements)
- [First setup](#first-setup)
- [Optional requirements](#optional-requirements)
+- [Local Development](#local-development)
- [Test](#test)
+- [Documentation](#documentation)
+ - [Readme](#readme)
- [Code of Conduct](#code-of-conduct)
- [Commit Message Format](#commit-message-format)
@@ -41,6 +44,8 @@ Then edit the `.env` file to set your `MISTRAL_API_KEY`.
- [cargo-llvm-cov](https://github.com/taiki-e/cargo-llvm-cov?tab=readme-ov-file#installation) for `make test-cover`
- [cargo-watch](https://github.com/watchexec/cargo-watch#install) for `make test-watch`.
+## Local Development
+
### Test
```sh
@@ -53,6 +58,16 @@ or
make test-watch
```
+## Documentation
+
+### Readme
+
+> [!IMPORTANT]
+> Do not edit the `README.md` file directly. It is generated from the `README.template.md` file.
+
+1. Edit the `README.template.md` file.
+2. Run `make readme` to generate/update the `README.md` file.
+
## Code of Conduct
Help us keep this project open and inclusive. Please read and follow our [Code of Conduct](./CODE_OF_CONDUCT.md).
diff --git a/Cargo.toml b/Cargo.toml
index 3761f3b..bb6f63a 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -15,14 +15,18 @@ readme = "README.md"
repository = "https://github.com/ivangabriele/mistralai-client-rs"
[dependencies]
+async-stream = "0.3.5"
+async-trait = "0.1.77"
+env_logger = "0.11.3"
futures = "0.3.30"
+log = "0.4.21"
reqwest = { version = "0.11.24", features = ["json", "blocking", "stream"] }
serde = { version = "1.0.197", features = ["derive"] }
serde_json = "1.0.114"
strum = "0.26.1"
-strum_macros = "0.26.1"
thiserror = "1.0.57"
tokio = { version = "1.36.0", features = ["full"] }
+tokio-stream = "0.1.14"
[dev-dependencies]
jrest = "0.2.3"
diff --git a/Makefile b/Makefile
index 6e940ec..02eada5 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
SHELL := /bin/bash
-.PHONY: test
+.PHONY: doc readme test
define source_env_if_not_ci
@if [ -z "$${CI}" ]; then \
@@ -26,20 +26,47 @@ doc:
cargo doc
open ./target/doc/mistralai_client/index.html
+readme:
+ @echo "Generating README.md from template..."
+ @> README.md # Clear README.md content before starting
+ @while IFS= read -r line || [[ -n "$$line" ]]; do \
+ if [[ $$line == *""* && $$line == *""* ]]; then \
+ example_path=$$(echo $$line | sed -n 's/.*\(.*\)<\/CODE>.*/\1/p'); \
+ if [ -f $$example_path ]; then \
+ echo '```rs' >> README.md; \
+ cat $$example_path >> README.md; \
+ echo '```' >> README.md; \
+ else \
+ echo "Error: Example $$example_path not found." >&2; \
+ fi; \
+ else \
+ echo "$$line" >> README.md; \
+ fi; \
+ done < README.template.md
+ @echo "README.md has been generated."
+
release-patch:
$(call RELEASE_TEMPLATE,patch)
-
release-minor:
$(call RELEASE_TEMPLATE,minor)
-
release-major:
$(call RELEASE_TEMPLATE,major)
test:
- @$(source_env_if_not_ci) && cargo test --no-fail-fast
+ @$(source_env_if_not_ci)
+ cargo test --no-fail-fast
test-cover:
- @$(source_env_if_not_ci) && cargo llvm-cov
+ @$(source_env_if_not_ci)
+ cargo llvm-cov
test-doc:
- @$(source_env_if_not_ci) && cargo test --doc --no-fail-fast
+ @$(source_env_if_not_ci)
+ cargo test --doc --no-fail-fast
+test-examples:
+ @$(source_env_if_not_ci)
+ @for example in $$(ls examples/*.rs | sed 's/examples\/\(.*\)\.rs/\1/'); do \
+ echo "Running $$example"; \
+ cargo run --example $$example; \
+ done
test-watch:
- @source ./.env && cargo watch -x "test -- --nocapture"
+ @source ./.env
+ cargo watch -x "test -- --nocapture"
diff --git a/README.md b/README.md
index a828533..3f55332 100644
--- a/README.md
+++ b/README.md
@@ -15,13 +15,16 @@ Rust client for the Mistral AI API.
- [As an environment variable](#as-an-environment-variable)
- [As a client argument](#as-a-client-argument)
- [Usage](#usage)
- - [Chat without streaming](#chat-without-streaming)
- - [Chat without streaming (async)](#chat-without-streaming-async)
+ - [Chat](#chat)
+ - [Chat (async)](#chat-async)
- [Chat with streaming (async)](#chat-with-streaming-async)
+ - [Chat with Function Calling](#chat-with-function-calling)
+ - [Chat with Function Calling (async)](#chat-with-function-calling-async)
- [Embeddings](#embeddings)
- [Embeddings (async)](#embeddings-async)
- [List models](#list-models)
- [List models (async)](#list-models-async)
+- [Contributing](#contributing)
---
@@ -34,8 +37,8 @@ Rust client for the Mistral AI API.
- [x] Embedding (async)
- [x] List models
- [x] List models (async)
-- [ ] Function Calling
-- [ ] Function Calling (async)
+- [x] Function Calling
+- [x] Function Calling (async)
## Installation
@@ -53,6 +56,18 @@ You can get your Mistral API Key there: .
Just set the `MISTRAL_API_KEY` environment variable.
+```rs
+use mistralai_client::v1::client::Client;
+
+fn main() {
+ let client = Client::new(None, None, None, None);
+}
+```
+
+```sh
+MISTRAL_API_KEY=your_api_key cargo run
+```
+
#### As a client argument
```rs
@@ -67,11 +82,11 @@ fn main() {
## Usage
-### Chat without streaming
+### Chat
```rs
use mistralai_client::v1::{
- chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
client::Client,
constants::Model,
};
@@ -82,10 +97,11 @@ fn main() {
let model = Model::OpenMistral7b;
let messages = vec![ChatMessage {
- role: ChatMessageRole::user,
+ role: ChatMessageRole::User,
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
+ tool_calls: None,
}];
- let options = ChatCompletionRequestOptions {
+ let options = ChatParams {
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
@@ -93,15 +109,15 @@ fn main() {
let result = client.chat(model, messages, Some(options)).unwrap();
println!("Assistant: {}", result.choices[0].message.content);
- // => "Assistant: Tower. [...]"
+ // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
}
```
-### Chat without streaming (async)
+### Chat (async)
```rs
use mistralai_client::v1::{
- chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
client::Client,
constants::Model,
};
@@ -113,18 +129,25 @@ async fn main() {
let model = Model::OpenMistral7b;
let messages = vec![ChatMessage {
- role: ChatMessageRole::user,
+ role: ChatMessageRole::User,
content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
+ tool_calls: None,
}];
- let options = ChatCompletionRequestOptions {
+ let options = ChatParams {
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
};
- let result = client.chat_async(model, messages, Some(options)).await.unwrap();
- println!("Assistant: {}", result.choices[0].message.content);
- // => "Assistant: Tower. [...]"
+ let result = client
+ .chat_async(model, messages, Some(options))
+ .await
+ .unwrap();
+ println!(
+ "{:?}: {}",
+ result.choices[0].message.role, result.choices[0].message.content
+ );
+ // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
}
```
@@ -133,38 +156,206 @@ async fn main() {
```rs
use futures::stream::StreamExt;
use mistralai_client::v1::{
- chat_completion::{ChatCompletionParams, ChatMessage, ChatMessageRole},
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
client::Client,
constants::Model,
};
+use std::io::{self, Write};
-[#tokio::main]
+#[tokio::main]
async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
- let client = Client::new(None, None, None, None).unwrap();
+ let client = Client::new(None, None, None, None).unwrap();
let model = Model::OpenMistral7b;
let messages = vec![ChatMessage {
- role: ChatMessageRole::user,
- content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
+ role: ChatMessageRole::User,
+ content: "Tell me a short happy story.".to_string(),
+ tool_calls: None,
}];
- let options = ChatCompletionParams {
+ let options = ChatParams {
temperature: Some(0.0),
random_seed: Some(42),
..Default::default()
};
- let stream_result = client.chat_stream(model, messages, Some(options)).await;
- let mut stream = stream_result.expect("Failed to create stream.");
- while let Some(chunk_result) = stream.next().await {
- match chunk_result {
- Ok(chunk) => {
- println!("Assistant (message chunk): {}", chunk.choices[0].delta.content);
+ let stream_result = client
+ .chat_stream(model, messages, Some(options))
+ .await
+ .expect("Failed to create stream.");
+ stream_result
+ .for_each(|chunk_result| async {
+ match chunk_result {
+ Ok(chunks) => chunks.iter().for_each(|chunk| {
+ print!("{}", chunk.choices[0].delta.content);
+ io::stdout().flush().unwrap();
+ // => "Once upon a time, [...]"
+ }),
+ Err(error) => {
+ eprintln!("Error processing chunk: {:?}", error)
+ }
}
- Err(e) => eprintln!("Error processing chunk: {:?}", e),
- }
+ })
+ .await;
+
+ print!("\n") // To persist the last chunk output.
+}
+```
+
+### Chat with Function Calling
+
+```rs
+use mistralai_client::v1::{
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
+ client::Client,
+ constants::Model,
+ tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
+};
+use serde::Deserialize;
+use std::any::Any;
+
+#[derive(Debug, Deserialize)]
+struct GetCityTemperatureArguments {
+ city: String,
+}
+
+struct GetCityTemperatureFunction;
+#[async_trait::async_trait]
+impl Function for GetCityTemperatureFunction {
+ async fn execute(&self, arguments: String) -> Box {
+ // Deserialize arguments, perform the logic, and return the result
+ let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
+
+ let temperature = match city.as_str() {
+ "Paris" => "20°C",
+ _ => "Unknown city",
+ };
+
+ Box::new(temperature.to_string())
}
}
+
+fn main() {
+ let tools = vec![Tool::new(
+ "get_city_temperature".to_string(),
+ "Get the current temperature in a city.".to_string(),
+ vec![ToolFunctionParameter::new(
+ "city".to_string(),
+ "The name of the city.".to_string(),
+ ToolFunctionParameterType::String,
+ )],
+ )];
+
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let mut client = Client::new(None, None, None, None).unwrap();
+ client.register_function(
+ "get_city_temperature".to_string(),
+ Box::new(GetCityTemperatureFunction),
+ );
+
+ let model = Model::MistralSmallLatest;
+ let messages = vec![ChatMessage {
+ role: ChatMessageRole::User,
+ content: "What's the temperature in Paris?".to_string(),
+ tool_calls: None,
+ }];
+ let options = ChatParams {
+ temperature: Some(0.0),
+ random_seed: Some(42),
+ tool_choice: Some(ToolChoice::Auto),
+ tools: Some(tools),
+ ..Default::default()
+ };
+
+ client.chat(model, messages, Some(options)).unwrap();
+ let temperature = client
+ .get_last_function_call_result()
+ .unwrap()
+ .downcast::()
+ .unwrap();
+ println!("The temperature in Paris is: {}.", temperature);
+ // => "The temperature in Paris is: 20°C."
+}
+```
+
+### Chat with Function Calling (async)
+
+```rs
+use mistralai_client::v1::{
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
+ client::Client,
+ constants::Model,
+ tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
+};
+use serde::Deserialize;
+use std::any::Any;
+
+#[derive(Debug, Deserialize)]
+struct GetCityTemperatureArguments {
+ city: String,
+}
+
+struct GetCityTemperatureFunction;
+#[async_trait::async_trait]
+impl Function for GetCityTemperatureFunction {
+ async fn execute(&self, arguments: String) -> Box {
+ // Deserialize arguments, perform the logic, and return the result
+ let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
+
+ let temperature = match city.as_str() {
+ "Paris" => "20°C",
+ _ => "Unknown city",
+ };
+
+ Box::new(temperature.to_string())
+ }
+}
+
+#[tokio::main]
+async fn main() {
+ let tools = vec![Tool::new(
+ "get_city_temperature".to_string(),
+ "Get the current temperature in a city.".to_string(),
+ vec![ToolFunctionParameter::new(
+ "city".to_string(),
+ "The name of the city.".to_string(),
+ ToolFunctionParameterType::String,
+ )],
+ )];
+
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let mut client = Client::new(None, None, None, None).unwrap();
+ client.register_function(
+ "get_city_temperature".to_string(),
+ Box::new(GetCityTemperatureFunction),
+ );
+
+ let model = Model::MistralSmallLatest;
+ let messages = vec![ChatMessage {
+ role: ChatMessageRole::User,
+ content: "What's the temperature in Paris?".to_string(),
+ tool_calls: None,
+ }];
+ let options = ChatParams {
+ temperature: Some(0.0),
+ random_seed: Some(42),
+ tool_choice: Some(ToolChoice::Auto),
+ tools: Some(tools),
+ ..Default::default()
+ };
+
+ client
+ .chat_async(model, messages, Some(options))
+ .await
+ .unwrap();
+ let temperature = client
+ .get_last_function_call_result()
+ .unwrap()
+ .downcast::()
+ .unwrap();
+ println!("The temperature in Paris is: {}.", temperature);
+ // => "The temperature in Paris is: 20°C."
+}
```
### Embeddings
@@ -174,18 +365,18 @@ use mistralai_client::v1::{client::Client, constants::EmbedModel};
fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
- let client: Client = Client::new(None, None, None, None).unwrap();
+ let client: Client = Client::new(None, None, None, None).unwrap();
- let model = EmbedModel::MistralEmbed;
- let input = vec!["Embed this sentence.", "As well as this one."]
- .iter()
- .map(|s| s.to_string())
- .collect();
- let options = None;
+ let model = EmbedModel::MistralEmbed;
+ let input = vec!["Embed this sentence.", "As well as this one."]
+ .iter()
+ .map(|s| s.to_string())
+ .collect();
+ let options = None;
- let response = client.embeddings(model, input, options).unwrap();
- println!("Embeddings: {:?}", response.data);
- // => "Embeddings: [{...}, {...}]"
+ let response = client.embeddings(model, input, options).unwrap();
+ println!("First Embedding: {:?}", response.data[0]);
+ // => "First Embedding: {...}"
}
```
@@ -197,18 +388,21 @@ use mistralai_client::v1::{client::Client, constants::EmbedModel};
#[tokio::main]
async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
- let client: Client = Client::new(None, None, None, None).unwrap();
+ let client: Client = Client::new(None, None, None, None).unwrap();
- let model = EmbedModel::MistralEmbed;
- let input = vec!["Embed this sentence.", "As well as this one."]
- .iter()
- .map(|s| s.to_string())
- .collect();
- let options = None;
+ let model = EmbedModel::MistralEmbed;
+ let input = vec!["Embed this sentence.", "As well as this one."]
+ .iter()
+ .map(|s| s.to_string())
+ .collect();
+ let options = None;
- let response = client.embeddings_async(model, input, options).await.unwrap();
- println!("Embeddings: {:?}", response.data);
- // => "Embeddings: [{...}, {...}]"
+ let response = client
+ .embeddings_async(model, input, options)
+ .await
+ .unwrap();
+ println!("First Embedding: {:?}", response.data[0]);
+ // => "First Embedding: {...}"
}
```
@@ -235,10 +429,14 @@ use mistralai_client::v1::client::Client;
#[tokio::main]
async fn main() {
// This example suppose you have set the `MISTRAL_API_KEY` environment variable.
- let client = Client::new(None, None, None, None).await.unwrap();
+ let client = Client::new(None, None, None, None).unwrap();
- let result = client.list_models_async().unwrap();
+ let result = client.list_models_async().await.unwrap();
println!("First Model ID: {:?}", result.data[0].id);
// => "First Model ID: open-mistral-7b"
}
```
+
+## Contributing
+
+Please read [CONTRIBUTING.md](./CONTRIBUTING.md) for details on how to contribute to this library.
diff --git a/README.template.md b/README.template.md
new file mode 100644
index 0000000..5f3d496
--- /dev/null
+++ b/README.template.md
@@ -0,0 +1,123 @@
+# Mistral AI Rust Client
+
+[](https://crates.io/crates/mistralai-client)
+[](https://docs.rs/mistralai-client/latest/mistralai-client)
+[](https://github.com/ivangabriele/mistralai-client-rs/actions?query=branch%3Amain+workflow%3ATest++)
+[](https://app.codecov.io/github/ivangabriele/mistralai-client-rs)
+
+Rust client for the Mistral AI API.
+
+---
+
+- [Supported APIs](#supported-apis)
+- [Installation](#installation)
+ - [Mistral API Key](#mistral-api-key)
+ - [As an environment variable](#as-an-environment-variable)
+ - [As a client argument](#as-a-client-argument)
+- [Usage](#usage)
+ - [Chat](#chat)
+ - [Chat (async)](#chat-async)
+ - [Chat with streaming (async)](#chat-with-streaming-async)
+ - [Chat with Function Calling](#chat-with-function-calling)
+ - [Chat with Function Calling (async)](#chat-with-function-calling-async)
+ - [Embeddings](#embeddings)
+ - [Embeddings (async)](#embeddings-async)
+ - [List models](#list-models)
+ - [List models (async)](#list-models-async)
+- [Contributing](#contributing)
+
+---
+
+## Supported APIs
+
+- [x] Chat without streaming
+- [x] Chat without streaming (async)
+- [x] Chat with streaming
+- [x] Embedding
+- [x] Embedding (async)
+- [x] List models
+- [x] List models (async)
+- [x] Function Calling
+- [x] Function Calling (async)
+
+## Installation
+
+You can install the library in your project using:
+
+```sh
+cargo add mistralai-client
+```
+
+### Mistral API Key
+
+You can get your Mistral API Key there: .
+
+#### As an environment variable
+
+Just set the `MISTRAL_API_KEY` environment variable.
+
+```rs
+use mistralai_client::v1::client::Client;
+
+fn main() {
+ let client = Client::new(None, None, None, None);
+}
+```
+
+```sh
+MISTRAL_API_KEY=your_api_key cargo run
+```
+
+#### As a client argument
+
+```rs
+use mistralai_client::v1::client::Client;
+
+fn main() {
+ let api_key = "your_api_key";
+
+ let client = Client::new(Some(api_key), None, None, None).unwrap();
+}
+```
+
+## Usage
+
+### Chat
+
+examples/chat.rs
+
+### Chat (async)
+
+examples/chat_async.rs
+
+### Chat with streaming (async)
+
+examples/chat_with_streaming.rs
+
+### Chat with Function Calling
+
+examples/chat_with_function_calling.rs
+
+### Chat with Function Calling (async)
+
+examples/chat_with_function_calling_async.rs
+
+### Embeddings
+
+examples/embeddings.rs
+
+### Embeddings (async)
+
+examples/embeddings_async.rs
+
+### List models
+
+examples/list_models.rs
+
+### List models (async)
+
+examples/list_models_async.rs
+
+## Contributing
+
+Please read [CONTRIBUTING.md](./CONTRIBUTING.md) for details on how to contribute to this library.
diff --git a/examples/chat.rs b/examples/chat.rs
new file mode 100644
index 0000000..12d5fd4
--- /dev/null
+++ b/examples/chat.rs
@@ -0,0 +1,26 @@
+use mistralai_client::v1::{
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
+ client::Client,
+ constants::Model,
+};
+
+fn main() {
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let client = Client::new(None, None, None, None).unwrap();
+
+ let model = Model::OpenMistral7b;
+ let messages = vec![ChatMessage {
+ role: ChatMessageRole::User,
+ content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
+ tool_calls: None,
+ }];
+ let options = ChatParams {
+ temperature: Some(0.0),
+ random_seed: Some(42),
+ ..Default::default()
+ };
+
+ let result = client.chat(model, messages, Some(options)).unwrap();
+ println!("Assistant: {}", result.choices[0].message.content);
+ // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
+}
diff --git a/examples/chat_async.rs b/examples/chat_async.rs
new file mode 100644
index 0000000..7034553
--- /dev/null
+++ b/examples/chat_async.rs
@@ -0,0 +1,33 @@
+use mistralai_client::v1::{
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
+ client::Client,
+ constants::Model,
+};
+
+#[tokio::main]
+async fn main() {
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let client = Client::new(None, None, None, None).unwrap();
+
+ let model = Model::OpenMistral7b;
+ let messages = vec![ChatMessage {
+ role: ChatMessageRole::User,
+ content: "Just guess the next word: \"Eiffel ...\"?".to_string(),
+ tool_calls: None,
+ }];
+ let options = ChatParams {
+ temperature: Some(0.0),
+ random_seed: Some(42),
+ ..Default::default()
+ };
+
+ let result = client
+ .chat_async(model, messages, Some(options))
+ .await
+ .unwrap();
+ println!(
+ "{:?}: {}",
+ result.choices[0].message.role, result.choices[0].message.content
+ );
+ // => "Assistant: Tower. The Eiffel Tower is a famous landmark in Paris, France."
+}
diff --git a/examples/chat_with_function_calling.rs b/examples/chat_with_function_calling.rs
new file mode 100644
index 0000000..991d2f2
--- /dev/null
+++ b/examples/chat_with_function_calling.rs
@@ -0,0 +1,71 @@
+use mistralai_client::v1::{
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
+ client::Client,
+ constants::Model,
+ tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
+};
+use serde::Deserialize;
+use std::any::Any;
+
+#[derive(Debug, Deserialize)]
+struct GetCityTemperatureArguments {
+ city: String,
+}
+
+struct GetCityTemperatureFunction;
+#[async_trait::async_trait]
+impl Function for GetCityTemperatureFunction {
+ async fn execute(&self, arguments: String) -> Box {
+ // Deserialize arguments, perform the logic, and return the result
+ let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
+
+ let temperature = match city.as_str() {
+ "Paris" => "20°C",
+ _ => "Unknown city",
+ };
+
+ Box::new(temperature.to_string())
+ }
+}
+
+fn main() {
+ let tools = vec![Tool::new(
+ "get_city_temperature".to_string(),
+ "Get the current temperature in a city.".to_string(),
+ vec![ToolFunctionParameter::new(
+ "city".to_string(),
+ "The name of the city.".to_string(),
+ ToolFunctionParameterType::String,
+ )],
+ )];
+
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let mut client = Client::new(None, None, None, None).unwrap();
+ client.register_function(
+ "get_city_temperature".to_string(),
+ Box::new(GetCityTemperatureFunction),
+ );
+
+ let model = Model::MistralSmallLatest;
+ let messages = vec![ChatMessage {
+ role: ChatMessageRole::User,
+ content: "What's the temperature in Paris?".to_string(),
+ tool_calls: None,
+ }];
+ let options = ChatParams {
+ temperature: Some(0.0),
+ random_seed: Some(42),
+ tool_choice: Some(ToolChoice::Auto),
+ tools: Some(tools),
+ ..Default::default()
+ };
+
+ client.chat(model, messages, Some(options)).unwrap();
+ let temperature = client
+ .get_last_function_call_result()
+ .unwrap()
+ .downcast::()
+ .unwrap();
+ println!("The temperature in Paris is: {}.", temperature);
+ // => "The temperature in Paris is: 20°C."
+}
diff --git a/examples/chat_with_function_calling_async.rs b/examples/chat_with_function_calling_async.rs
new file mode 100644
index 0000000..0fb5213
--- /dev/null
+++ b/examples/chat_with_function_calling_async.rs
@@ -0,0 +1,75 @@
+use mistralai_client::v1::{
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
+ client::Client,
+ constants::Model,
+ tool::{Function, Tool, ToolChoice, ToolFunctionParameter, ToolFunctionParameterType},
+};
+use serde::Deserialize;
+use std::any::Any;
+
+#[derive(Debug, Deserialize)]
+struct GetCityTemperatureArguments {
+ city: String,
+}
+
+struct GetCityTemperatureFunction;
+#[async_trait::async_trait]
+impl Function for GetCityTemperatureFunction {
+ async fn execute(&self, arguments: String) -> Box {
+ // Deserialize arguments, perform the logic, and return the result
+ let GetCityTemperatureArguments { city } = serde_json::from_str(&arguments).unwrap();
+
+ let temperature = match city.as_str() {
+ "Paris" => "20°C",
+ _ => "Unknown city",
+ };
+
+ Box::new(temperature.to_string())
+ }
+}
+
+#[tokio::main]
+async fn main() {
+ let tools = vec![Tool::new(
+ "get_city_temperature".to_string(),
+ "Get the current temperature in a city.".to_string(),
+ vec![ToolFunctionParameter::new(
+ "city".to_string(),
+ "The name of the city.".to_string(),
+ ToolFunctionParameterType::String,
+ )],
+ )];
+
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let mut client = Client::new(None, None, None, None).unwrap();
+ client.register_function(
+ "get_city_temperature".to_string(),
+ Box::new(GetCityTemperatureFunction),
+ );
+
+ let model = Model::MistralSmallLatest;
+ let messages = vec![ChatMessage {
+ role: ChatMessageRole::User,
+ content: "What's the temperature in Paris?".to_string(),
+ tool_calls: None,
+ }];
+ let options = ChatParams {
+ temperature: Some(0.0),
+ random_seed: Some(42),
+ tool_choice: Some(ToolChoice::Auto),
+ tools: Some(tools),
+ ..Default::default()
+ };
+
+ client
+ .chat_async(model, messages, Some(options))
+ .await
+ .unwrap();
+ let temperature = client
+ .get_last_function_call_result()
+ .unwrap()
+ .downcast::()
+ .unwrap();
+ println!("The temperature in Paris is: {}.", temperature);
+ // => "The temperature in Paris is: 20°C."
+}
diff --git a/examples/chat_with_streaming.rs b/examples/chat_with_streaming.rs
new file mode 100644
index 0000000..8515a45
--- /dev/null
+++ b/examples/chat_with_streaming.rs
@@ -0,0 +1,45 @@
+use futures::stream::StreamExt;
+use mistralai_client::v1::{
+ chat::{ChatMessage, ChatMessageRole, ChatParams},
+ client::Client,
+ constants::Model,
+};
+use std::io::{self, Write};
+
+#[tokio::main]
+async fn main() {
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let client = Client::new(None, None, None, None).unwrap();
+
+ let model = Model::OpenMistral7b;
+ let messages = vec![ChatMessage {
+ role: ChatMessageRole::User,
+ content: "Tell me a short happy story.".to_string(),
+ tool_calls: None,
+ }];
+ let options = ChatParams {
+ temperature: Some(0.0),
+ random_seed: Some(42),
+ ..Default::default()
+ };
+
+ let stream_result = client
+ .chat_stream(model, messages, Some(options))
+ .await
+ .unwrap();
+ stream_result
+ .for_each(|chunk_result| async {
+ match chunk_result {
+ Ok(chunks) => chunks.iter().for_each(|chunk| {
+ print!("{}", chunk.choices[0].delta.content);
+ io::stdout().flush().unwrap();
+ // => "Once upon a time, [...]"
+ }),
+ Err(error) => {
+ eprintln!("Error processing chunk: {:?}", error)
+ }
+ }
+ })
+ .await;
+ print!("\n") // To persist the last chunk output.
+}
diff --git a/examples/embeddings.rs b/examples/embeddings.rs
new file mode 100644
index 0000000..898e7d4
--- /dev/null
+++ b/examples/embeddings.rs
@@ -0,0 +1,17 @@
+use mistralai_client::v1::{client::Client, constants::EmbedModel};
+
+fn main() {
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let client: Client = Client::new(None, None, None, None).unwrap();
+
+ let model = EmbedModel::MistralEmbed;
+ let input = vec!["Embed this sentence.", "As well as this one."]
+ .iter()
+ .map(|s| s.to_string())
+ .collect();
+ let options = None;
+
+ let response = client.embeddings(model, input, options).unwrap();
+ println!("First Embedding: {:?}", response.data[0]);
+ // => "First Embedding: {...}"
+}
diff --git a/examples/embeddings_async.rs b/examples/embeddings_async.rs
new file mode 100644
index 0000000..a93d374
--- /dev/null
+++ b/examples/embeddings_async.rs
@@ -0,0 +1,21 @@
+use mistralai_client::v1::{client::Client, constants::EmbedModel};
+
+#[tokio::main]
+async fn main() {
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let client: Client = Client::new(None, None, None, None).unwrap();
+
+ let model = EmbedModel::MistralEmbed;
+ let input = vec!["Embed this sentence.", "As well as this one."]
+ .iter()
+ .map(|s| s.to_string())
+ .collect();
+ let options = None;
+
+ let response = client
+ .embeddings_async(model, input, options)
+ .await
+ .unwrap();
+ println!("First Embedding: {:?}", response.data[0]);
+ // => "First Embedding: {...}"
+}
diff --git a/examples/list_models.rs b/examples/list_models.rs
new file mode 100644
index 0000000..7551b0a
--- /dev/null
+++ b/examples/list_models.rs
@@ -0,0 +1,10 @@
+use mistralai_client::v1::client::Client;
+
+fn main() {
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let client = Client::new(None, None, None, None).unwrap();
+
+ let result = client.list_models().unwrap();
+ println!("First Model ID: {:?}", result.data[0].id);
+ // => "First Model ID: open-mistral-7b"
+}
diff --git a/examples/list_models_async.rs b/examples/list_models_async.rs
new file mode 100644
index 0000000..a801928
--- /dev/null
+++ b/examples/list_models_async.rs
@@ -0,0 +1,11 @@
+use mistralai_client::v1::client::Client;
+
+#[tokio::main]
+async fn main() {
+ // This example suppose you have set the `MISTRAL_API_KEY` environment variable.
+ let client = Client::new(None, None, None, None).unwrap();
+
+ let result = client.list_models_async().await.unwrap();
+ println!("First Model ID: {:?}", result.data[0].id);
+ // => "First Model ID: open-mistral-7b"
+}
diff --git a/src/v1/chat_completion.rs b/src/v1/chat.rs
similarity index 64%
rename from src/v1/chat_completion.rs
rename to src/v1/chat.rs
index 701764b..20853d2 100644
--- a/src/v1/chat_completion.rs
+++ b/src/v1/chat.rs
@@ -1,6 +1,6 @@
use serde::{Deserialize, Serialize};
-use crate::v1::{common, constants};
+use crate::v1::{common, constants, tool};
// -----------------------------------------------------------------------------
// Definitions
@@ -9,88 +9,115 @@ use crate::v1::{common, constants};
pub struct ChatMessage {
pub role: ChatMessageRole,
pub content: String,
+ pub tool_calls: Option>,
+}
+impl ChatMessage {
+ pub fn new_assistant_message(content: &str, tool_calls: Option>) -> Self {
+ Self {
+ role: ChatMessageRole::Assistant,
+ content: content.to_string(),
+ tool_calls,
+ }
+ }
+
+ pub fn new_user_message(content: &str) -> Self {
+ Self {
+ role: ChatMessageRole::User,
+ content: content.to_string(),
+ tool_calls: None,
+ }
+ }
}
-#[derive(Clone, Debug, strum_macros::Display, Eq, PartialEq, Deserialize, Serialize)]
-#[allow(non_camel_case_types)]
+#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
pub enum ChatMessageRole {
- assistant,
- user,
+ #[serde(rename = "assistant")]
+ Assistant,
+ #[serde(rename = "user")]
+ User,
}
// -----------------------------------------------------------------------------
// Request
#[derive(Debug)]
-pub struct ChatCompletionParams {
- pub tools: Option,
- pub temperature: Option,
+pub struct ChatParams {
pub max_tokens: Option,
- pub top_p: Option,
pub random_seed: Option,
pub safe_prompt: Option,
+ pub temperature: Option,
+ pub tool_choice: Option,
+ pub tools: Option>,
+ pub top_p: Option,
}
-impl Default for ChatCompletionParams {
+impl Default for ChatParams {
fn default() -> Self {
Self {
- tools: None,
- temperature: None,
max_tokens: None,
- top_p: None,
random_seed: None,
safe_prompt: None,
+ temperature: None,
+ tool_choice: None,
+ tools: None,
+ top_p: None,
}
}
}
#[derive(Debug, Serialize, Deserialize)]
-pub struct ChatCompletionRequest {
+pub struct ChatRequest {
pub messages: Vec,
pub model: constants::Model,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub tools: Option,
- #[serde(skip_serializing_if = "Option::is_none")]
- pub temperature: Option,
+
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option,
#[serde(skip_serializing_if = "Option::is_none")]
- pub top_p: Option,
- #[serde(skip_serializing_if = "Option::is_none")]
pub random_seed: Option,
- pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub safe_prompt: Option,
+ pub stream: bool,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub temperature: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tool_choice: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub tools: Option>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ pub top_p: Option,
// TODO Check this prop (seen in official Python client but not in API doc).
// pub tool_choice: Option,
// TODO Check this prop (seen in official Python client but not in API doc).
// pub response_format: Option,
}
-impl ChatCompletionRequest {
+impl ChatRequest {
pub fn new(
model: constants::Model,
messages: Vec,
stream: bool,
- options: Option,
+ options: Option,
) -> Self {
- let ChatCompletionParams {
- tools,
- temperature,
+ let ChatParams {
max_tokens,
- top_p,
random_seed,
safe_prompt,
+ temperature,
+ tool_choice,
+ tools,
+ top_p,
} = options.unwrap_or_default();
Self {
messages,
model,
- tools,
- temperature,
+
max_tokens,
- top_p,
random_seed,
- stream,
safe_prompt,
+ stream,
+ temperature,
+ tool_choice,
+ tools,
+ top_p,
}
}
}
@@ -99,51 +126,29 @@ impl ChatCompletionRequest {
// Response
#[derive(Clone, Debug, Deserialize, Serialize)]
-pub struct ChatCompletionResponse {
+pub struct ChatResponse {
pub id: String,
pub object: String,
/// Unix timestamp (in seconds).
pub created: u32,
pub model: constants::Model,
- pub choices: Vec,
+ pub choices: Vec,
pub usage: common::ResponseUsage,
}
#[derive(Clone, Debug, Deserialize, Serialize)]
-pub struct ChatCompletionResponseChoice {
+pub struct ChatResponseChoice {
pub index: u32,
pub message: ChatMessage,
- pub finish_reason: String,
+ pub finish_reason: ChatResponseChoiceFinishReason,
// TODO Check this prop (seen in API responses but undocumented).
// pub logprobs: ???
}
-// -----------------------------------------------------------------------------
-// Stream
-
-#[derive(Debug, Deserialize)]
-pub struct ChatCompletionStreamChunk {
- pub id: String,
- pub object: String,
- /// Unix timestamp (in seconds).
- pub created: u32,
- pub model: constants::Model,
- pub choices: Vec,
- // TODO Check this prop (seen in API responses but undocumented).
- // pub usage: ???,
-}
-
-#[derive(Debug, Deserialize)]
-pub struct ChatCompletionStreamChunkChoice {
- pub index: u32,
- pub delta: ChatCompletionStreamChunkChoiceDelta,
- pub finish_reason: Option,
- // TODO Check this prop (seen in API responses but undocumented).
- // pub logprobs: ???,
-}
-
-#[derive(Debug, Deserialize)]
-pub struct ChatCompletionStreamChunkChoiceDelta {
- pub role: Option,
- pub content: String,
+#[derive(Clone, Debug, Eq, PartialEq, Deserialize, Serialize)]
+pub enum ChatResponseChoiceFinishReason {
+ #[serde(rename = "stop")]
+ Stop,
+ #[serde(rename = "tool_calls")]
+ ToolCalls,
}
diff --git a/src/v1/chat_stream.rs b/src/v1/chat_stream.rs
new file mode 100644
index 0000000..1daf481
--- /dev/null
+++ b/src/v1/chat_stream.rs
@@ -0,0 +1,57 @@
+use serde::{Deserialize, Serialize};
+use serde_json::from_str;
+
+use crate::v1::{chat, common, constants, error};
+
+// -----------------------------------------------------------------------------
+// Response
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
+pub struct ChatStreamChunk {
+ pub id: String,
+ pub object: String,
+ /// Unix timestamp (in seconds).
+ pub created: u32,
+ pub model: constants::Model,
+ pub choices: Vec,
+ pub usage: Option,
+ // TODO Check this prop (seen in API responses but undocumented).
+ // pub logprobs: ???,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
+pub struct ChatStreamChunkChoice {
+ pub index: u32,
+ pub delta: ChatStreamChunkChoiceDelta,
+ pub finish_reason: Option,
+ // TODO Check this prop (seen in API responses but undocumented).
+ // pub logprobs: ???,
+}
+
+#[derive(Clone, Debug, Deserialize, Serialize)]
+pub struct ChatStreamChunkChoiceDelta {
+ pub role: Option,
+ pub content: String,
+}
+
+/// Extracts serialized chunks from a stream message.
+pub fn get_chunk_from_stream_message_line(
+ line: &str,
+) -> Result